1use crate::client::cache::{SessionCache, SessionKey};
3use crate::SslStream;
4use http::uri::Scheme;
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper::Uri;
7#[cfg(feature = "tokio")]
8use hyper_util::client::legacy::connect::HttpConnector;
9use hyper_util::client::legacy::connect::{Connected, Connection};
10use once_cell::sync::OnceCell;
11use openssl::error::ErrorStack;
12use openssl::ex_data::Index;
13use openssl::ssl::{
14 self, ConnectConfiguration, Ssl, SslConnector, SslConnectorBuilder, SslMethod,
15 SslSessionCacheMode,
16};
17use openssl::x509::X509VerifyResult;
18use parking_lot::Mutex;
19use pin_project::pin_project;
20use std::error::Error;
21use std::future::Future;
22use std::pin::Pin;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25use std::{fmt, io};
26use tower_layer::Layer;
27use tower_service::Service;
28
29type ConfigureCallback =
30 dyn Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send;
31
32fn key_index() -> Result<Index<Ssl, SessionKey>, ErrorStack> {
33 static IDX: OnceCell<Index<Ssl, SessionKey>> = OnceCell::new();
34 IDX.get_or_try_init(Ssl::new_ex_index).copied()
35}
36
37#[derive(Clone)]
38struct Inner {
39 ssl: SslConnector,
40 cache: Arc<Mutex<SessionCache>>,
41 callback: Option<Arc<ConfigureCallback>>,
42}
43
44pub struct HttpsLayer {
46 inner: Inner,
47}
48
49impl HttpsLayer {
50 pub fn new() -> Result<Self, ErrorStack> {
54 let mut ssl = SslConnector::builder(SslMethod::tls())?;
55
56 #[cfg(ossl102)]
57 ssl.set_alpn_protos(b"\x02h2\x08http/1.1")?;
58
59 Self::with_connector(ssl)
60 }
61
62 pub fn with_connector(mut ssl: SslConnectorBuilder) -> Result<Self, ErrorStack> {
66 let cache = Arc::new(Mutex::new(SessionCache::new()));
67
68 ssl.set_session_cache_mode(SslSessionCacheMode::CLIENT);
69
70 ssl.set_new_session_callback({
71 let cache = cache.clone();
72 move |ssl, session| {
73 if let Some(key) = key_index().ok().and_then(|idx| ssl.ex_data(idx)) {
74 cache.lock().insert(key.clone(), session);
75 }
76 }
77 });
78
79 ssl.set_remove_session_callback({
80 let cache = cache.clone();
81 move |_, session| cache.lock().remove(session)
82 });
83
84 Ok(HttpsLayer {
85 inner: Inner {
86 ssl: ssl.build(),
87 cache,
88 callback: None,
89 },
90 })
91 }
92
93 pub fn set_callback<F>(&mut self, callback: F)
95 where
96 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
97 {
98 self.inner.callback = Some(Arc::new(callback));
99 }
100}
101
102impl<S> Layer<S> for HttpsLayer {
103 type Service = HttpsConnector<S>;
104
105 fn layer(&self, inner: S) -> Self::Service {
106 HttpsConnector {
107 http: inner,
108 inner: self.inner.clone(),
109 }
110 }
111}
112
113#[derive(Clone)]
115pub struct HttpsConnector<T> {
116 http: T,
117 inner: Inner,
118}
119
120#[cfg(feature = "tokio")]
121impl HttpsConnector<HttpConnector> {
122 pub fn new() -> Result<Self, ErrorStack> {
127 let mut http = HttpConnector::new();
128 http.enforce_http(false);
129
130 HttpsLayer::new().map(|l| l.layer(http))
131 }
132}
133
134impl<S> HttpsConnector<S> {
135 pub fn with_connector(http: S, ssl: SslConnectorBuilder) -> Result<Self, ErrorStack> {
139 HttpsLayer::with_connector(ssl).map(|l| l.layer(http))
140 }
141
142 pub fn set_callback<F>(&mut self, callback: F)
144 where
145 F: Fn(&mut ConnectConfiguration, &Uri) -> Result<(), ErrorStack> + 'static + Sync + Send,
146 {
147 self.inner.callback = Some(Arc::new(callback));
148 }
149}
150
151impl<S> Service<Uri> for HttpsConnector<S>
152where
153 S: Service<Uri>,
154 S::Future: 'static + Send,
155 S::Error: Into<Box<dyn Error + Sync + Send>>,
156 S::Response: Read + Write + Unpin + Connection + Send,
157{
158 type Response = MaybeHttpsStream<S::Response>;
159 type Error = Box<dyn Error + Sync + Send>;
160 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
161
162 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163 self.http.poll_ready(cx).map_err(Into::into)
164 }
165
166 fn call(&mut self, req: Uri) -> Self::Future {
167 let tls_setup = if req.scheme() == Some(&Scheme::HTTPS) {
168 Some((self.inner.clone(), req.clone()))
169 } else {
170 None
171 };
172
173 let connect = self.http.call(req);
174
175 Box::pin(async move {
176 let conn = connect.await.map_err(Into::into)?;
177
178 let Some((inner, uri)) = tls_setup else {
179 return Ok(MaybeHttpsStream::Http(conn));
180 };
181
182 let Some(host) = uri.host() else {
183 return Err("URI missing host".into());
184 };
185
186 let mut config = inner.ssl.configure()?;
187
188 if let Some(callback) = &inner.callback {
189 callback(&mut config, &uri)?;
190 }
191
192 let key = SessionKey {
193 host: host.to_string(),
194 port: uri.port_u16().unwrap_or(443),
195 };
196
197 if let Some(session) = inner.cache.lock().get(&key) {
198 unsafe {
199 config.set_session(&session)?;
200 }
201 }
202
203 let idx = key_index()?;
204 config.set_ex_data(idx, key);
205
206 let ssl = config.into_ssl(host)?;
207
208 let mut stream = SslStream::new(ssl, conn)?;
209
210 match Pin::new(&mut stream).connect().await {
211 Ok(()) => Ok(MaybeHttpsStream::Https(stream)),
212 Err(error) => Err(Box::new(ConnectError {
213 error,
214 verify_result: stream.ssl().verify_result(),
215 }) as _),
216 }
217 })
218 }
219}
220
221#[derive(Debug)]
222struct ConnectError {
223 error: ssl::Error,
224 verify_result: X509VerifyResult,
225}
226
227impl fmt::Display for ConnectError {
228 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
229 fmt::Display::fmt(&self.error, fmt)?;
230
231 if self.verify_result != X509VerifyResult::OK {
232 fmt.write_str(": ")?;
233 fmt::Display::fmt(&self.verify_result, fmt)?;
234 }
235
236 Ok(())
237 }
238}
239
240impl Error for ConnectError {
241 fn source(&self) -> Option<&(dyn Error + 'static)> {
242 Some(&self.error)
243 }
244}
245
246#[pin_project(project = MaybeHttpsStreamProj)]
248pub enum MaybeHttpsStream<T> {
249 Http(#[pin] T),
251 Https(#[pin] SslStream<T>),
253}
254
255impl<T> Read for MaybeHttpsStream<T>
256where
257 T: Read + Write,
258{
259 fn poll_read(
260 self: Pin<&mut Self>,
261 cx: &mut Context<'_>,
262 buf: ReadBufCursor<'_>,
263 ) -> Poll<io::Result<()>> {
264 match self.project() {
265 MaybeHttpsStreamProj::Http(s) => s.poll_read(cx, buf),
266 MaybeHttpsStreamProj::Https(s) => s.poll_read(cx, buf),
267 }
268 }
269}
270
271impl<T> Write for MaybeHttpsStream<T>
272where
273 T: Read + Write,
274{
275 fn poll_write(
276 self: Pin<&mut Self>,
277 cx: &mut Context<'_>,
278 buf: &[u8],
279 ) -> Poll<io::Result<usize>> {
280 match self.project() {
281 MaybeHttpsStreamProj::Http(s) => s.poll_write(cx, buf),
282 MaybeHttpsStreamProj::Https(s) => s.poll_write(cx, buf),
283 }
284 }
285
286 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
287 match self.project() {
288 MaybeHttpsStreamProj::Http(s) => s.poll_flush(cx),
289 MaybeHttpsStreamProj::Https(s) => s.poll_flush(cx),
290 }
291 }
292
293 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294 match self.project() {
295 MaybeHttpsStreamProj::Http(s) => s.poll_shutdown(cx),
296 MaybeHttpsStreamProj::Https(s) => s.poll_shutdown(cx),
297 }
298 }
299}
300
301impl<T> Connection for MaybeHttpsStream<T>
302where
303 T: Connection,
304{
305 fn connected(&self) -> Connected {
306 match self {
307 MaybeHttpsStream::Http(s) => s.connected(),
308 MaybeHttpsStream::Https(s) => {
309 let mut connected = s.get_ref().connected();
310 #[cfg(ossl102)]
311 if s.ssl().selected_alpn_protocol() == Some(b"h2") {
312 connected = connected.negotiated_h2();
313 }
314 connected
315 }
316 }
317 }
318}