1use std::{
4 io::{Read, Write},
5 net::{SocketAddr, TcpStream, ToSocketAddrs},
6 result::Result as StdResult,
7};
8
9use http::{request::Parts, HeaderName, Uri};
10use log::*;
11
12use crate::{
13 handshake::client::{generate_key, Request, Response},
14 protocol::WebSocketConfig,
15 stream::MaybeTlsStream,
16};
17
18use crate::{
19 error::{Error, Result, UrlError},
20 handshake::{client::ClientHandshake, HandshakeError},
21 protocol::WebSocket,
22 stream::{Mode, NoDelay},
23};
24
25pub fn connect_with_config<Req: IntoClientRequest>(
44 request: Req,
45 config: Option<WebSocketConfig>,
46 max_redirects: u8,
47) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
48 fn try_client_handshake(
49 request: Request,
50 config: Option<WebSocketConfig>,
51 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
52 let uri = request.uri();
53 let mode = uri_mode(uri)?;
54
55 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
56 if let Mode::Tls = mode {
57 return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
58 }
59
60 let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
61 let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
62 let port = uri.port_u16().unwrap_or(match mode {
63 Mode::Plain => 80,
64 Mode::Tls => 443,
65 });
66 let addrs = (host, port).to_socket_addrs()?;
67 let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
68 NoDelay::set_nodelay(&mut stream, true)?;
69
70 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
71 let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
72 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
73 let client = crate::tls::client_tls_with_config(request, stream, config, None);
74
75 client.map_err(|e| match e {
76 HandshakeError::Failure(f) => f,
77 HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
78 })
79 }
80
81 fn create_request(parts: &Parts, uri: &Uri) -> Request {
82 let mut builder =
83 Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
84 *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
85 builder.body(()).expect("Failed to create `Request`")
86 }
87
88 let (parts, _) = request.into_client_request()?.into_parts();
89 let mut uri = parts.uri.clone();
90
91 for attempt in 0..=max_redirects {
92 let request = create_request(&parts, &uri);
93
94 match try_client_handshake(request, config) {
95 Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
96 if let Some(location) = res.headers().get("Location") {
97 uri = location.to_str()?.parse::<Uri>()?;
98 debug!("Redirecting to {uri:?}");
99 continue;
100 } else {
101 warn!("No `Location` found in redirect");
102 return Err(Error::Http(res));
103 }
104 }
105 other => return other,
106 }
107 }
108
109 unreachable!("Bug in a redirect handling logic")
110}
111
112pub fn connect<Req: IntoClientRequest>(
125 request: Req,
126) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
127 connect_with_config(request, None, 3)
128}
129
130fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
131 for addr in addrs {
132 debug!("Trying to contact {uri} at {addr}...");
133 if let Ok(stream) = TcpStream::connect(addr) {
134 return Ok(stream);
135 }
136 }
137 Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
138}
139
140pub fn uri_mode(uri: &Uri) -> Result<Mode> {
145 match uri.scheme_str() {
146 Some("ws") => Ok(Mode::Plain),
147 Some("wss") => Ok(Mode::Tls),
148 _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
149 }
150}
151
152pub fn client_with_config<Stream, Req>(
159 request: Req,
160 stream: Stream,
161 config: Option<WebSocketConfig>,
162) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
163where
164 Stream: Read + Write,
165 Req: IntoClientRequest,
166{
167 ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
168}
169
170pub fn client<Stream, Req>(
176 request: Req,
177 stream: Stream,
178) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
179where
180 Stream: Read + Write,
181 Req: IntoClientRequest,
182{
183 client_with_config(request, stream, None)
184}
185
186pub trait IntoClientRequest {
195 fn into_client_request(self) -> Result<Request>;
197}
198
199impl IntoClientRequest for &str {
200 fn into_client_request(self) -> Result<Request> {
201 self.parse::<Uri>()?.into_client_request()
202 }
203}
204
205impl IntoClientRequest for &String {
206 fn into_client_request(self) -> Result<Request> {
207 <&str as IntoClientRequest>::into_client_request(self)
208 }
209}
210
211impl IntoClientRequest for String {
212 fn into_client_request(self) -> Result<Request> {
213 <&str as IntoClientRequest>::into_client_request(&self)
214 }
215}
216
217impl IntoClientRequest for &Uri {
218 fn into_client_request(self) -> Result<Request> {
219 self.clone().into_client_request()
220 }
221}
222
223impl IntoClientRequest for Uri {
224 fn into_client_request(self) -> Result<Request> {
225 let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
226 let host = authority
227 .find('@')
228 .map(|idx| authority.split_at(idx + 1).1)
229 .unwrap_or_else(|| authority);
230
231 if host.is_empty() {
232 return Err(Error::Url(UrlError::EmptyHostName));
233 }
234
235 let req = Request::builder()
236 .method("GET")
237 .header("Host", host)
238 .header("Connection", "Upgrade")
239 .header("Upgrade", "websocket")
240 .header("Sec-WebSocket-Version", "13")
241 .header("Sec-WebSocket-Key", generate_key())
242 .uri(self)
243 .body(())?;
244 Ok(req)
245 }
246}
247
248#[cfg(feature = "url")]
249impl IntoClientRequest for &url::Url {
250 fn into_client_request(self) -> Result<Request> {
251 self.as_str().into_client_request()
252 }
253}
254
255#[cfg(feature = "url")]
256impl IntoClientRequest for url::Url {
257 fn into_client_request(self) -> Result<Request> {
258 self.as_str().into_client_request()
259 }
260}
261
262impl IntoClientRequest for Request {
263 fn into_client_request(self) -> Result<Request> {
264 Ok(self)
265 }
266}
267
268impl IntoClientRequest for httparse::Request<'_, '_> {
269 fn into_client_request(self) -> Result<Request> {
270 use crate::handshake::headers::FromHttparse;
271 Request::from_httparse(self)
272 }
273}
274
275#[derive(Debug, Clone)]
293pub struct ClientRequestBuilder {
294 uri: Uri,
295 additional_headers: Vec<(String, String)>,
297 subprotocols: Vec<String>,
299}
300
301impl ClientRequestBuilder {
302 #[must_use]
304 pub const fn new(uri: Uri) -> Self {
305 Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
306 }
307
308 pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
310 where
311 K: Into<String>,
312 V: Into<String>,
313 {
314 self.additional_headers.push((key.into(), value.into()));
315 self
316 }
317
318 pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
320 where
321 P: Into<String>,
322 {
323 self.subprotocols.push(protocol.into());
324 self
325 }
326}
327
328impl IntoClientRequest for ClientRequestBuilder {
329 fn into_client_request(self) -> Result<Request> {
330 let mut request = self.uri.into_client_request()?;
331 let headers = request.headers_mut();
332 for (k, v) in self.additional_headers {
333 let key = HeaderName::try_from(k)?;
334 let value = v.parse()?;
335 headers.append(key, value);
336 }
337 if !self.subprotocols.is_empty() {
338 let protocols = self.subprotocols.join(", ").parse()?;
339 headers.append("Sec-WebSocket-Protocol", protocols);
340 }
341 Ok(request)
342 }
343}