tungstenite/
stream.rs

1//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2//!
3//!  There is no dependency on actual TLS implementations. Everything like
4//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5//! `Read + Write` traits.
6
7#[cfg(feature = "__rustls-tls")]
8use std::ops::Deref;
9use std::{
10    fmt::{self, Debug},
11    io::{Read, Result as IoResult, Write},
12};
13
14use std::net::TcpStream;
15
16#[cfg(feature = "native-tls")]
17use native_tls_crate::TlsStream;
18#[cfg(feature = "__rustls-tls")]
19use rustls::StreamOwned;
20
21/// Stream mode, either plain TCP or TLS.
22#[derive(Clone, Copy, Debug)]
23pub enum Mode {
24    /// Plain mode (`ws://` URL).
25    Plain,
26    /// TLS mode (`wss://` URL).
27    Tls,
28}
29
30/// Trait to switch TCP_NODELAY.
31pub trait NoDelay {
32    /// Set the TCP_NODELAY option to the given value.
33    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()>;
34}
35
36impl NoDelay for TcpStream {
37    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
38        TcpStream::set_nodelay(self, nodelay)
39    }
40}
41
42#[cfg(feature = "native-tls")]
43impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
44    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
45        self.get_mut().set_nodelay(nodelay)
46    }
47}
48
49#[cfg(feature = "__rustls-tls")]
50impl<S, SD, T> NoDelay for StreamOwned<S, T>
51where
52    S: Deref<Target = rustls::ConnectionCommon<SD>>,
53    SD: rustls::SideData,
54    T: Read + Write + NoDelay,
55{
56    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
57        self.sock.set_nodelay(nodelay)
58    }
59}
60
61/// A stream that might be protected with TLS.
62#[non_exhaustive]
63#[allow(clippy::large_enum_variant)]
64pub enum MaybeTlsStream<S: Read + Write> {
65    /// Unencrypted socket stream.
66    Plain(S),
67    #[cfg(feature = "native-tls")]
68    /// Encrypted socket stream using `native-tls`.
69    NativeTls(native_tls_crate::TlsStream<S>),
70    #[cfg(feature = "__rustls-tls")]
71    /// Encrypted socket stream using `rustls`.
72    Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
73}
74
75impl<S: Read + Write + Debug> Debug for MaybeTlsStream<S> {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        match self {
78            Self::Plain(s) => f.debug_tuple("MaybeTlsStream::Plain").field(s).finish(),
79            #[cfg(feature = "native-tls")]
80            Self::NativeTls(s) => f.debug_tuple("MaybeTlsStream::NativeTls").field(s).finish(),
81            #[cfg(feature = "__rustls-tls")]
82            Self::Rustls(s) => {
83                struct RustlsStreamDebug<'a, S: Read + Write>(
84                    &'a rustls::StreamOwned<rustls::ClientConnection, S>,
85                );
86
87                impl<S: Read + Write + Debug> Debug for RustlsStreamDebug<'_, S> {
88                    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89                        f.debug_struct("StreamOwned")
90                            .field("conn", &self.0.conn)
91                            .field("sock", &self.0.sock)
92                            .finish()
93                    }
94                }
95
96                f.debug_tuple("MaybeTlsStream::Rustls").field(&RustlsStreamDebug(s)).finish()
97            }
98        }
99    }
100}
101
102impl<S: Read + Write> Read for MaybeTlsStream<S> {
103    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
104        match *self {
105            MaybeTlsStream::Plain(ref mut s) => s.read(buf),
106            #[cfg(feature = "native-tls")]
107            MaybeTlsStream::NativeTls(ref mut s) => s.read(buf),
108            #[cfg(feature = "__rustls-tls")]
109            MaybeTlsStream::Rustls(ref mut s) => s.read(buf),
110        }
111    }
112}
113
114impl<S: Read + Write> Write for MaybeTlsStream<S> {
115    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
116        match *self {
117            MaybeTlsStream::Plain(ref mut s) => s.write(buf),
118            #[cfg(feature = "native-tls")]
119            MaybeTlsStream::NativeTls(ref mut s) => s.write(buf),
120            #[cfg(feature = "__rustls-tls")]
121            MaybeTlsStream::Rustls(ref mut s) => s.write(buf),
122        }
123    }
124
125    fn flush(&mut self) -> IoResult<()> {
126        match *self {
127            MaybeTlsStream::Plain(ref mut s) => s.flush(),
128            #[cfg(feature = "native-tls")]
129            MaybeTlsStream::NativeTls(ref mut s) => s.flush(),
130            #[cfg(feature = "__rustls-tls")]
131            MaybeTlsStream::Rustls(ref mut s) => s.flush(),
132        }
133    }
134}
135
136impl<S: Read + Write + NoDelay> NoDelay for MaybeTlsStream<S> {
137    fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
138        match *self {
139            MaybeTlsStream::Plain(ref mut s) => s.set_nodelay(nodelay),
140            #[cfg(feature = "native-tls")]
141            MaybeTlsStream::NativeTls(ref mut s) => s.set_nodelay(nodelay),
142            #[cfg(feature = "__rustls-tls")]
143            MaybeTlsStream::Rustls(ref mut s) => s.set_nodelay(nodelay),
144        }
145    }
146}