Skip to main content

tokio_tungstenite/
stream.rs

1//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2//!
3//!  There is no dependency on actual TLS implementations. Everything like
4//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5//! `Read + Write` traits.
6use std::{
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13/// A stream that might be protected with TLS.
14#[non_exhaustive]
15#[derive(Debug)]
16pub enum MaybeTlsStream<S> {
17    /// Unencrypted socket stream.
18    Plain(S),
19    /// Encrypted socket stream using `native-tls`.
20    #[cfg(feature = "native-tls")]
21    NativeTls(tokio_native_tls::TlsStream<S>),
22    /// Encrypted socket stream using `rustls`.
23    #[cfg(feature = "__rustls-tls")]
24    Rustls(tokio_rustls::client::TlsStream<S>),
25}
26
27impl<S> MaybeTlsStream<S> {
28    /// Returns a shared reference to the inner stream.
29    pub fn get_ref(&self) -> &S {
30        match self {
31            MaybeTlsStream::Plain(s) => s,
32            #[cfg(feature = "native-tls")]
33            MaybeTlsStream::NativeTls(s) => s.get_ref().get_ref().get_ref(),
34            #[cfg(feature = "__rustls-tls")]
35            MaybeTlsStream::Rustls(s) => s.get_ref().0,
36        }
37    }
38}
39
40impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
41    fn poll_read(
42        self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44        buf: &mut ReadBuf<'_>,
45    ) -> Poll<std::io::Result<()>> {
46        match self.get_mut() {
47            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
48            #[cfg(feature = "native-tls")]
49            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
50            #[cfg(feature = "__rustls-tls")]
51            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf),
52        }
53    }
54}
55
56impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
57    fn poll_write(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &[u8],
61    ) -> Poll<Result<usize, std::io::Error>> {
62        match self.get_mut() {
63            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
64            #[cfg(feature = "native-tls")]
65            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
66            #[cfg(feature = "__rustls-tls")]
67            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf),
68        }
69    }
70
71    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
72        match self.get_mut() {
73            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
74            #[cfg(feature = "native-tls")]
75            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
76            #[cfg(feature = "__rustls-tls")]
77            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
78        }
79    }
80
81    fn poll_shutdown(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84    ) -> Poll<Result<(), std::io::Error>> {
85        match self.get_mut() {
86            MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
87            #[cfg(feature = "native-tls")]
88            MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
89            #[cfg(feature = "__rustls-tls")]
90            MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
91        }
92    }
93}