tokio_tungstenite/
stream.rs1use std::{
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13#[non_exhaustive]
15#[derive(Debug)]
16pub enum MaybeTlsStream<S> {
17 Plain(S),
19 #[cfg(feature = "native-tls")]
21 NativeTls(tokio_native_tls::TlsStream<S>),
22 #[cfg(feature = "__rustls-tls")]
24 Rustls(tokio_rustls::client::TlsStream<S>),
25}
26
27impl<S> MaybeTlsStream<S> {
28 pub fn get_ref(&self) -> &S {
30 match self {
31 MaybeTlsStream::Plain(s) => s,
32 #[cfg(feature = "native-tls")]
33 MaybeTlsStream::NativeTls(s) => s.get_ref().get_ref().get_ref(),
34 #[cfg(feature = "__rustls-tls")]
35 MaybeTlsStream::Rustls(s) => s.get_ref().0,
36 }
37 }
38}
39
40impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeTlsStream<S> {
41 fn poll_read(
42 self: Pin<&mut Self>,
43 cx: &mut Context<'_>,
44 buf: &mut ReadBuf<'_>,
45 ) -> Poll<std::io::Result<()>> {
46 match self.get_mut() {
47 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
48 #[cfg(feature = "native-tls")]
49 MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_read(cx, buf),
50 #[cfg(feature = "__rustls-tls")]
51 MaybeTlsStream::Rustls(s) => Pin::new(s).poll_read(cx, buf),
52 }
53 }
54}
55
56impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MaybeTlsStream<S> {
57 fn poll_write(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &[u8],
61 ) -> Poll<Result<usize, std::io::Error>> {
62 match self.get_mut() {
63 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
64 #[cfg(feature = "native-tls")]
65 MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_write(cx, buf),
66 #[cfg(feature = "__rustls-tls")]
67 MaybeTlsStream::Rustls(s) => Pin::new(s).poll_write(cx, buf),
68 }
69 }
70
71 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
72 match self.get_mut() {
73 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
74 #[cfg(feature = "native-tls")]
75 MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_flush(cx),
76 #[cfg(feature = "__rustls-tls")]
77 MaybeTlsStream::Rustls(s) => Pin::new(s).poll_flush(cx),
78 }
79 }
80
81 fn poll_shutdown(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 ) -> Poll<Result<(), std::io::Error>> {
85 match self.get_mut() {
86 MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
87 #[cfg(feature = "native-tls")]
88 MaybeTlsStream::NativeTls(s) => Pin::new(s).poll_shutdown(cx),
89 #[cfg(feature = "__rustls-tls")]
90 MaybeTlsStream::Rustls(s) => Pin::new(s).poll_shutdown(cx),
91 }
92 }
93}