tungstenite/
client.rs

1//! Methods to connect to a WebSocket as a client.
2
3use std::{
4    io::{Read, Write},
5    net::{SocketAddr, TcpStream, ToSocketAddrs},
6    result::Result as StdResult,
7};
8
9use http::{request::Parts, HeaderName, Uri};
10use log::*;
11
12use crate::{
13    handshake::client::{generate_key, Request, Response},
14    protocol::WebSocketConfig,
15    stream::MaybeTlsStream,
16};
17
18use crate::{
19    error::{Error, Result, UrlError},
20    handshake::{client::ClientHandshake, HandshakeError},
21    protocol::WebSocket,
22    stream::{Mode, NoDelay},
23};
24
25/// Connect to the given WebSocket in blocking mode.
26///
27/// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
28/// equal to calling `connect()` function.
29///
30/// The URL may be either ws:// or wss://.
31/// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
32/// project's [README][readme] for more information on available features.
33///
34/// This function "just works" for those who wants a simple blocking solution
35/// similar to `std::net::TcpStream`. If you want a non-blocking or other
36/// custom stream, call `client` instead.
37///
38/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
39/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
40/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
41///
42/// [readme]: https://github.com/snapview/tungstenite-rs/#features
43pub fn connect_with_config<Req: IntoClientRequest>(
44    request: Req,
45    config: Option<WebSocketConfig>,
46    max_redirects: u8,
47) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
48    fn try_client_handshake(
49        request: Request,
50        config: Option<WebSocketConfig>,
51    ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
52        let uri = request.uri();
53        let mode = uri_mode(uri)?;
54
55        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
56        if let Mode::Tls = mode {
57            return Err(Error::Url(UrlError::TlsFeatureNotEnabled));
58        }
59
60        let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
61        let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
62        let port = uri.port_u16().unwrap_or(match mode {
63            Mode::Plain => 80,
64            Mode::Tls => 443,
65        });
66        let addrs = (host, port).to_socket_addrs()?;
67        let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
68        NoDelay::set_nodelay(&mut stream, true)?;
69
70        #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
71        let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
72        #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
73        let client = crate::tls::client_tls_with_config(request, stream, config, None);
74
75        client.map_err(|e| match e {
76            HandshakeError::Failure(f) => f,
77            HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
78        })
79    }
80
81    fn create_request(parts: &Parts, uri: &Uri) -> Request {
82        let mut builder =
83            Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
84        *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
85        builder.body(()).expect("Failed to create `Request`")
86    }
87
88    let (parts, _) = request.into_client_request()?.into_parts();
89    let mut uri = parts.uri.clone();
90
91    for attempt in 0..=max_redirects {
92        let request = create_request(&parts, &uri);
93
94        match try_client_handshake(request, config) {
95            Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
96                if let Some(location) = res.headers().get("Location") {
97                    uri = location.to_str()?.parse::<Uri>()?;
98                    debug!("Redirecting to {uri:?}");
99                    continue;
100                } else {
101                    warn!("No `Location` found in redirect");
102                    return Err(Error::Http(res));
103                }
104            }
105            other => return other,
106        }
107    }
108
109    unreachable!("Bug in a redirect handling logic")
110}
111
112/// Connect to the given WebSocket in blocking mode.
113///
114/// The URL may be either ws:// or wss://.
115/// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
116///
117/// This function "just works" for those who wants a simple blocking solution
118/// similar to `std::net::TcpStream`. If you want a non-blocking or other
119/// custom stream, call `client` instead.
120///
121/// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
122/// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
123/// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
124pub fn connect<Req: IntoClientRequest>(
125    request: Req,
126) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
127    connect_with_config(request, None, 3)
128}
129
130fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
131    for addr in addrs {
132        debug!("Trying to contact {uri} at {addr}...");
133        if let Ok(stream) = TcpStream::connect(addr) {
134            return Ok(stream);
135        }
136    }
137    Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
138}
139
140/// Get the mode of the given URL.
141///
142/// This function may be used to ease the creation of custom TLS streams
143/// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
144pub fn uri_mode(uri: &Uri) -> Result<Mode> {
145    match uri.scheme_str() {
146        Some("ws") => Ok(Mode::Plain),
147        Some("wss") => Ok(Mode::Tls),
148        _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
149    }
150}
151
152/// Do the client handshake over the given stream given a web socket configuration. Passing `None`
153/// as configuration is equal to calling `client()` function.
154///
155/// Use this function if you need a nonblocking handshake support or if you
156/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
157/// Any stream supporting `Read + Write` will do.
158pub fn client_with_config<Stream, Req>(
159    request: Req,
160    stream: Stream,
161    config: Option<WebSocketConfig>,
162) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
163where
164    Stream: Read + Write,
165    Req: IntoClientRequest,
166{
167    ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
168}
169
170/// Do the client handshake over the given stream.
171///
172/// Use this function if you need a nonblocking handshake support or if you
173/// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
174/// Any stream supporting `Read + Write` will do.
175pub fn client<Stream, Req>(
176    request: Req,
177    stream: Stream,
178) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
179where
180    Stream: Read + Write,
181    Req: IntoClientRequest,
182{
183    client_with_config(request, stream, None)
184}
185
186/// Trait for converting various types into HTTP requests used for a client connection.
187///
188/// This trait is implemented by default for string slices, strings, `http::Uri` and
189/// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
190/// simply take your request and pass it as is further without altering any headers or URLs, so
191/// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
192/// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
193/// the proper `http::Request<()>` for you.
194pub trait IntoClientRequest {
195    /// Convert into a `Request` that can be used for a client connection.
196    fn into_client_request(self) -> Result<Request>;
197}
198
199impl IntoClientRequest for &str {
200    fn into_client_request(self) -> Result<Request> {
201        self.parse::<Uri>()?.into_client_request()
202    }
203}
204
205impl IntoClientRequest for &String {
206    fn into_client_request(self) -> Result<Request> {
207        <&str as IntoClientRequest>::into_client_request(self)
208    }
209}
210
211impl IntoClientRequest for String {
212    fn into_client_request(self) -> Result<Request> {
213        <&str as IntoClientRequest>::into_client_request(&self)
214    }
215}
216
217impl IntoClientRequest for &Uri {
218    fn into_client_request(self) -> Result<Request> {
219        self.clone().into_client_request()
220    }
221}
222
223impl IntoClientRequest for Uri {
224    fn into_client_request(self) -> Result<Request> {
225        let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
226        let host = authority
227            .find('@')
228            .map(|idx| authority.split_at(idx + 1).1)
229            .unwrap_or_else(|| authority);
230
231        if host.is_empty() {
232            return Err(Error::Url(UrlError::EmptyHostName));
233        }
234
235        let req = Request::builder()
236            .method("GET")
237            .header("Host", host)
238            .header("Connection", "Upgrade")
239            .header("Upgrade", "websocket")
240            .header("Sec-WebSocket-Version", "13")
241            .header("Sec-WebSocket-Key", generate_key())
242            .uri(self)
243            .body(())?;
244        Ok(req)
245    }
246}
247
248#[cfg(feature = "url")]
249impl IntoClientRequest for &url::Url {
250    fn into_client_request(self) -> Result<Request> {
251        self.as_str().into_client_request()
252    }
253}
254
255#[cfg(feature = "url")]
256impl IntoClientRequest for url::Url {
257    fn into_client_request(self) -> Result<Request> {
258        self.as_str().into_client_request()
259    }
260}
261
262impl IntoClientRequest for Request {
263    fn into_client_request(self) -> Result<Request> {
264        Ok(self)
265    }
266}
267
268impl IntoClientRequest for httparse::Request<'_, '_> {
269    fn into_client_request(self) -> Result<Request> {
270        use crate::handshake::headers::FromHttparse;
271        Request::from_httparse(self)
272    }
273}
274
275/// Builder for a custom [`IntoClientRequest`] with options to add
276/// custom additional headers and sub protocols.
277///
278/// # Example
279///
280/// ```rust no_run
281/// # use crate::*;
282/// use http::Uri;
283/// use tungstenite::{connect, ClientRequestBuilder};
284///
285/// let uri: Uri = "ws://localhost:3012/socket".parse().unwrap();
286/// let token = "my_jwt_token";
287/// let builder = ClientRequestBuilder::new(uri)
288///     .with_header("Authorization", format!("Bearer {token}"))
289///     .with_sub_protocol("my_sub_protocol");
290/// let socket = connect(builder).unwrap();
291/// ```
292#[derive(Debug, Clone)]
293pub struct ClientRequestBuilder {
294    uri: Uri,
295    /// Additional [`Request`] handshake headers
296    additional_headers: Vec<(String, String)>,
297    /// Handsake subprotocols
298    subprotocols: Vec<String>,
299}
300
301impl ClientRequestBuilder {
302    /// Initializes an empty request builder
303    #[must_use]
304    pub const fn new(uri: Uri) -> Self {
305        Self { uri, additional_headers: Vec::new(), subprotocols: Vec::new() }
306    }
307
308    /// Adds (`key`, `value`) as an additional header to the handshake request
309    pub fn with_header<K, V>(mut self, key: K, value: V) -> Self
310    where
311        K: Into<String>,
312        V: Into<String>,
313    {
314        self.additional_headers.push((key.into(), value.into()));
315        self
316    }
317
318    /// Adds `protocol` to the handshake request subprotocols (`Sec-WebSocket-Protocol`)
319    pub fn with_sub_protocol<P>(mut self, protocol: P) -> Self
320    where
321        P: Into<String>,
322    {
323        self.subprotocols.push(protocol.into());
324        self
325    }
326}
327
328impl IntoClientRequest for ClientRequestBuilder {
329    fn into_client_request(self) -> Result<Request> {
330        let mut request = self.uri.into_client_request()?;
331        let headers = request.headers_mut();
332        for (k, v) in self.additional_headers {
333            let key = HeaderName::try_from(k)?;
334            let value = v.parse()?;
335            headers.append(key, value);
336        }
337        if !self.subprotocols.is_empty() {
338            let protocols = self.subprotocols.join(", ").parse()?;
339            headers.append("Sec-WebSocket-Protocol", protocols);
340        }
341        Ok(request)
342    }
343}