tokio_postgres/
maybe_tls_stream.rs

1use crate::tls::{ChannelBinding, TlsStream};
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
6
7pub enum MaybeTlsStream<S, T> {
8    Raw(S),
9    Tls(T),
10}
11
12impl<S, T> AsyncRead for MaybeTlsStream<S, T>
13where
14    S: AsyncRead + Unpin,
15    T: AsyncRead + Unpin,
16{
17    fn poll_read(
18        mut self: Pin<&mut Self>,
19        cx: &mut Context<'_>,
20        buf: &mut ReadBuf<'_>,
21    ) -> Poll<io::Result<()>> {
22        match &mut *self {
23            MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf),
24            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
25        }
26    }
27}
28
29impl<S, T> AsyncWrite for MaybeTlsStream<S, T>
30where
31    S: AsyncWrite + Unpin,
32    T: AsyncWrite + Unpin,
33{
34    fn poll_write(
35        mut self: Pin<&mut Self>,
36        cx: &mut Context<'_>,
37        buf: &[u8],
38    ) -> Poll<io::Result<usize>> {
39        match &mut *self {
40            MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf),
41            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
42        }
43    }
44
45    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46        match &mut *self {
47            MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx),
48            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
49        }
50    }
51
52    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53        match &mut *self {
54            MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx),
55            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
56        }
57    }
58}
59
60impl<S, T> TlsStream for MaybeTlsStream<S, T>
61where
62    S: AsyncRead + AsyncWrite + Unpin,
63    T: TlsStream + Unpin,
64{
65    fn channel_binding(&self) -> ChannelBinding {
66        match self {
67            MaybeTlsStream::Raw(_) => ChannelBinding::none(),
68            MaybeTlsStream::Tls(s) => s.channel_binding(),
69        }
70    }
71}