tokio_postgres/
tls.rs

1//! TLS support.
2
3use std::error::Error;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::{fmt, io};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10pub(crate) mod private {
11    pub struct ForcePrivateApi;
12}
13
14/// Channel binding information returned from a TLS handshake.
15pub struct ChannelBinding {
16    pub(crate) tls_server_end_point: Option<Vec<u8>>,
17}
18
19impl ChannelBinding {
20    /// Creates a `ChannelBinding` containing no information.
21    pub fn none() -> ChannelBinding {
22        ChannelBinding {
23            tls_server_end_point: None,
24        }
25    }
26
27    /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information.
28    pub fn tls_server_end_point(tls_server_end_point: Vec<u8>) -> ChannelBinding {
29        ChannelBinding {
30            tls_server_end_point: Some(tls_server_end_point),
31        }
32    }
33}
34
35/// A constructor of `TlsConnect`ors.
36///
37/// Requires the `runtime` Cargo feature (enabled by default).
38#[cfg(feature = "runtime")]
39pub trait MakeTlsConnect<S> {
40    /// The stream type created by the `TlsConnect` implementation.
41    type Stream: TlsStream + Unpin;
42    /// The `TlsConnect` implementation created by this type.
43    type TlsConnect: TlsConnect<S, Stream = Self::Stream>;
44    /// The error type returned by the `TlsConnect` implementation.
45    type Error: Into<Box<dyn Error + Sync + Send>>;
46
47    /// Creates a new `TlsConnect`or.
48    ///
49    /// The domain name is provided for certificate verification and SNI.
50    fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
51}
52
53/// An asynchronous function wrapping a stream in a TLS session.
54pub trait TlsConnect<S> {
55    /// The stream returned by the future.
56    type Stream: TlsStream + Unpin;
57    /// The error returned by the future.
58    type Error: Into<Box<dyn Error + Sync + Send>>;
59    /// The future returned by the connector.
60    type Future: Future<Output = Result<Self::Stream, Self::Error>>;
61
62    /// Returns a future performing a TLS handshake over the stream.
63    fn connect(self, stream: S) -> Self::Future;
64
65    #[doc(hidden)]
66    fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
67        true
68    }
69}
70
71/// A TLS-wrapped connection to a PostgreSQL database.
72pub trait TlsStream: AsyncRead + AsyncWrite {
73    /// Returns channel binding information for the session.
74    fn channel_binding(&self) -> ChannelBinding;
75}
76
77/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error.
78///
79/// This can be used when `sslmode` is `none` or `prefer`.
80#[derive(Debug, Copy, Clone)]
81pub struct NoTls;
82
83#[cfg(feature = "runtime")]
84impl<S> MakeTlsConnect<S> for NoTls {
85    type Stream = NoTlsStream;
86    type TlsConnect = NoTls;
87    type Error = NoTlsError;
88
89    fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
90        Ok(NoTls)
91    }
92}
93
94impl<S> TlsConnect<S> for NoTls {
95    type Stream = NoTlsStream;
96    type Error = NoTlsError;
97    type Future = NoTlsFuture;
98
99    fn connect(self, _: S) -> NoTlsFuture {
100        NoTlsFuture(())
101    }
102
103    fn can_connect(&self, _: private::ForcePrivateApi) -> bool {
104        false
105    }
106}
107
108/// The future returned by `NoTls`.
109pub struct NoTlsFuture(());
110
111impl Future for NoTlsFuture {
112    type Output = Result<NoTlsStream, NoTlsError>;
113
114    fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
115        Poll::Ready(Err(NoTlsError(())))
116    }
117}
118
119/// The TLS "stream" type produced by the `NoTls` connector.
120///
121/// Since `NoTls` doesn't support TLS, this type is uninhabited.
122pub enum NoTlsStream {}
123
124impl AsyncRead for NoTlsStream {
125    fn poll_read(
126        self: Pin<&mut Self>,
127        _: &mut Context<'_>,
128        _: &mut ReadBuf<'_>,
129    ) -> Poll<io::Result<()>> {
130        match *self {}
131    }
132}
133
134impl AsyncWrite for NoTlsStream {
135    fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll<io::Result<usize>> {
136        match *self {}
137    }
138
139    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
140        match *self {}
141    }
142
143    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
144        match *self {}
145    }
146}
147
148impl TlsStream for NoTlsStream {
149    fn channel_binding(&self) -> ChannelBinding {
150        match *self {}
151    }
152}
153
154/// The error returned by `NoTls`.
155#[derive(Debug)]
156pub struct NoTlsError(());
157
158impl fmt::Display for NoTlsError {
159    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
160        fmt.write_str("no TLS implementation configured")
161    }
162}
163
164impl Error for NoTlsError {}