1use mz_ore::secure::{Zeroize, Zeroizing};
13use openssl::pkcs12::Pkcs12;
14use openssl::pkey::PKey;
15use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode};
16use openssl::stack::Stack;
17use openssl::x509::X509;
18use postgres_openssl::MakeTlsConnector;
19use tokio_postgres::config::SslMode;
20
21macro_rules! bail_generic {
22 ($err:expr $(,)?) => {
23 return Err(TlsError::Generic(anyhow::anyhow!($err)))
24 };
25}
26
27#[derive(Debug, thiserror::Error)]
29pub enum TlsError {
30 #[error(transparent)]
32 Generic(#[from] anyhow::Error),
33 #[error(transparent)]
35 OpenSsl(#[from] openssl::error::ErrorStack),
36}
37
38pub fn make_tls(config: &tokio_postgres::Config) -> Result<MakeTlsConnector, TlsError> {
40 let mut builder = SslConnector::builder(SslMethod::tls_client())?;
41 let (verify_mode, verify_hostname) = match config.get_ssl_mode() {
47 SslMode::Disable | SslMode::Prefer => (SslVerifyMode::NONE, false),
48 SslMode::Require => match config.get_ssl_root_cert() {
49 Some(_) => (SslVerifyMode::PEER, false),
55 None => (SslVerifyMode::NONE, false),
56 },
57 SslMode::VerifyCa => (SslVerifyMode::PEER, false),
58 SslMode::VerifyFull => (SslVerifyMode::PEER, true),
59 _ => panic!("unexpected sslmode {:?}", config.get_ssl_mode()),
60 };
61
62 builder.set_verify(verify_mode);
64
65 match (config.get_ssl_cert(), config.get_ssl_key()) {
67 (Some(ssl_cert), Some(ssl_key)) => {
68 builder.set_certificate(&*X509::from_pem(ssl_cert)?)?;
69 builder.set_private_key(&*PKey::private_key_from_pem(ssl_key)?)?;
70 }
71 (None, Some(_)) => {
72 bail_generic!("must provide both sslcert and sslkey, but only provided sslkey")
73 }
74 (Some(_), None) => {
75 bail_generic!("must provide both sslcert and sslkey, but only provided sslcert")
76 }
77 _ => {}
78 }
79 if let Some(ssl_root_cert) = config.get_ssl_root_cert() {
80 for cert in X509::stack_from_pem(ssl_root_cert)? {
81 builder.cert_store_mut().add_cert(cert)?;
82 }
83 }
84
85 let mut tls_connector = MakeTlsConnector::new(builder.build());
86
87 match (verify_mode, verify_hostname) {
89 (SslVerifyMode::PEER, false) => tls_connector.set_callback(|connect, _| {
90 connect.set_verify_hostname(false);
91 Ok(())
92 }),
93 _ => {}
94 }
95
96 Ok(tls_connector)
97}
98
99pub struct Pkcs12Archive {
100 pub der: Vec<u8>,
101 pub pass: String,
102}
103
104impl Zeroize for Pkcs12Archive {
105 fn zeroize(&mut self) {
106 self.der.zeroize();
107 self.pass.zeroize();
108 }
109}
110
111impl Drop for Pkcs12Archive {
112 fn drop(&mut self) {
113 self.zeroize();
114 }
115}
116
117pub fn pkcs12der_from_pem(
119 key: &[u8],
120 cert: &[u8],
121) -> Result<Pkcs12Archive, openssl::error::ErrorStack> {
122 let mut buf = Zeroizing::new(Vec::new());
123 buf.extend(key);
124 buf.push(b'\n');
125 buf.extend(cert);
126 let pem = buf.as_slice();
127 let pkey = PKey::private_key_from_pem(pem)?;
128 let mut certs = Stack::new()?;
129
130 let mut cert_iter = X509::stack_from_pem(pem)?.into_iter();
140 let cert = match cert_iter.next() {
141 Some(cert) => cert,
142 None => X509::from_pem(pem)?,
143 };
144 for cert in cert_iter {
145 certs.push(cert)?;
146 }
147 let pass = String::new();
151 let friendly_name = "";
152 let der = Pkcs12::builder()
153 .name(friendly_name)
154 .pkey(&pkey)
155 .cert(&cert)
156 .ca(certs)
157 .build2(&pass)?
158 .to_der()?;
159 Ok(Pkcs12Archive { der, pass })
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[mz_ore::test]
167 fn pkcs12_archive_needs_drop() {
168 assert!(std::mem::needs_drop::<Pkcs12Archive>());
169 }
170
171 #[mz_ore::test]
172 fn pkcs12_archive_zeroize_clears_fields() {
173 let mut archive = Pkcs12Archive {
174 der: vec![0xDE, 0xAD, 0xBE, 0xEF],
175 pass: String::from("hunter2"),
176 };
177
178 archive.zeroize();
179
180 assert!(archive.der.is_empty(), "der was not zeroed");
181 assert!(archive.pass.is_empty(), "pass was not zeroed");
182 }
183
184 #[mz_ore::test]
185 fn pkcs12_archive_implements_zeroize() {
186 fn assert_zeroize<T: mz_ore::secure::Zeroize>() {}
187 assert_zeroize::<Pkcs12Archive>();
188 }
189}