Skip to main content

mz_tls_util/
lib.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! A tiny utility library for making TLS connectors.
11
12use mz_ore::secure::{Zeroize, Zeroizing};
13use openssl::pkcs12::Pkcs12;
14use openssl::pkey::PKey;
15use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
16use openssl::stack::Stack;
17use openssl::x509::X509;
18use postgres_openssl::MakeTlsConnector;
19use tokio_postgres::config::SslMode;
20
21macro_rules! bail_generic {
22    ($err:expr $(,)?) => {
23        return Err(TlsError::Generic(anyhow::anyhow!($err)))
24    };
25}
26
27/// An error representing tls failures.
28#[derive(Debug, thiserror::Error)]
29pub enum TlsError {
30    /// Any other error we bail on.
31    #[error(transparent)]
32    Generic(#[from] anyhow::Error),
33    /// Error setting up postgres ssl.
34    #[error(transparent)]
35    OpenSsl(#[from] openssl::error::ErrorStack),
36}
37
38/// Creates a TLS connector for the given [`Config`](tokio_postgres::Config).
39pub fn make_tls(config: &tokio_postgres::Config) -> Result<MakeTlsConnector, TlsError> {
40    let mut builder = SslConnector::builder(SslMethod::tls_client())?;
41    // The mode dictates whether we verify peer certs and hostnames. By default, Postgres is
42    // pretty relaxed and recommends SslMode::VerifyCa or SslMode::VerifyFull for security.
43    //
44    // For more details, check out Table 33.1. SSL Mode Descriptions in
45    // https://postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION.
46    let (verify_mode, verify_hostname) = match config.get_ssl_mode() {
47        SslMode::Disable | SslMode::Prefer => (SslVerifyMode::NONE, false),
48        SslMode::Require => match config.get_ssl_root_cert() {
49            // If a root CA file exists, the behavior of sslmode=require will be the same as
50            // that of verify-ca, meaning the server certificate is validated against the CA.
51            //
52            // For more details, check out the note about backwards compatibility in
53            // https://postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES.
54            Some(_) => (SslVerifyMode::PEER, false),
55            None => (SslVerifyMode::NONE, false),
56        },
57        SslMode::VerifyCa => (SslVerifyMode::PEER, false),
58        SslMode::VerifyFull => (SslVerifyMode::PEER, true),
59        _ => panic!("unexpected sslmode {:?}", config.get_ssl_mode()),
60    };
61
62    // Configure peer verification
63    builder.set_verify(verify_mode);
64
65    // Configure certificates
66    match (config.get_ssl_cert(), config.get_ssl_key()) {
67        (Some(ssl_cert), Some(ssl_key)) => {
68            builder.set_certificate(&*X509::from_pem(ssl_cert)?)?;
69            builder.set_private_key(&*PKey::private_key_from_pem(ssl_key)?)?;
70        }
71        (None, Some(_)) => {
72            bail_generic!("must provide both sslcert and sslkey, but only provided sslkey")
73        }
74        (Some(_), None) => {
75            bail_generic!("must provide both sslcert and sslkey, but only provided sslcert")
76        }
77        _ => {}
78    }
79    if let Some(ssl_root_cert) = config.get_ssl_root_cert() {
80        for cert in X509::stack_from_pem(ssl_root_cert)? {
81            builder.cert_store_mut().add_cert(cert)?;
82        }
83    }
84
85    let mut tls_connector = MakeTlsConnector::new(builder.build());
86
87    // Configure hostname verification
88    match (verify_mode, verify_hostname) {
89        (SslVerifyMode::PEER, false) => tls_connector.set_callback(|connect, _| {
90            connect.set_verify_hostname(false);
91            Ok(())
92        }),
93        _ => {}
94    }
95
96    Ok(tls_connector)
97}
98
99pub struct Pkcs12Archive {
100    pub der: Vec<u8>,
101    pub pass: String,
102}
103
104impl Zeroize for Pkcs12Archive {
105    fn zeroize(&mut self) {
106        self.der.zeroize();
107        self.pass.zeroize();
108    }
109}
110
111impl Drop for Pkcs12Archive {
112    fn drop(&mut self) {
113        self.zeroize();
114    }
115}
116
117impl Pkcs12Archive {
118    pub fn into_parts(self) -> (Vec<u8>, String) {
119        let mut md = std::mem::ManuallyDrop::new(self);
120        let der = std::mem::take(&mut md.der);
121        let pass = std::mem::take(&mut md.pass);
122        (der, pass)
123    }
124}
125
126/// Constructs an identity from a PEM-formatted key and certificate using OpenSSL.
127pub fn pkcs12der_from_pem(
128    key: &[u8],
129    cert: &[u8],
130) -> Result<Pkcs12Archive, openssl::error::ErrorStack> {
131    let mut buf = Zeroizing::new(Vec::new());
132    buf.extend(key);
133    buf.push(b'\n');
134    buf.extend(cert);
135    let pem = buf.as_slice();
136    let pkey = PKey::private_key_from_pem(pem)?;
137    let mut certs = Stack::new()?;
138
139    // `X509::stack_from_pem` in openssl as of at least versions <= 0.10.48
140    // does not guarantee that it will either error or return at least 1
141    // element; in fact, it doesn't if the `pem` is not a well-formed
142    // representation of a PEM file. For example, if the represented file
143    // contains a well-formed key but a malformed certificate.
144    //
145    // To circumvent this issue, if `X509::stack_from_pem` returns no
146    // certificates, rely on getting the error message from
147    // `X509::from_pem`.
148    let mut cert_iter = X509::stack_from_pem(pem)?.into_iter();
149    let cert = match cert_iter.next() {
150        Some(cert) => cert,
151        None => X509::from_pem(pem)?,
152    };
153    for cert in cert_iter {
154        certs.push(cert)?;
155    }
156    // We build a PKCS #12 archive solely to have something to pass to
157    // `reqwest::Identity::from_pkcs12_der`, so the password and friendly
158    // name don't matter.
159    let pass = String::new();
160    let friendly_name = "";
161    let der = Pkcs12::builder()
162        .name(friendly_name)
163        .pkey(&pkey)
164        .cert(&cert)
165        .ca(certs)
166        .build2(&pass)?
167        .to_der()?;
168    Ok(Pkcs12Archive { der, pass })
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[mz_ore::test]
176    fn pkcs12_archive_needs_drop() {
177        assert!(std::mem::needs_drop::<Pkcs12Archive>());
178    }
179
180    #[mz_ore::test]
181    fn pkcs12_archive_zeroize_clears_fields() {
182        let mut archive = Pkcs12Archive {
183            der: vec![0xDE, 0xAD, 0xBE, 0xEF],
184            pass: String::from("hunter2"),
185        };
186
187        archive.zeroize();
188
189        assert!(archive.der.is_empty(), "der was not zeroed");
190        assert!(archive.pass.is_empty(), "pass was not zeroed");
191    }
192
193    #[mz_ore::test]
194    fn pkcs12_archive_implements_zeroize() {
195        fn assert_zeroize<T: mz_ore::secure::Zeroize>() {}
196        assert_zeroize::<Pkcs12Archive>();
197    }
198}