tokio_postgres/
connect_tls.rs

1use crate::config::SslMode;
2use crate::maybe_tls_stream::MaybeTlsStream;
3use crate::tls::private::ForcePrivateApi;
4use crate::tls::TlsConnect;
5use crate::Error;
6use bytes::BytesMut;
7use postgres_protocol::message::frontend;
8use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
9
10pub async fn connect_tls<S, T>(
11    mut stream: S,
12    mode: SslMode,
13    tls: T,
14    has_hostname: bool,
15) -> Result<MaybeTlsStream<S, T::Stream>, Error>
16where
17    S: AsyncRead + AsyncWrite + Unpin,
18    T: TlsConnect<S>,
19{
20    match mode {
21        SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)),
22        SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => {
23            return Ok(MaybeTlsStream::Raw(stream))
24        }
25        SslMode::Prefer | SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {}
26    }
27
28    let mut buf = BytesMut::new();
29    frontend::ssl_request(&mut buf);
30    stream.write_all(&buf).await.map_err(Error::io)?;
31
32    let mut buf = [0];
33    stream.read_exact(&mut buf).await.map_err(Error::io)?;
34
35    if buf[0] != b'S' {
36        match mode {
37            SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => {
38                return Err(Error::tls("server does not support TLS".into()))
39            }
40            SslMode::Disable | SslMode::Prefer => return Ok(MaybeTlsStream::Raw(stream)),
41        }
42    }
43
44    if !has_hostname {
45        return Err(Error::tls("no hostname provided for TLS handshake".into()));
46    }
47
48    let stream = tls
49        .connect(stream)
50        .await
51        .map_err(|e| Error::tls(e.into()))?;
52
53    Ok(MaybeTlsStream::Tls(stream))
54}