1extern crate openssl;
2extern crate openssl_probe;
3
4use self::openssl::error::ErrorStack;
5use self::openssl::hash::MessageDigest;
6use self::openssl::nid::Nid;
7use self::openssl::pkcs12::Pkcs12;
8use self::openssl::pkey::{PKey, Private};
9use self::openssl::ssl::{
10 self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod,
11 SslVerifyMode,
12};
13use self::openssl::x509::{store::X509StoreBuilder, X509VerifyResult, X509};
14use self::openssl_probe::ProbeResult;
15use std::error;
16use std::fmt;
17use std::io;
18use std::sync::LazyLock;
19
20use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder};
21
22static PROBE_RESULT: LazyLock<ProbeResult> = LazyLock::new(openssl_probe::probe);
23
24#[cfg(have_min_max_version)]
25fn supported_protocols(
26 min: Option<Protocol>,
27 max: Option<Protocol>,
28 ctx: &mut SslContextBuilder,
29) -> Result<(), ErrorStack> {
30 use self::openssl::ssl::SslVersion;
31
32 fn cvt(p: Protocol) -> SslVersion {
33 match p {
34 Protocol::Sslv3 => SslVersion::SSL3,
35 Protocol::Tlsv10 => SslVersion::TLS1,
36 Protocol::Tlsv11 => SslVersion::TLS1_1,
37 Protocol::Tlsv12 => SslVersion::TLS1_2,
38 }
39 }
40
41 ctx.set_min_proto_version(min.map(cvt))?;
42 ctx.set_max_proto_version(max.map(cvt))?;
43
44 Ok(())
45}
46
47#[cfg(not(have_min_max_version))]
48fn supported_protocols(
49 min: Option<Protocol>,
50 max: Option<Protocol>,
51 ctx: &mut SslContextBuilder,
52) -> Result<(), ErrorStack> {
53 use self::openssl::ssl::SslOptions;
54
55 let no_ssl_mask = SslOptions::NO_SSLV2
56 | SslOptions::NO_SSLV3
57 | SslOptions::NO_TLSV1
58 | SslOptions::NO_TLSV1_1
59 | SslOptions::NO_TLSV1_2;
60
61 ctx.clear_options(no_ssl_mask);
62 let mut options = SslOptions::empty();
63 options |= match min {
64 None => SslOptions::empty(),
65 Some(Protocol::Sslv3) => SslOptions::NO_SSLV2,
66 Some(Protocol::Tlsv10) => SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3,
67 Some(Protocol::Tlsv11) => {
68 SslOptions::NO_SSLV2 | SslOptions::NO_SSLV3 | SslOptions::NO_TLSV1
69 }
70 Some(Protocol::Tlsv12) => {
71 SslOptions::NO_SSLV2
72 | SslOptions::NO_SSLV3
73 | SslOptions::NO_TLSV1
74 | SslOptions::NO_TLSV1_1
75 }
76 };
77 options |= match max {
78 None | Some(Protocol::Tlsv12) => SslOptions::empty(),
79 Some(Protocol::Tlsv11) => SslOptions::NO_TLSV1_2,
80 Some(Protocol::Tlsv10) => SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2,
81 Some(Protocol::Sslv3) => {
82 SslOptions::NO_TLSV1 | SslOptions::NO_TLSV1_1 | SslOptions::NO_TLSV1_2
83 }
84 };
85
86 ctx.set_options(options);
87
88 Ok(())
89}
90
91#[cfg(target_os = "android")]
92fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Error> {
93 use std::fs;
94
95 if let Ok(dir) = fs::read_dir("/system/etc/security/cacerts") {
96 let certs = dir
97 .filter_map(|r| r.ok())
98 .filter_map(|e| fs::read(e.path()).ok())
99 .filter_map(|b| X509::from_pem(&b).ok());
100 for cert in certs {
101 if let Err(err) = connector.cert_store_mut().add_cert(cert) {
102 debug!("load_android_root_certs error: {:?}", err);
103 }
104 }
105 }
106
107 Ok(())
108}
109
110#[derive(Debug)]
111pub enum Error {
112 Normal(ErrorStack),
113 Ssl(ssl::Error, X509VerifyResult),
114 EmptyChain,
115 NotPkcs8,
116}
117
118impl error::Error for Error {
119 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
120 match *self {
121 Error::Normal(ref e) => error::Error::source(e),
122 Error::Ssl(ref e, _) => error::Error::source(e),
123 Error::EmptyChain => None,
124 Error::NotPkcs8 => None,
125 }
126 }
127}
128
129impl fmt::Display for Error {
130 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
131 match *self {
132 Error::Normal(ref e) => fmt::Display::fmt(e, fmt),
133 Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt),
134 Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v),
135 Error::EmptyChain => write!(
136 fmt,
137 "at least one certificate must be provided to create an identity"
138 ),
139 Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"),
140 }
141 }
142}
143
144impl From<ErrorStack> for Error {
145 fn from(err: ErrorStack) -> Error {
146 Error::Normal(err)
147 }
148}
149
150#[derive(Clone)]
151pub struct Identity {
152 pkey: PKey<Private>,
153 cert: X509,
154 chain: Vec<X509>,
155}
156
157impl Identity {
158 pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result<Identity, Error> {
159 let pkcs12 = Pkcs12::from_der(buf)?;
160 let parsed = pkcs12.parse2(pass)?;
161 Ok(Identity {
162 pkey: parsed.pkey.ok_or_else(|| Error::EmptyChain)?,
163 cert: parsed.cert.ok_or_else(|| Error::EmptyChain)?,
164 chain: parsed.ca.into_iter().flatten().rev().collect(),
168 })
169 }
170
171 pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result<Identity, Error> {
172 if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") {
173 return Err(Error::NotPkcs8);
174 }
175
176 let pkey = PKey::private_key_from_pem(key)?;
177 let mut cert_chain = X509::stack_from_pem(buf)?.into_iter();
178 let cert = cert_chain.next().ok_or(Error::EmptyChain)?;
179 let chain = cert_chain.collect();
180 Ok(Identity { pkey, cert, chain })
181 }
182}
183
184#[derive(Clone)]
185pub struct Certificate(X509);
186
187impl Certificate {
188 pub fn from_der(buf: &[u8]) -> Result<Certificate, Error> {
189 let cert = X509::from_der(buf)?;
190 Ok(Certificate(cert))
191 }
192
193 pub fn from_pem(buf: &[u8]) -> Result<Certificate, Error> {
194 let cert = X509::from_pem(buf)?;
195 Ok(Certificate(cert))
196 }
197
198 pub fn to_der(&self) -> Result<Vec<u8>, Error> {
199 let der = self.0.to_der()?;
200 Ok(der)
201 }
202}
203
204pub struct MidHandshakeTlsStream<S>(MidHandshakeSslStream<S>);
205
206impl<S> fmt::Debug for MidHandshakeTlsStream<S>
207where
208 S: fmt::Debug,
209{
210 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
211 fmt::Debug::fmt(&self.0, fmt)
212 }
213}
214
215impl<S> MidHandshakeTlsStream<S> {
216 pub fn get_ref(&self) -> &S {
217 self.0.get_ref()
218 }
219
220 pub fn get_mut(&mut self) -> &mut S {
221 self.0.get_mut()
222 }
223}
224
225impl<S> MidHandshakeTlsStream<S>
226where
227 S: io::Read + io::Write,
228{
229 pub fn handshake(self) -> Result<TlsStream<S>, HandshakeError<S>> {
230 match self.0.handshake() {
231 Ok(s) => Ok(TlsStream(s)),
232 Err(e) => Err(e.into()),
233 }
234 }
235}
236
237pub enum HandshakeError<S> {
238 Failure(Error),
239 WouldBlock(MidHandshakeTlsStream<S>),
240}
241
242impl<S> From<ssl::HandshakeError<S>> for HandshakeError<S> {
243 fn from(e: ssl::HandshakeError<S>) -> HandshakeError<S> {
244 match e {
245 ssl::HandshakeError::SetupFailure(e) => HandshakeError::Failure(e.into()),
246 ssl::HandshakeError::Failure(e) => {
247 let v = e.ssl().verify_result();
248 HandshakeError::Failure(Error::Ssl(e.into_error(), v))
249 }
250 ssl::HandshakeError::WouldBlock(s) => {
251 HandshakeError::WouldBlock(MidHandshakeTlsStream(s))
252 }
253 }
254 }
255}
256
257impl<S> From<ErrorStack> for HandshakeError<S> {
258 fn from(e: ErrorStack) -> HandshakeError<S> {
259 HandshakeError::Failure(e.into())
260 }
261}
262
263#[derive(Clone)]
264pub struct TlsConnector {
265 connector: SslConnector,
266 use_sni: bool,
267 accept_invalid_hostnames: bool,
268 accept_invalid_certs: bool,
269}
270
271impl TlsConnector {
272 pub fn new(builder: &TlsConnectorBuilder) -> Result<TlsConnector, Error> {
273 let mut connector = SslConnector::builder(SslMethod::tls())?;
274
275 if let Some(cert_file) = &PROBE_RESULT.cert_file {
277 if let Err(e) = connector.load_verify_locations(Some(cert_file), None) {
278 debug!("load_verify_locations cert file error: {:?}", e);
279 }
280 }
281 if let Some(cert_dir) = &PROBE_RESULT.cert_dir {
282 if let Err(e) = connector.load_verify_locations(None, Some(cert_dir)) {
283 debug!("load_verify_locations cert dir error: {:?}", e);
284 }
285 }
286
287 if let Some(ref identity) = builder.identity {
288 connector.set_certificate(&identity.0.cert)?;
289 connector.set_private_key(&identity.0.pkey)?;
290 for cert in identity.0.chain.iter() {
291 connector.add_extra_chain_cert(cert.to_owned())?;
295 }
296 }
297 supported_protocols(builder.min_protocol, builder.max_protocol, &mut connector)?;
298
299 if builder.disable_built_in_roots {
300 connector.set_cert_store(X509StoreBuilder::new()?.build());
301 }
302
303 for cert in &builder.root_certificates {
304 if let Err(err) = connector.cert_store_mut().add_cert((cert.0).0.clone()) {
305 debug!("add_cert error: {:?}", err);
306 }
307 }
308
309 #[cfg(feature = "alpn")]
310 {
311 if !builder.alpn.is_empty() {
312 let mut alpn_wire_format = Vec::with_capacity(
314 builder
315 .alpn
316 .iter()
317 .map(|s| s.as_bytes().len())
318 .sum::<usize>()
319 + builder.alpn.len(),
320 );
321 for alpn in builder.alpn.iter().map(|s| s.as_bytes()) {
322 alpn_wire_format.push(alpn.len() as u8);
323 alpn_wire_format.extend(alpn);
324 }
325 connector.set_alpn_protos(&alpn_wire_format)?;
326 }
327 }
328
329 #[cfg(target_os = "android")]
330 load_android_root_certs(&mut connector)?;
331
332 Ok(TlsConnector {
333 connector: connector.build(),
334 use_sni: builder.use_sni,
335 accept_invalid_hostnames: builder.accept_invalid_hostnames,
336 accept_invalid_certs: builder.accept_invalid_certs,
337 })
338 }
339
340 pub fn connect<S>(&self, domain: &str, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
341 where
342 S: io::Read + io::Write,
343 {
344 let mut ssl = self
345 .connector
346 .configure()?
347 .use_server_name_indication(self.use_sni)
348 .verify_hostname(!self.accept_invalid_hostnames);
349 if self.accept_invalid_certs {
350 ssl.set_verify(SslVerifyMode::NONE);
351 }
352
353 let s = ssl.connect(domain, stream)?;
354 Ok(TlsStream(s))
355 }
356}
357
358impl fmt::Debug for TlsConnector {
359 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
360 fmt.debug_struct("TlsConnector")
361 .field("use_sni", &self.use_sni)
363 .field("accept_invalid_hostnames", &self.accept_invalid_hostnames)
364 .field("accept_invalid_certs", &self.accept_invalid_certs)
365 .finish()
366 }
367}
368
369#[derive(Clone)]
370pub struct TlsAcceptor(SslAcceptor);
371
372impl TlsAcceptor {
373 pub fn new(builder: &TlsAcceptorBuilder) -> Result<TlsAcceptor, Error> {
374 let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
375 acceptor.set_private_key(&builder.identity.0.pkey)?;
376 acceptor.set_certificate(&builder.identity.0.cert)?;
377 for cert in builder.identity.0.chain.iter() {
378 acceptor.add_extra_chain_cert(cert.to_owned())?;
382 }
383 supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?;
384
385 Ok(TlsAcceptor(acceptor.build()))
386 }
387
388 pub fn accept<S>(&self, stream: S) -> Result<TlsStream<S>, HandshakeError<S>>
389 where
390 S: io::Read + io::Write,
391 {
392 let s = self.0.accept(stream)?;
393 Ok(TlsStream(s))
394 }
395}
396
397pub struct TlsStream<S>(ssl::SslStream<S>);
398
399impl<S: fmt::Debug> fmt::Debug for TlsStream<S> {
400 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
401 fmt::Debug::fmt(&self.0, fmt)
402 }
403}
404
405impl<S> TlsStream<S> {
406 pub fn get_ref(&self) -> &S {
407 self.0.get_ref()
408 }
409
410 pub fn get_mut(&mut self) -> &mut S {
411 self.0.get_mut()
412 }
413}
414
415impl<S: io::Read + io::Write> TlsStream<S> {
416 pub fn buffered_read_size(&self) -> Result<usize, Error> {
417 Ok(self.0.ssl().pending())
418 }
419
420 pub fn peer_certificate(&self) -> Result<Option<Certificate>, Error> {
421 Ok(self.0.ssl().peer_certificate().map(Certificate))
422 }
423
424 #[cfg(feature = "alpn")]
425 pub fn negotiated_alpn(&self) -> Result<Option<Vec<u8>>, Error> {
426 Ok(self
427 .0
428 .ssl()
429 .selected_alpn_protocol()
430 .map(|alpn| alpn.to_vec()))
431 }
432
433 pub fn tls_server_end_point(&self) -> Result<Option<Vec<u8>>, Error> {
434 let cert = if self.0.ssl().is_server() {
435 self.0.ssl().certificate().map(|x| x.to_owned())
436 } else {
437 self.0.ssl().peer_certificate()
438 };
439
440 let cert = match cert {
441 Some(cert) => cert,
442 None => return Ok(None),
443 };
444
445 let algo_nid = cert.signature_algorithm().object().nid();
446 let signature_algorithms = match algo_nid.signature_algorithms() {
447 Some(algs) => algs,
448 None => return Ok(None),
449 };
450
451 let md = match signature_algorithms.digest {
452 Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(),
453 nid => match MessageDigest::from_nid(nid) {
454 Some(md) => md,
455 None => return Ok(None),
456 },
457 };
458
459 let digest = cert.digest(md)?;
460
461 Ok(Some(digest.to_vec()))
462 }
463
464 pub fn shutdown(&mut self) -> io::Result<()> {
465 match self.0.shutdown() {
466 Ok(_) => Ok(()),
467 Err(ref e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(()),
468 Err(e) => Err(e
469 .into_io_error()
470 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))),
471 }
472 }
473}
474
475impl<S: io::Read + io::Write> io::Read for TlsStream<S> {
476 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
477 self.0.read(buf)
478 }
479}
480
481impl<S: io::Read + io::Write> io::Write for TlsStream<S> {
482 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
483 self.0.write(buf)
484 }
485
486 fn flush(&mut self) -> io::Result<()> {
487 self.0.flush()
488 }
489}