native_tls/imp/
openssl.rs

1extern crate openssl;
2extern crate openssl_probe;
3
4use self::openssl::error::ErrorStack;
5use self::openssl::hash::MessageDigest;
6use self::openssl::nid::Nid;
7use self::openssl::pkcs12::Pkcs12;
8use self::openssl::pkey::{PKey, Private};
9use self::openssl::ssl::{
10    self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11    SslVerifyMode,
12};
13use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
14use self::openssl_probe::ProbeResult;
15use std::error;
16use std::fmt;
17use std::io;
18use std::sync::LazyLock;
19
20use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
21
22static PROBE_RESULT: LazyLock<ProbeResult> = LazyLock::new(openssl_probe::probe);
23
24#[cfg(have_min_max_version)]
25fn supported_protocols(
26    min: Option<Protocol>,
27    max: Option<Protocol>,
28    ctx: &mut SslContextBuilder,
29) -> Result<(), ErrorStack> {
30    use self::openssl::ssl::SslVersion;
31
32    fn cvt(p: Protocol) -> SslVersion {
33        match p {
34            Protocol::Sslv3 => SslVersion::SSL3,
35            Protocol::Tlsv10 => SslVersion::TLS1,
36            Protocol::Tlsv11 => SslVersion::TLS1_1,
37            Protocol::Tlsv12 => SslVersion::TLS1_2,
38        }
39    }
40
41    ctx.set_min_proto_version(min.map(cvt))?;
42    ctx.set_max_proto_version(max.map(cvt))?;
43
44    Ok(())
45}
46
47#[cfg(not(have_min_max_version))]
48fn supported_protocols(
49    min: Option<Protocol>,
50    max: Option<Protocol>,
51    ctx: &mut SslContextBuilder,
52) -> Result<(), ErrorStack> {
53    use self::openssl::ssl::SslOptions;
54
55    let no_ssl_mask = SslOptions::NO_SSLV2
56        | SslOptions::NO_SSLV3
57        | SslOptions::NO_TLSV1
58        | SslOptions::NO_TLSV1_1
59        | SslOptions::NO_TLSV1_2;
60
61    ctx.clear_options(no_ssl_mask);
62    let mut options = SslOptions::empty();
63    options |= match min {
64        None => SslOptions::empty(),
65        Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
66        Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
67        Some(Protocol::Tlsv11) => {
68            SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
69        }
70        Some(Protocol::Tlsv12) => {
71            SslOptions::NO_SSLV2
72                | SslOptions::NO_SSLV3
73                | SslOptions::NO_TLSV1
74                | SslOptions::NO_TLSV1_1
75        }
76    };
77    options |= match max {
78        None | Some(Protocol::Tlsv12) => SslOptions::empty(),
79        Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
80        Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
81        Some(Protocol::Sslv3) => {
82            SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
83        }
84    };
85
86    ctx.set_options(options);
87
88    Ok(())
89}
90
91#[cfg(target_os = "android")]
92fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
93    use std::fs;
94
95    if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
96        let certs = dir
97            .filter_map(|r| r.ok())
98            .filter_map(|e| fs::read(e.path()).ok())
99            .filter_map(|b| X509::from_pem(&b).ok());
100        for cert in certs {
101            if let Err(err) = connector.cert_store_mut().add_cert(cert) {
102                debug!("load_android_root_certs error: {:?}", err);
103            }
104        }
105    }
106
107    Ok(())
108}
109
110#[derive(Debug)]
111pub enum Error {
112    Normal(ErrorStack),
113    Ssl(ssl::Error, X509VerifyResult),
114    EmptyChain,
115    NotPkcs8,
116}
117
118impl error::Error for Error {
119    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
120        match *self {
121            Error::Normal(ref e) => error::Error::source(e),
122            Error::Ssl(ref e, _) => error::Error::source(e),
123            Error::EmptyChain => None,
124            Error::NotPkcs8 => None,
125        }
126    }
127}
128
129impl fmt::Display for Error {
130    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
131        match *self {
132            Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
133            Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
134            Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
135            Error::EmptyChain => write!(
136                fmt,
137                "at least one certificate must be provided to create an identity"
138            ),
139            Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
140        }
141    }
142}
143
144impl From<ErrorStack> for Error {
145    fn from(err: ErrorStack) -> Error {
146        Error::Normal(err)
147    }
148}
149
150#[derive(Clone)]
151pub struct Identity {
152    pkey: PKey<Private>,
153    cert: X509,
154    chain: Vec<X509>,
155}
156
157impl Identity {
158    pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
159        let pkcs12 = Pkcs12::from_der(buf)?;
160        let parsed = pkcs12.parse2(pass)?;
161        Ok(Identity {
162            pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
163            cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
164            // > The stack is the reverse of what you might expect due to the way
165            // > PKCS12_parse is implemented, so we need to load it backwards.
166            // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44
167            chain: parsed.ca.into_iter().flatten().rev().collect(),
168        })
169    }
170
171    pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
172        if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
173            return Err(Error::NotPkcs8);
174        }
175
176        let pkey = PKey::private_key_from_pem(key)?;
177        let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
178        let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
179        let chain = cert_chain.collect();
180        Ok(Identity { pkey, cert, chain })
181    }
182}
183
184#[derive(Clone)]
185pub struct Certificate(X509);
186
187impl Certificate {
188    pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
189        let cert = X509::from_der(buf)?;
190        Ok(Certificate(cert))
191    }
192
193    pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
194        let cert = X509::from_pem(buf)?;
195        Ok(Certificate(cert))
196    }
197
198    pub fn to_der(&self) -> Result<Vec<u8>, Error> {
199        let der = self.0.to_der()?;
200        Ok(der)
201    }
202}
203
204pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
205
206impl<S> fmt::Debug for MidHandshakeTlsStream<S>
207where
208    S: fmt::Debug,
209{
210    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
211        fmt::Debug::fmt(&self.0, fmt)
212    }
213}
214
215impl<S> MidHandshakeTlsStream<S> {
216    pub fn get_ref(&self) -> &S {
217        self.0.get_ref()
218    }
219
220    pub fn get_mut(&mut self) -> &mut S {
221        self.0.get_mut()
222    }
223}
224
225impl<S> MidHandshakeTlsStream<S>
226where
227    S: io::Read + io::Write,
228{
229    pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
230        match self.0.handshake() {
231            Ok(s) => Ok(TlsStream(s)),
232            Err(e) => Err(e.into()),
233        }
234    }
235}
236
237pub enum HandshakeError<S> {
238    Failure(Error),
239    WouldBlock(MidHandshakeTlsStream<S>),
240}
241
242impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
243    fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
244        match e {
245            ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
246            ssl::HandshakeError::Failure(e) => {
247                let v = e.ssl().verify_result();
248                HandshakeError::Failure(Error::Ssl(e.into_error(), v))
249            }
250            ssl::HandshakeError::WouldBlock(s) => {
251                HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
252            }
253        }
254    }
255}
256
257impl<S> From<ErrorStack> for HandshakeError<S> {
258    fn from(e: ErrorStack) -> HandshakeError<S> {
259        HandshakeError::Failure(e.into())
260    }
261}
262
263#[derive(Clone)]
264pub struct TlsConnector {
265    connector: SslConnector,
266    use_sni: bool,
267    accept_invalid_hostnames: bool,
268    accept_invalid_certs: bool,
269}
270
271impl TlsConnector {
272    pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
273        let mut connector = SslConnector::builder(SslMethod::tls())?;
274
275        // We need to load these separately so an error on one doesn't prevent the other from loading.
276        if let Some(cert_file) = &PROBE_RESULT.cert_file {
277            if let Err(e) = connector.load_verify_locations(Some(cert_file), None) {
278                debug!("load_verify_locations cert file error: {:?}", e);
279            }
280        }
281        if let Some(cert_dir) = &PROBE_RESULT.cert_dir {
282            if let Err(e) = connector.load_verify_locations(None, Some(cert_dir)) {
283                debug!("load_verify_locations cert dir error: {:?}", e);
284            }
285        }
286
287        if let Some(ref identity) = builder.identity {
288            connector.set_certificate(&identity.0.cert)?;
289            connector.set_private_key(&identity.0.pkey)?;
290            for cert in identity.0.chain.iter() {
291                // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
292                // specifies that "When sending a certificate chain, extra chain certificates are
293                // sent in order following the end entity certificate."
294                connector.add_extra_chain_cert(cert.to_owned())?;
295            }
296        }
297        supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
298
299        if builder.disable_built_in_roots {
300            connector.set_cert_store(X509StoreBuilder::new()?.build());
301        }
302
303        for cert in &builder.root_certificates {
304            if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
305                debug!("add_cert error: {:?}", err);
306            }
307        }
308
309        #[cfg(feature = "alpn")]
310        {
311            if !builder.alpn.is_empty() {
312                // Wire format is each alpn preceded by its length as a byte.
313                let mut alpn_wire_format = Vec::with_capacity(
314                    builder
315                        .alpn
316                        .iter()
317                        .map(|s| s.as_bytes().len())
318                        .sum::<usize>()
319                        + builder.alpn.len(),
320                );
321                for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
322                    alpn_wire_format.push(alpn.len() as u8);
323                    alpn_wire_format.extend(alpn);
324                }
325                connector.set_alpn_protos(&alpn_wire_format)?;
326            }
327        }
328
329        #[cfg(target_os = "android")]
330        load_android_root_certs(&mut connector)?;
331
332        Ok(TlsConnector {
333            connector: connector.build(),
334            use_sni: builder.use_sni,
335            accept_invalid_hostnames: builder.accept_invalid_hostnames,
336            accept_invalid_certs: builder.accept_invalid_certs,
337        })
338    }
339
340    pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
341    where
342        S: io::Read + io::Write,
343    {
344        let mut ssl = self
345            .connector
346            .configure()?
347            .use_server_name_indication(self.use_sni)
348            .verify_hostname(!self.accept_invalid_hostnames);
349        if self.accept_invalid_certs {
350            ssl.set_verify(SslVerifyMode::NONE);
351        }
352
353        let s = ssl.connect(domain, stream)?;
354        Ok(TlsStream(s))
355    }
356}
357
358impl fmt::Debug for TlsConnector {
359    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
360        fmt.debug_struct("TlsConnector")
361            // n.b. SslConnector is a newtype on SslContext which implements a noop Debug so it's omitted
362            .field("use_sni", &self.use_sni)
363            .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
364            .field("accept_invalid_certs", &self.accept_invalid_certs)
365            .finish()
366    }
367}
368
369#[derive(Clone)]
370pub struct TlsAcceptor(SslAcceptor);
371
372impl TlsAcceptor {
373    pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
374        let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
375        acceptor.set_private_key(&builder.identity.0.pkey)?;
376        acceptor.set_certificate(&builder.identity.0.cert)?;
377        for cert in builder.identity.0.chain.iter() {
378            // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html
379            // specifies that "When sending a certificate chain, extra chain certificates are
380            // sent in order following the end entity certificate."
381            acceptor.add_extra_chain_cert(cert.to_owned())?;
382        }
383        supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
384
385        Ok(TlsAcceptor(acceptor.build()))
386    }
387
388    pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
389    where
390        S: io::Read + io::Write,
391    {
392        let s = self.0.accept(stream)?;
393        Ok(TlsStream(s))
394    }
395}
396
397pub struct TlsStream<S>(ssl::SslStream<S>);
398
399impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
400    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
401        fmt::Debug::fmt(&self.0, fmt)
402    }
403}
404
405impl<S> TlsStream<S> {
406    pub fn get_ref(&self) -> &S {
407        self.0.get_ref()
408    }
409
410    pub fn get_mut(&mut self) -> &mut S {
411        self.0.get_mut()
412    }
413}
414
415impl<S: io::Read + io::Write> TlsStream<S> {
416    pub fn buffered_read_size(&self) -> Result<usize, Error> {
417        Ok(self.0.ssl().pending())
418    }
419
420    pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
421        Ok(self.0.ssl().peer_certificate().map(Certificate))
422    }
423
424    #[cfg(feature = "alpn")]
425    pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
426        Ok(self
427            .0
428            .ssl()
429            .selected_alpn_protocol()
430            .map(|alpn| alpn.to_vec()))
431    }
432
433    pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
434        let cert = if self.0.ssl().is_server() {
435            self.0.ssl().certificate().map(|x| x.to_owned())
436        } else {
437            self.0.ssl().peer_certificate()
438        };
439
440        let cert = match cert {
441            Some(cert) => cert,
442            None => return Ok(None),
443        };
444
445        let algo_nid = cert.signature_algorithm().object().nid();
446        let signature_algorithms = match algo_nid.signature_algorithms() {
447            Some(algs) => algs,
448            None => return Ok(None),
449        };
450
451        let md = match signature_algorithms.digest {
452            Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
453            nid => match MessageDigest::from_nid(nid) {
454                Some(md) => md,
455                None => return Ok(None),
456            },
457        };
458
459        let digest = cert.digest(md)?;
460
461        Ok(Some(digest.to_vec()))
462    }
463
464    pub fn shutdown(&mut self) -> io::Result<()> {
465        match self.0.shutdown() {
466            Ok(_) => Ok(()),
467            Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
468            Err(e) => Err(e
469                .into_io_error()
470                .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
471        }
472    }
473}
474
475impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
476    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
477        self.0.read(buf)
478    }
479}
480
481impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
482    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
483        self.0.write(buf)
484    }
485
486    fn flush(&mut self) -> io::Result<()> {
487        self.0.flush()
488    }
489}