postgres_openssl/
lib.rs
1#![warn(rust_2018_idioms, clippy::all, missing_docs)]
49
50#[cfg(feature = "runtime")]
51use openssl::error::ErrorStack;
52use openssl::hash::MessageDigest;
53use openssl::nid::Nid;
54#[cfg(feature = "runtime")]
55use openssl::ssl::SslConnector;
56use openssl::ssl::{self, ConnectConfiguration, SslRef};
57use openssl::x509::X509VerifyResult;
58use std::error::Error;
59use std::fmt::{self, Debug};
60use std::future::Future;
61use std::io;
62use std::pin::Pin;
63#[cfg(feature = "runtime")]
64use std::sync::Arc;
65use std::task::{Context, Poll};
66use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf};
67use tokio_openssl::SslStream;
68use tokio_postgres::tls;
69#[cfg(feature = "runtime")]
70use tokio_postgres::tls::MakeTlsConnect;
71use tokio_postgres::tls::{ChannelBinding, TlsConnect};
72
73#[cfg(test)]
74mod test;
75
76type ConfigCallback =
77 dyn Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + Sync + Send;
78
79#[cfg(feature = "runtime")]
83#[derive(Clone)]
84pub struct MakeTlsConnector {
85 connector: SslConnector,
86 config: Arc<ConfigCallback>,
87}
88
89#[cfg(feature = "runtime")]
90impl MakeTlsConnector {
91 pub fn new(connector: SslConnector) -> MakeTlsConnector {
93 MakeTlsConnector {
94 connector,
95 config: Arc::new(|_, _| Ok(())),
96 }
97 }
98
99 pub fn set_callback<F>(&mut self, f: F)
103 where
104 F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send,
105 {
106 self.config = Arc::new(f);
107 }
108}
109
110#[cfg(feature = "runtime")]
111impl<S> MakeTlsConnect<S> for MakeTlsConnector
112where
113 S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send,
114{
115 type Stream = TlsStream<S>;
116 type TlsConnect = TlsConnector;
117 type Error = ErrorStack;
118
119 fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, ErrorStack> {
120 let mut ssl = self.connector.configure()?;
121 (self.config)(&mut ssl, domain)?;
122 Ok(TlsConnector::new(ssl, domain))
123 }
124}
125
126pub struct TlsConnector {
128 ssl: ConnectConfiguration,
129 domain: String,
130}
131
132impl TlsConnector {
133 pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector {
135 TlsConnector {
136 ssl,
137 domain: domain.to_string(),
138 }
139 }
140}
141
142impl<S> TlsConnect<S> for TlsConnector
143where
144 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
145{
146 type Stream = TlsStream<S>;
147 type Error = Box<dyn Error + Send + Sync>;
148 #[allow(clippy::type_complexity)]
149 type Future = Pin<Box<dyn Future<Output = Result<TlsStream<S>, Self::Error>> + Send>>;
150
151 fn connect(self, stream: S) -> Self::Future {
152 let stream = BufReader::with_capacity(8192, stream);
153 let future = async move {
154 let ssl = self.ssl.into_ssl(&self.domain)?;
155 let mut stream = SslStream::new(ssl, stream)?;
156 match Pin::new(&mut stream).connect().await {
157 Ok(()) => Ok(TlsStream(stream)),
158 Err(error) => Err(Box::new(ConnectError {
159 error,
160 verify_result: stream.ssl().verify_result(),
161 }) as _),
162 }
163 };
164
165 Box::pin(future)
166 }
167}
168
169#[derive(Debug)]
170struct ConnectError {
171 error: ssl::Error,
172 verify_result: X509VerifyResult,
173}
174
175impl fmt::Display for ConnectError {
176 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
177 fmt::Display::fmt(&self.error, fmt)?;
178
179 if self.verify_result != X509VerifyResult::OK {
180 fmt.write_str(": ")?;
181 fmt::Display::fmt(&self.verify_result, fmt)?;
182 }
183
184 Ok(())
185 }
186}
187
188impl Error for ConnectError {
189 fn source(&self) -> Option<&(dyn Error + 'static)> {
190 Some(&self.error)
191 }
192}
193
194pub struct TlsStream<S>(SslStream<BufReader<S>>);
196
197impl<S> AsyncRead for TlsStream<S>
198where
199 S: AsyncRead + AsyncWrite + Unpin,
200{
201 fn poll_read(
202 mut self: Pin<&mut Self>,
203 cx: &mut Context<'_>,
204 buf: &mut ReadBuf<'_>,
205 ) -> Poll<io::Result<()>> {
206 Pin::new(&mut self.0).poll_read(cx, buf)
207 }
208}
209
210impl<S> AsyncWrite for TlsStream<S>
211where
212 S: AsyncRead + AsyncWrite + Unpin,
213{
214 fn poll_write(
215 mut self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 buf: &[u8],
218 ) -> Poll<io::Result<usize>> {
219 Pin::new(&mut self.0).poll_write(cx, buf)
220 }
221
222 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
223 Pin::new(&mut self.0).poll_flush(cx)
224 }
225
226 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
227 Pin::new(&mut self.0).poll_shutdown(cx)
228 }
229}
230
231impl<S> tls::TlsStream for TlsStream<S>
232where
233 S: AsyncRead + AsyncWrite + Unpin,
234{
235 fn channel_binding(&self) -> ChannelBinding {
236 match tls_server_end_point(self.0.ssl()) {
237 Some(buf) => ChannelBinding::tls_server_end_point(buf),
238 None => ChannelBinding::none(),
239 }
240 }
241}
242
243fn tls_server_end_point(ssl: &SslRef) -> Option<Vec<u8>> {
244 let cert = ssl.peer_certificate()?;
245 let algo_nid = cert.signature_algorithm().object().nid();
246 let signature_algorithms = algo_nid.signature_algorithms()?;
247 let md = match signature_algorithms.digest {
248 Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
249 nid => MessageDigest::from_nid(nid)?,
250 };
251 cert.digest(md).ok().map(|b| b.to_vec())
252}