postgres_openssl/
lib.rs

1//! TLS support for `tokio-postgres` and `postgres` via `openssl`.
2//!
3//! # Examples
4//!
5//! ```no_run
6//! use openssl::ssl::{SslConnector, SslMethod};
7//! # #[cfg(feature = "runtime")]
8//! use postgres_openssl::MakeTlsConnector;
9//!
10//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
11//! # #[cfg(feature = "runtime")] {
12//! let mut builder = SslConnector::builder(SslMethod::tls())?;
13//! builder.set_ca_file("database_cert.pem")?;
14//! let connector = MakeTlsConnector::new(builder.build());
15//!
16//! let connect_future = tokio_postgres::connect(
17//!     "host=localhost user=postgres sslmode=require",
18//!     connector,
19//! );
20//! # }
21//!
22//! // ...
23//! # Ok(())
24//! # }
25//! ```
26//!
27//! ```no_run
28//! use openssl::ssl::{SslConnector, SslMethod};
29//! # #[cfg(feature = "runtime")]
30//! use postgres_openssl::MakeTlsConnector;
31//!
32//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
33//! # #[cfg(feature = "runtime")] {
34//! let mut builder = SslConnector::builder(SslMethod::tls())?;
35//! builder.set_ca_file("database_cert.pem")?;
36//! let connector = MakeTlsConnector::new(builder.build());
37//!
38//! let client = postgres::Client::connect(
39//!     "host=localhost user=postgres sslmode=require",
40//!     connector,
41//! )?;
42//! # }
43//!
44//! // ...
45//! # Ok(())
46//! # }
47//! ```
48#![warn(rust_2018_idioms, clippy::all, missing_docs)]
49
50#[cfg(feature = "runtime")]
51use openssl::error::ErrorStack;
52use openssl::hash::MessageDigest;
53use openssl::nid::Nid;
54#[cfg(feature = "runtime")]
55use openssl::ssl::SslConnector;
56use openssl::ssl::{self, ConnectConfiguration, SslRef};
57use openssl::x509::X509VerifyResult;
58use std::error::Error;
59use std::fmt::{self, Debug};
60use std::future::Future;
61use std::io;
62use std::pin::Pin;
63#[cfg(feature = "runtime")]
64use std::sync::Arc;
65use std::task::{Context, Poll};
66use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
67use tokio_openssl::SslStream;
68use tokio_postgres::tls;
69#[cfg(feature = "runtime")]
70use tokio_postgres::tls::MakeTlsConnect;
71use tokio_postgres::tls::{ChannelBinding, TlsConnect};
72
73#[cfg(test)]
74mod test;
75
76type ConfigCallback =
77    dyn Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + Sync + Send;
78
79/// A `MakeTlsConnect` implementation using the `openssl` crate.
80///
81/// Requires the `runtime` Cargo feature (enabled by default).
82#[cfg(feature = "runtime")]
83#[derive(Clone)]
84pub struct MakeTlsConnector {
85    connector: SslConnector,
86    config: Arc<ConfigCallback>,
87}
88
89#[cfg(feature = "runtime")]
90impl MakeTlsConnector {
91    /// Creates a new connector.
92    pub fn new(connector: SslConnector) -> MakeTlsConnector {
93        MakeTlsConnector {
94            connector,
95            config: Arc::new(|_, _| Ok(())),
96        }
97    }
98
99    /// Sets a callback used to apply per-connection configuration.
100    ///
101    /// The the callback is provided the domain name along with the `ConnectConfiguration`.
102    pub fn set_callback<F>(&mut self, f: F)
103    where
104        F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send,
105    {
106        self.config = Arc::new(f);
107    }
108}
109
110#[cfg(feature = "runtime")]
111impl<S> MakeTlsConnect<S> for MakeTlsConnector
112where
113    S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
114{
115    type Stream = TlsStream<S>;
116    type TlsConnect = TlsConnector;
117    type Error = ErrorStack;
118
119    fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
120        let mut ssl = self.connector.configure()?;
121        (self.config)(&mut ssl, domain)?;
122        Ok(TlsConnector::new(ssl, domain))
123    }
124}
125
126/// A `TlsConnect` implementation using the `openssl` crate.
127pub struct TlsConnector {
128    ssl: ConnectConfiguration,
129    domain: String,
130}
131
132impl TlsConnector {
133    /// Creates a new connector configured to connect to the specified domain.
134    pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector {
135        TlsConnector {
136            ssl,
137            domain: domain.to_string(),
138        }
139    }
140}
141
142impl<S> TlsConnect<S> for TlsConnector
143where
144    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
145{
146    type Stream = TlsStream<S>;
147    type Error = Box<dyn Error + Send + Sync>;
148    #[allow(clippy::type_complexity)]
149    type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, Self::Error>> + Send>>;
150
151    fn connect(self, stream: S) -> Self::Future {
152        let stream = BufReader::with_capacity(8192, stream);
153        let future = async move {
154            let ssl = self.ssl.into_ssl(&self.domain)?;
155            let mut stream = SslStream::new(ssl, stream)?;
156            match Pin::new(&mut stream).connect().await {
157                Ok(()) => Ok(TlsStream(stream)),
158                Err(error) => Err(Box::new(ConnectError {
159                    error,
160                    verify_result: stream.ssl().verify_result(),
161                }) as _),
162            }
163        };
164
165        Box::pin(future)
166    }
167}
168
169#[derive(Debug)]
170struct ConnectError {
171    error: ssl::Error,
172    verify_result: X509VerifyResult,
173}
174
175impl fmt::Display for ConnectError {
176    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
177        fmt::Display::fmt(&self.error, fmt)?;
178
179        if self.verify_result != X509VerifyResult::OK {
180            fmt.write_str(": ")?;
181            fmt::Display::fmt(&self.verify_result, fmt)?;
182        }
183
184        Ok(())
185    }
186}
187
188impl Error for ConnectError {
189    fn source(&self) -> Option<&(dyn Error + 'static)> {
190        Some(&self.error)
191    }
192}
193
194/// The stream returned by `TlsConnector`.
195pub struct TlsStream<S>(SslStream<BufReader<S>>);
196
197impl<S> AsyncRead for TlsStream<S>
198where
199    S: AsyncRead + AsyncWrite + Unpin,
200{
201    fn poll_read(
202        mut self: Pin<&mut Self>,
203        cx: &mut Context<'_>,
204        buf: &mut ReadBuf<'_>,
205    ) -> Poll<io::Result<()>> {
206        Pin::new(&mut self.0).poll_read(cx, buf)
207    }
208}
209
210impl<S> AsyncWrite for TlsStream<S>
211where
212    S: AsyncRead + AsyncWrite + Unpin,
213{
214    fn poll_write(
215        mut self: Pin<&mut Self>,
216        cx: &mut Context<'_>,
217        buf: &[u8],
218    ) -> Poll<io::Result<usize>> {
219        Pin::new(&mut self.0).poll_write(cx, buf)
220    }
221
222    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223        Pin::new(&mut self.0).poll_flush(cx)
224    }
225
226    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
227        Pin::new(&mut self.0).poll_shutdown(cx)
228    }
229}
230
231impl<S> tls::TlsStream for TlsStream<S>
232where
233    S: AsyncRead + AsyncWrite + Unpin,
234{
235    fn channel_binding(&self) -> ChannelBinding {
236        match tls_server_end_point(self.0.ssl()) {
237            Some(buf) => ChannelBinding::tls_server_end_point(buf),
238            None => ChannelBinding::none(),
239        }
240    }
241}
242
243fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
244    let cert = ssl.peer_certificate()?;
245    let algo_nid = cert.signature_algorithm().object().nid();
246    let signature_algorithms = algo_nid.signature_algorithms()?;
247    let md = match signature_algorithms.digest {
248        Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
249        nid => MessageDigest::from_nid(nid)?,
250    };
251    cert.digest(md).ok().map(|b| b.to_vec())
252}