1#[cfg(feature = "__tls")]
2use http::header::HeaderValue;
3use http::uri::{Authority, Scheme};
4use http::Uri;
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7#[cfg(any(feature = "socks", feature = "__tls"))]
8use hyper_util::rt::TokioIo;
9#[cfg(feature = "default-tls")]
10use native_tls_crate::{TlsConnector, TlsConnectorBuilder};
11use pin_project_lite::pin_project;
12use tower::util::{BoxCloneSyncServiceLayer, MapRequestLayer};
13use tower::{timeout::TimeoutLayer, util::BoxCloneSyncService, ServiceBuilder};
14use tower_service::Service;
15
16use std::future::Future;
17use std::io::{self, IoSlice};
18use std::net::IpAddr;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::{Context, Poll};
22use std::time::Duration;
23
24#[cfg(feature = "default-tls")]
25use self::native_tls_conn::NativeTlsConn;
26#[cfg(feature = "__rustls")]
27use self::rustls_tls_conn::RustlsTlsConn;
28use crate::dns::DynResolver;
29use crate::error::{cast_to_internal_error, BoxError};
30use crate::proxy::{Proxy, ProxyScheme};
31use sealed::{Conn, Unnameable};
32
33pub(crate) type HttpConnector = hyper_util::client::legacy::connect::HttpConnector<DynResolver>;
34
35#[derive(Clone)]
36pub(crate) enum Connector {
37 Simple(ConnectorService),
39 WithLayers(BoxCloneSyncService<Unnameable, Conn, BoxError>),
42}
43
44impl Service<Uri> for Connector {
45 type Response = Conn;
46 type Error = BoxError;
47 type Future = Connecting;
48
49 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
50 match self {
51 Connector::Simple(service) => service.poll_ready(cx),
52 Connector::WithLayers(service) => service.poll_ready(cx),
53 }
54 }
55
56 fn call(&mut self, dst: Uri) -> Self::Future {
57 match self {
58 Connector::Simple(service) => service.call(dst),
59 Connector::WithLayers(service) => service.call(Unnameable(dst)),
60 }
61 }
62}
63
64pub(crate) type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>;
65
66pub(crate) type BoxedConnectorLayer =
67 BoxCloneSyncServiceLayer<BoxedConnectorService, Unnameable, Conn, BoxError>;
68
69pub(crate) struct ConnectorBuilder {
70 inner: Inner,
71 proxies: Arc<Vec<Proxy>>,
72 verbose: verbose::Wrapper,
73 timeout: Option<Duration>,
74 #[cfg(feature = "__tls")]
75 nodelay: bool,
76 #[cfg(feature = "__tls")]
77 tls_info: bool,
78 #[cfg(feature = "__tls")]
79 user_agent: Option<HeaderValue>,
80}
81
82impl ConnectorBuilder {
83 pub(crate) fn build(self, layers: Vec<BoxedConnectorLayer>) -> Connector
84where {
85 let mut base_service = ConnectorService {
87 inner: self.inner,
88 proxies: self.proxies,
89 verbose: self.verbose,
90 #[cfg(feature = "__tls")]
91 nodelay: self.nodelay,
92 #[cfg(feature = "__tls")]
93 tls_info: self.tls_info,
94 #[cfg(feature = "__tls")]
95 user_agent: self.user_agent,
96 simple_timeout: None,
97 };
98
99 if layers.is_empty() {
100 base_service.simple_timeout = self.timeout;
102 return Connector::Simple(base_service);
103 }
104
105 let unnameable_service = ServiceBuilder::new()
109 .layer(MapRequestLayer::new(|request: Unnameable| request.0))
110 .service(base_service);
111 let mut service = BoxCloneSyncService::new(unnameable_service);
112
113 for layer in layers {
114 service = ServiceBuilder::new().layer(layer).service(service);
115 }
116
117 match self.timeout {
121 Some(timeout) => {
122 let service = ServiceBuilder::new()
123 .layer(TimeoutLayer::new(timeout))
124 .service(service);
125 let service = ServiceBuilder::new()
126 .map_err(|error: BoxError| cast_to_internal_error(error))
127 .service(service);
128 let service = BoxCloneSyncService::new(service);
129
130 Connector::WithLayers(service)
131 }
132 None => {
133 let service = ServiceBuilder::new().service(service);
137 let service = ServiceBuilder::new()
138 .map_err(|error: BoxError| cast_to_internal_error(error))
139 .service(service);
140 let service = BoxCloneSyncService::new(service);
141 Connector::WithLayers(service)
142 }
143 }
144 }
145
146 #[cfg(not(feature = "__tls"))]
147 pub(crate) fn new<T>(
148 mut http: HttpConnector,
149 proxies: Arc<Vec<Proxy>>,
150 local_addr: T,
151 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
152 interface: Option<&str>,
153 nodelay: bool,
154 ) -> ConnectorBuilder
155 where
156 T: Into<Option<IpAddr>>,
157 {
158 http.set_local_address(local_addr.into());
159 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
160 if let Some(interface) = interface {
161 http.set_interface(interface.to_owned());
162 }
163 http.set_nodelay(nodelay);
164
165 ConnectorBuilder {
166 inner: Inner::Http(http),
167 proxies,
168 verbose: verbose::OFF,
169 timeout: None,
170 }
171 }
172
173 #[cfg(feature = "default-tls")]
174 pub(crate) fn new_default_tls<T>(
175 http: HttpConnector,
176 tls: TlsConnectorBuilder,
177 proxies: Arc<Vec<Proxy>>,
178 user_agent: Option<HeaderValue>,
179 local_addr: T,
180 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
181 interface: Option<&str>,
182 nodelay: bool,
183 tls_info: bool,
184 ) -> crate::Result<ConnectorBuilder>
185 where
186 T: Into<Option<IpAddr>>,
187 {
188 let tls = tls.build().map_err(crate::error::builder)?;
189 Ok(Self::from_built_default_tls(
190 http,
191 tls,
192 proxies,
193 user_agent,
194 local_addr,
195 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
196 interface,
197 nodelay,
198 tls_info,
199 ))
200 }
201
202 #[cfg(feature = "default-tls")]
203 pub(crate) fn from_built_default_tls<T>(
204 mut http: HttpConnector,
205 tls: TlsConnector,
206 proxies: Arc<Vec<Proxy>>,
207 user_agent: Option<HeaderValue>,
208 local_addr: T,
209 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
210 interface: Option<&str>,
211 nodelay: bool,
212 tls_info: bool,
213 ) -> ConnectorBuilder
214 where
215 T: Into<Option<IpAddr>>,
216 {
217 http.set_local_address(local_addr.into());
218 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
219 if let Some(interface) = interface {
220 http.set_interface(interface);
221 }
222 http.set_nodelay(nodelay);
223 http.enforce_http(false);
224
225 ConnectorBuilder {
226 inner: Inner::DefaultTls(http, tls),
227 proxies,
228 verbose: verbose::OFF,
229 nodelay,
230 tls_info,
231 user_agent,
232 timeout: None,
233 }
234 }
235
236 #[cfg(feature = "__rustls")]
237 pub(crate) fn new_rustls_tls<T>(
238 mut http: HttpConnector,
239 tls: rustls::ClientConfig,
240 proxies: Arc<Vec<Proxy>>,
241 user_agent: Option<HeaderValue>,
242 local_addr: T,
243 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
244 interface: Option<&str>,
245 nodelay: bool,
246 tls_info: bool,
247 ) -> ConnectorBuilder
248 where
249 T: Into<Option<IpAddr>>,
250 {
251 http.set_local_address(local_addr.into());
252 #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
253 if let Some(interface) = interface {
254 http.set_interface(interface.to_owned());
255 }
256 http.set_nodelay(nodelay);
257 http.enforce_http(false);
258
259 let (tls, tls_proxy) = if proxies.is_empty() {
260 let tls = Arc::new(tls);
261 (tls.clone(), tls)
262 } else {
263 let mut tls_proxy = tls.clone();
264 tls_proxy.alpn_protocols.clear();
265 (Arc::new(tls), Arc::new(tls_proxy))
266 };
267
268 ConnectorBuilder {
269 inner: Inner::RustlsTls {
270 http,
271 tls,
272 tls_proxy,
273 },
274 proxies,
275 verbose: verbose::OFF,
276 nodelay,
277 tls_info,
278 user_agent,
279 timeout: None,
280 }
281 }
282
283 pub(crate) fn set_timeout(&mut self, timeout: Option<Duration>) {
284 self.timeout = timeout;
285 }
286
287 pub(crate) fn set_verbose(&mut self, enabled: bool) {
288 self.verbose.0 = enabled;
289 }
290
291 pub(crate) fn set_keepalive(&mut self, dur: Option<Duration>) {
292 match &mut self.inner {
293 #[cfg(feature = "default-tls")]
294 Inner::DefaultTls(http, _tls) => http.set_keepalive(dur),
295 #[cfg(feature = "__rustls")]
296 Inner::RustlsTls { http, .. } => http.set_keepalive(dur),
297 #[cfg(not(feature = "__tls"))]
298 Inner::Http(http) => http.set_keepalive(dur),
299 }
300 }
301}
302
303#[allow(missing_debug_implementations)]
304#[derive(Clone)]
305pub(crate) struct ConnectorService {
306 inner: Inner,
307 proxies: Arc<Vec<Proxy>>,
308 verbose: verbose::Wrapper,
309 simple_timeout: Option<Duration>,
314 #[cfg(feature = "__tls")]
315 nodelay: bool,
316 #[cfg(feature = "__tls")]
317 tls_info: bool,
318 #[cfg(feature = "__tls")]
319 user_agent: Option<HeaderValue>,
320}
321
322#[derive(Clone)]
323enum Inner {
324 #[cfg(not(feature = "__tls"))]
325 Http(HttpConnector),
326 #[cfg(feature = "default-tls")]
327 DefaultTls(HttpConnector, TlsConnector),
328 #[cfg(feature = "__rustls")]
329 RustlsTls {
330 http: HttpConnector,
331 tls: Arc<rustls::ClientConfig>,
332 tls_proxy: Arc<rustls::ClientConfig>,
333 },
334}
335
336impl ConnectorService {
337 #[cfg(feature = "socks")]
338 async fn connect_socks(&self, dst: Uri, proxy: ProxyScheme) -> Result<Conn, BoxError> {
339 let dns = match proxy {
340 ProxyScheme::Socks4 {
341 remote_dns: false, ..
342 } => socks::DnsResolve::Local,
343 ProxyScheme::Socks4 {
344 remote_dns: true, ..
345 } => socks::DnsResolve::Proxy,
346 ProxyScheme::Socks5 {
347 remote_dns: false, ..
348 } => socks::DnsResolve::Local,
349 ProxyScheme::Socks5 {
350 remote_dns: true, ..
351 } => socks::DnsResolve::Proxy,
352 ProxyScheme::Http { .. } | ProxyScheme::Https { .. } => {
353 unreachable!("connect_socks is only called for socks proxies");
354 }
355 };
356
357 match &self.inner {
358 #[cfg(feature = "default-tls")]
359 Inner::DefaultTls(_http, tls) => {
360 if dst.scheme() == Some(&Scheme::HTTPS) {
361 let host = dst.host().ok_or("no host in url")?.to_string();
362 let conn = socks::connect(proxy, dst, dns).await?;
363 let conn = TokioIo::new(conn);
364 let conn = TokioIo::new(conn);
365 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
366 let io = tls_connector.connect(&host, conn).await?;
367 let io = TokioIo::new(io);
368 return Ok(Conn {
369 inner: self.verbose.wrap(NativeTlsConn { inner: io }),
370 is_proxy: false,
371 tls_info: self.tls_info,
372 });
373 }
374 }
375 #[cfg(feature = "__rustls")]
376 Inner::RustlsTls { tls, .. } => {
377 if dst.scheme() == Some(&Scheme::HTTPS) {
378 use std::convert::TryFrom;
379 use tokio_rustls::TlsConnector as RustlsConnector;
380
381 let tls = tls.clone();
382 let host = dst.host().ok_or("no host in url")?.to_string();
383 let conn = socks::connect(proxy, dst, dns).await?;
384 let conn = TokioIo::new(conn);
385 let conn = TokioIo::new(conn);
386 let server_name =
387 rustls_pki_types::ServerName::try_from(host.as_str().to_owned())
388 .map_err(|_| "Invalid Server Name")?;
389 let io = RustlsConnector::from(tls)
390 .connect(server_name, conn)
391 .await?;
392 let io = TokioIo::new(io);
393 return Ok(Conn {
394 inner: self.verbose.wrap(RustlsTlsConn { inner: io }),
395 is_proxy: false,
396 tls_info: false,
397 });
398 }
399 }
400 #[cfg(not(feature = "__tls"))]
401 Inner::Http(_) => (),
402 }
403
404 socks::connect(proxy, dst, dns).await.map(|tcp| Conn {
405 inner: self.verbose.wrap(TokioIo::new(tcp)),
406 is_proxy: false,
407 tls_info: false,
408 })
409 }
410
411 async fn connect_with_maybe_proxy(self, dst: Uri, is_proxy: bool) -> Result<Conn, BoxError> {
412 match self.inner {
413 #[cfg(not(feature = "__tls"))]
414 Inner::Http(mut http) => {
415 let io = http.call(dst).await?;
416 Ok(Conn {
417 inner: self.verbose.wrap(io),
418 is_proxy,
419 tls_info: false,
420 })
421 }
422 #[cfg(feature = "default-tls")]
423 Inner::DefaultTls(http, tls) => {
424 let mut http = http.clone();
425
426 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
430 http.set_nodelay(true);
431 }
432
433 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
434 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
435 let io = http.call(dst).await?;
436
437 if let hyper_tls::MaybeHttpsStream::Https(stream) = io {
438 if !self.nodelay {
439 stream
440 .inner()
441 .get_ref()
442 .get_ref()
443 .get_ref()
444 .inner()
445 .inner()
446 .set_nodelay(false)?;
447 }
448 Ok(Conn {
449 inner: self.verbose.wrap(NativeTlsConn { inner: stream }),
450 is_proxy,
451 tls_info: self.tls_info,
452 })
453 } else {
454 Ok(Conn {
455 inner: self.verbose.wrap(io),
456 is_proxy,
457 tls_info: false,
458 })
459 }
460 }
461 #[cfg(feature = "__rustls")]
462 Inner::RustlsTls { http, tls, .. } => {
463 let mut http = http.clone();
464
465 if !self.nodelay && (dst.scheme() == Some(&Scheme::HTTPS)) {
469 http.set_nodelay(true);
470 }
471
472 let mut http = hyper_rustls::HttpsConnector::from((http, tls.clone()));
473 let io = http.call(dst).await?;
474
475 if let hyper_rustls::MaybeHttpsStream::Https(stream) = io {
476 if !self.nodelay {
477 let (io, _) = stream.inner().get_ref();
478 io.inner().inner().set_nodelay(false)?;
479 }
480 Ok(Conn {
481 inner: self.verbose.wrap(RustlsTlsConn { inner: stream }),
482 is_proxy,
483 tls_info: self.tls_info,
484 })
485 } else {
486 Ok(Conn {
487 inner: self.verbose.wrap(io),
488 is_proxy,
489 tls_info: false,
490 })
491 }
492 }
493 }
494 }
495
496 async fn connect_via_proxy(
497 self,
498 dst: Uri,
499 proxy_scheme: ProxyScheme,
500 ) -> Result<Conn, BoxError> {
501 log::debug!("proxy({proxy_scheme:?}) intercepts '{dst:?}'");
502
503 let (proxy_dst, _auth) = match proxy_scheme {
504 ProxyScheme::Http { host, auth } => (into_uri(Scheme::HTTP, host), auth),
505 ProxyScheme::Https { host, auth } => (into_uri(Scheme::HTTPS, host), auth),
506 #[cfg(feature = "socks")]
507 ProxyScheme::Socks4 { .. } => return self.connect_socks(dst, proxy_scheme).await,
508 #[cfg(feature = "socks")]
509 ProxyScheme::Socks5 { .. } => return self.connect_socks(dst, proxy_scheme).await,
510 };
511
512 #[cfg(feature = "__tls")]
513 let auth = _auth;
514
515 match &self.inner {
516 #[cfg(feature = "default-tls")]
517 Inner::DefaultTls(http, tls) => {
518 if dst.scheme() == Some(&Scheme::HTTPS) {
519 let host = dst.host().to_owned();
520 let port = dst.port().map(|p| p.as_u16()).unwrap_or(443);
521 let http = http.clone();
522 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
523 let mut http = hyper_tls::HttpsConnector::from((http, tls_connector));
524 let conn = http.call(proxy_dst).await?;
525 log::trace!("tunneling HTTPS over proxy");
526 let tunneled = tunnel(
527 conn,
528 host.ok_or("no host in url")?.to_string(),
529 port,
530 self.user_agent.clone(),
531 auth,
532 )
533 .await?;
534 let tls_connector = tokio_native_tls::TlsConnector::from(tls.clone());
535 let io = tls_connector
536 .connect(host.ok_or("no host in url")?, TokioIo::new(tunneled))
537 .await?;
538 return Ok(Conn {
539 inner: self.verbose.wrap(NativeTlsConn {
540 inner: TokioIo::new(io),
541 }),
542 is_proxy: false,
543 tls_info: false,
544 });
545 }
546 }
547 #[cfg(feature = "__rustls")]
548 Inner::RustlsTls {
549 http,
550 tls,
551 tls_proxy,
552 } => {
553 if dst.scheme() == Some(&Scheme::HTTPS) {
554 use rustls_pki_types::ServerName;
555 use std::convert::TryFrom;
556 use tokio_rustls::TlsConnector as RustlsConnector;
557
558 let host = dst.host().ok_or("no host in url")?.to_string();
559 let port = dst.port().map(|r| r.as_u16()).unwrap_or(443);
560 let http = http.clone();
561 let mut http = hyper_rustls::HttpsConnector::from((http, tls_proxy.clone()));
562 let tls = tls.clone();
563 let conn = http.call(proxy_dst).await?;
564 log::trace!("tunneling HTTPS over proxy");
565 let maybe_server_name = ServerName::try_from(host.as_str().to_owned())
566 .map_err(|_| "Invalid Server Name");
567 let tunneled = tunnel(conn, host, port, self.user_agent.clone(), auth).await?;
568 let server_name = maybe_server_name?;
569 let io = RustlsConnector::from(tls)
570 .connect(server_name, TokioIo::new(tunneled))
571 .await?;
572
573 return Ok(Conn {
574 inner: self.verbose.wrap(RustlsTlsConn {
575 inner: TokioIo::new(io),
576 }),
577 is_proxy: false,
578 tls_info: false,
579 });
580 }
581 }
582 #[cfg(not(feature = "__tls"))]
583 Inner::Http(_) => (),
584 }
585
586 self.connect_with_maybe_proxy(proxy_dst, true).await
587 }
588}
589
590fn into_uri(scheme: Scheme, host: Authority) -> Uri {
591 http::Uri::builder()
593 .scheme(scheme)
594 .authority(host)
595 .path_and_query(http::uri::PathAndQuery::from_static("/"))
596 .build()
597 .expect("scheme and authority is valid Uri")
598}
599
600async fn with_timeout<T, F>(f: F, timeout: Option<Duration>) -> Result<T, BoxError>
601where
602 F: Future<Output = Result<T, BoxError>>,
603{
604 if let Some(to) = timeout {
605 match tokio::time::timeout(to, f).await {
606 Err(_elapsed) => Err(Box::new(crate::error::TimedOut) as BoxError),
607 Ok(Ok(try_res)) => Ok(try_res),
608 Ok(Err(e)) => Err(e),
609 }
610 } else {
611 f.await
612 }
613}
614
615impl Service<Uri> for ConnectorService {
616 type Response = Conn;
617 type Error = BoxError;
618 type Future = Connecting;
619
620 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
621 Poll::Ready(Ok(()))
622 }
623
624 fn call(&mut self, dst: Uri) -> Self::Future {
625 log::debug!("starting new connection: {dst:?}");
626 let timeout = self.simple_timeout;
627 for prox in self.proxies.iter() {
628 if let Some(proxy_scheme) = prox.intercept(&dst) {
629 return Box::pin(with_timeout(
630 self.clone().connect_via_proxy(dst, proxy_scheme),
631 timeout,
632 ));
633 }
634 }
635
636 Box::pin(with_timeout(
637 self.clone().connect_with_maybe_proxy(dst, false),
638 timeout,
639 ))
640 }
641}
642
643#[cfg(feature = "__tls")]
644trait TlsInfoFactory {
645 fn tls_info(&self) -> Option<crate::tls::TlsInfo>;
646}
647
648#[cfg(feature = "__tls")]
649impl TlsInfoFactory for tokio::net::TcpStream {
650 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
651 None
652 }
653}
654
655#[cfg(feature = "__tls")]
656impl<T: TlsInfoFactory> TlsInfoFactory for TokioIo<T> {
657 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
658 self.inner().tls_info()
659 }
660}
661
662#[cfg(feature = "default-tls")]
663impl TlsInfoFactory for tokio_native_tls::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
664 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
665 let peer_certificate = self
666 .get_ref()
667 .peer_certificate()
668 .ok()
669 .flatten()
670 .and_then(|c| c.to_der().ok());
671 Some(crate::tls::TlsInfo { peer_certificate })
672 }
673}
674
675#[cfg(feature = "default-tls")]
676impl TlsInfoFactory
677 for tokio_native_tls::TlsStream<
678 TokioIo<hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
679 >
680{
681 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
682 let peer_certificate = self
683 .get_ref()
684 .peer_certificate()
685 .ok()
686 .flatten()
687 .and_then(|c| c.to_der().ok());
688 Some(crate::tls::TlsInfo { peer_certificate })
689 }
690}
691
692#[cfg(feature = "default-tls")]
693impl TlsInfoFactory for hyper_tls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
694 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
695 match self {
696 hyper_tls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
697 hyper_tls::MaybeHttpsStream::Http(_) => None,
698 }
699 }
700}
701
702#[cfg(feature = "__rustls")]
703impl TlsInfoFactory for tokio_rustls::client::TlsStream<TokioIo<TokioIo<tokio::net::TcpStream>>> {
704 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
705 let peer_certificate = self
706 .get_ref()
707 .1
708 .peer_certificates()
709 .and_then(|certs| certs.first())
710 .map(|c| c.to_vec());
711 Some(crate::tls::TlsInfo { peer_certificate })
712 }
713}
714
715#[cfg(feature = "__rustls")]
716impl TlsInfoFactory
717 for tokio_rustls::client::TlsStream<
718 TokioIo<hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>>>,
719 >
720{
721 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
722 let peer_certificate = self
723 .get_ref()
724 .1
725 .peer_certificates()
726 .and_then(|certs| certs.first())
727 .map(|c| c.to_vec());
728 Some(crate::tls::TlsInfo { peer_certificate })
729 }
730}
731
732#[cfg(feature = "__rustls")]
733impl TlsInfoFactory for hyper_rustls::MaybeHttpsStream<TokioIo<tokio::net::TcpStream>> {
734 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
735 match self {
736 hyper_rustls::MaybeHttpsStream::Https(tls) => tls.tls_info(),
737 hyper_rustls::MaybeHttpsStream::Http(_) => None,
738 }
739 }
740}
741
742pub(crate) trait AsyncConn:
743 Read + Write + Connection + Send + Sync + Unpin + 'static
744{
745}
746
747impl<T: Read + Write + Connection + Send + Sync + Unpin + 'static> AsyncConn for T {}
748
749#[cfg(feature = "__tls")]
750trait AsyncConnWithInfo: AsyncConn + TlsInfoFactory {}
751#[cfg(not(feature = "__tls"))]
752trait AsyncConnWithInfo: AsyncConn {}
753
754#[cfg(feature = "__tls")]
755impl<T: AsyncConn + TlsInfoFactory> AsyncConnWithInfo for T {}
756#[cfg(not(feature = "__tls"))]
757impl<T: AsyncConn> AsyncConnWithInfo for T {}
758
759type BoxConn = Box<dyn AsyncConnWithInfo>;
760
761pub(crate) mod sealed {
762 use super::*;
763 #[derive(Debug)]
764 pub struct Unnameable(pub(super) Uri);
765
766 pin_project! {
767 #[allow(missing_debug_implementations)]
772 pub struct Conn {
773 #[pin]
774 pub(super)inner: BoxConn,
775 pub(super) is_proxy: bool,
776 pub(super) tls_info: bool,
778 }
779 }
780
781 impl Connection for Conn {
782 fn connected(&self) -> Connected {
783 let connected = self.inner.connected().proxy(self.is_proxy);
784 #[cfg(feature = "__tls")]
785 if self.tls_info {
786 if let Some(tls_info) = self.inner.tls_info() {
787 connected.extra(tls_info)
788 } else {
789 connected
790 }
791 } else {
792 connected
793 }
794 #[cfg(not(feature = "__tls"))]
795 connected
796 }
797 }
798
799 impl Read for Conn {
800 fn poll_read(
801 self: Pin<&mut Self>,
802 cx: &mut Context,
803 buf: ReadBufCursor<'_>,
804 ) -> Poll<io::Result<()>> {
805 let this = self.project();
806 Read::poll_read(this.inner, cx, buf)
807 }
808 }
809
810 impl Write for Conn {
811 fn poll_write(
812 self: Pin<&mut Self>,
813 cx: &mut Context,
814 buf: &[u8],
815 ) -> Poll<Result<usize, io::Error>> {
816 let this = self.project();
817 Write::poll_write(this.inner, cx, buf)
818 }
819
820 fn poll_write_vectored(
821 self: Pin<&mut Self>,
822 cx: &mut Context<'_>,
823 bufs: &[IoSlice<'_>],
824 ) -> Poll<Result<usize, io::Error>> {
825 let this = self.project();
826 Write::poll_write_vectored(this.inner, cx, bufs)
827 }
828
829 fn is_write_vectored(&self) -> bool {
830 self.inner.is_write_vectored()
831 }
832
833 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
834 let this = self.project();
835 Write::poll_flush(this.inner, cx)
836 }
837
838 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
839 let this = self.project();
840 Write::poll_shutdown(this.inner, cx)
841 }
842 }
843}
844
845pub(crate) type Connecting = Pin<Box<dyn Future<Output = Result<Conn, BoxError>> + Send>>;
846
847#[cfg(feature = "__tls")]
848async fn tunnel<T>(
849 mut conn: T,
850 host: String,
851 port: u16,
852 user_agent: Option<HeaderValue>,
853 auth: Option<HeaderValue>,
854) -> Result<T, BoxError>
855where
856 T: Read + Write + Unpin,
857{
858 use hyper_util::rt::TokioIo;
859 use tokio::io::{AsyncReadExt, AsyncWriteExt};
860
861 let mut buf = format!(
862 "\
863 CONNECT {host}:{port} HTTP/1.1\r\n\
864 Host: {host}:{port}\r\n\
865 "
866 )
867 .into_bytes();
868
869 if let Some(user_agent) = user_agent {
871 buf.extend_from_slice(b"User-Agent: ");
872 buf.extend_from_slice(user_agent.as_bytes());
873 buf.extend_from_slice(b"\r\n");
874 }
875
876 if let Some(value) = auth {
878 log::debug!("tunnel to {host}:{port} using basic auth");
879 buf.extend_from_slice(b"Proxy-Authorization: ");
880 buf.extend_from_slice(value.as_bytes());
881 buf.extend_from_slice(b"\r\n");
882 }
883
884 buf.extend_from_slice(b"\r\n");
886
887 let mut tokio_conn = TokioIo::new(&mut conn);
888
889 tokio_conn.write_all(&buf).await?;
890
891 let mut buf = [0; 8192];
892 let mut pos = 0;
893
894 loop {
895 let n = tokio_conn.read(&mut buf[pos..]).await?;
896
897 if n == 0 {
898 return Err(tunnel_eof());
899 }
900 pos += n;
901
902 let recvd = &buf[..pos];
903 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
904 if recvd.ends_with(b"\r\n\r\n") {
905 return Ok(conn);
906 }
907 if pos == buf.len() {
908 return Err("proxy headers too long for tunnel".into());
909 }
910 } else if recvd.starts_with(b"HTTP/1.1 407") {
912 return Err("proxy authentication required".into());
913 } else {
914 return Err("unsuccessful tunnel".into());
915 }
916 }
917}
918
919#[cfg(feature = "__tls")]
920fn tunnel_eof() -> BoxError {
921 "unexpected eof while tunneling".into()
922}
923
924#[cfg(feature = "default-tls")]
925mod native_tls_conn {
926 use super::TlsInfoFactory;
927 use hyper::rt::{Read, ReadBufCursor, Write};
928 use hyper_tls::MaybeHttpsStream;
929 use hyper_util::client::legacy::connect::{Connected, Connection};
930 use hyper_util::rt::TokioIo;
931 use pin_project_lite::pin_project;
932 use std::{
933 io::{self, IoSlice},
934 pin::Pin,
935 task::{Context, Poll},
936 };
937 use tokio::io::{AsyncRead, AsyncWrite};
938 use tokio::net::TcpStream;
939 use tokio_native_tls::TlsStream;
940
941 pin_project! {
942 pub(super) struct NativeTlsConn<T> {
943 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
944 }
945 }
946
947 impl Connection for NativeTlsConn<TokioIo<TokioIo<TcpStream>>> {
948 fn connected(&self) -> Connected {
949 let connected = self
950 .inner
951 .inner()
952 .get_ref()
953 .get_ref()
954 .get_ref()
955 .inner()
956 .connected();
957 #[cfg(feature = "native-tls-alpn")]
958 match self.inner.inner().get_ref().negotiated_alpn().ok() {
959 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
960 _ => connected,
961 }
962 #[cfg(not(feature = "native-tls-alpn"))]
963 connected
964 }
965 }
966
967 impl Connection for NativeTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
968 fn connected(&self) -> Connected {
969 let connected = self
970 .inner
971 .inner()
972 .get_ref()
973 .get_ref()
974 .get_ref()
975 .inner()
976 .connected();
977 #[cfg(feature = "native-tls-alpn")]
978 match self.inner.inner().get_ref().negotiated_alpn().ok() {
979 Some(Some(alpn_protocol)) if alpn_protocol == b"h2" => connected.negotiated_h2(),
980 _ => connected,
981 }
982 #[cfg(not(feature = "native-tls-alpn"))]
983 connected
984 }
985 }
986
987 impl<T: AsyncRead + AsyncWrite + Unpin> Read for NativeTlsConn<T> {
988 fn poll_read(
989 self: Pin<&mut Self>,
990 cx: &mut Context,
991 buf: ReadBufCursor<'_>,
992 ) -> Poll<tokio::io::Result<()>> {
993 let this = self.project();
994 Read::poll_read(this.inner, cx, buf)
995 }
996 }
997
998 impl<T: AsyncRead + AsyncWrite + Unpin> Write for NativeTlsConn<T> {
999 fn poll_write(
1000 self: Pin<&mut Self>,
1001 cx: &mut Context,
1002 buf: &[u8],
1003 ) -> Poll<Result<usize, tokio::io::Error>> {
1004 let this = self.project();
1005 Write::poll_write(this.inner, cx, buf)
1006 }
1007
1008 fn poll_write_vectored(
1009 self: Pin<&mut Self>,
1010 cx: &mut Context<'_>,
1011 bufs: &[IoSlice<'_>],
1012 ) -> Poll<Result<usize, io::Error>> {
1013 let this = self.project();
1014 Write::poll_write_vectored(this.inner, cx, bufs)
1015 }
1016
1017 fn is_write_vectored(&self) -> bool {
1018 self.inner.is_write_vectored()
1019 }
1020
1021 fn poll_flush(
1022 self: Pin<&mut Self>,
1023 cx: &mut Context,
1024 ) -> Poll<Result<(), tokio::io::Error>> {
1025 let this = self.project();
1026 Write::poll_flush(this.inner, cx)
1027 }
1028
1029 fn poll_shutdown(
1030 self: Pin<&mut Self>,
1031 cx: &mut Context,
1032 ) -> Poll<Result<(), tokio::io::Error>> {
1033 let this = self.project();
1034 Write::poll_shutdown(this.inner, cx)
1035 }
1036 }
1037
1038 impl<T> TlsInfoFactory for NativeTlsConn<T>
1039 where
1040 TokioIo<TlsStream<T>>: TlsInfoFactory,
1041 {
1042 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1043 self.inner.tls_info()
1044 }
1045 }
1046}
1047
1048#[cfg(feature = "__rustls")]
1049mod rustls_tls_conn {
1050 use super::TlsInfoFactory;
1051 use hyper::rt::{Read, ReadBufCursor, Write};
1052 use hyper_rustls::MaybeHttpsStream;
1053 use hyper_util::client::legacy::connect::{Connected, Connection};
1054 use hyper_util::rt::TokioIo;
1055 use pin_project_lite::pin_project;
1056 use std::{
1057 io::{self, IoSlice},
1058 pin::Pin,
1059 task::{Context, Poll},
1060 };
1061 use tokio::io::{AsyncRead, AsyncWrite};
1062 use tokio::net::TcpStream;
1063 use tokio_rustls::client::TlsStream;
1064
1065 pin_project! {
1066 pub(super) struct RustlsTlsConn<T> {
1067 #[pin] pub(super) inner: TokioIo<TlsStream<T>>,
1068 }
1069 }
1070
1071 impl Connection for RustlsTlsConn<TokioIo<TokioIo<TcpStream>>> {
1072 fn connected(&self) -> Connected {
1073 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
1074 self.inner
1075 .inner()
1076 .get_ref()
1077 .0
1078 .inner()
1079 .connected()
1080 .negotiated_h2()
1081 } else {
1082 self.inner.inner().get_ref().0.inner().connected()
1083 }
1084 }
1085 }
1086 impl Connection for RustlsTlsConn<TokioIo<MaybeHttpsStream<TokioIo<TcpStream>>>> {
1087 fn connected(&self) -> Connected {
1088 if self.inner.inner().get_ref().1.alpn_protocol() == Some(b"h2") {
1089 self.inner
1090 .inner()
1091 .get_ref()
1092 .0
1093 .inner()
1094 .connected()
1095 .negotiated_h2()
1096 } else {
1097 self.inner.inner().get_ref().0.inner().connected()
1098 }
1099 }
1100 }
1101
1102 impl<T: AsyncRead + AsyncWrite + Unpin> Read for RustlsTlsConn<T> {
1103 fn poll_read(
1104 self: Pin<&mut Self>,
1105 cx: &mut Context,
1106 buf: ReadBufCursor<'_>,
1107 ) -> Poll<tokio::io::Result<()>> {
1108 let this = self.project();
1109 Read::poll_read(this.inner, cx, buf)
1110 }
1111 }
1112
1113 impl<T: AsyncRead + AsyncWrite + Unpin> Write for RustlsTlsConn<T> {
1114 fn poll_write(
1115 self: Pin<&mut Self>,
1116 cx: &mut Context,
1117 buf: &[u8],
1118 ) -> Poll<Result<usize, tokio::io::Error>> {
1119 let this = self.project();
1120 Write::poll_write(this.inner, cx, buf)
1121 }
1122
1123 fn poll_write_vectored(
1124 self: Pin<&mut Self>,
1125 cx: &mut Context<'_>,
1126 bufs: &[IoSlice<'_>],
1127 ) -> Poll<Result<usize, io::Error>> {
1128 let this = self.project();
1129 Write::poll_write_vectored(this.inner, cx, bufs)
1130 }
1131
1132 fn is_write_vectored(&self) -> bool {
1133 self.inner.is_write_vectored()
1134 }
1135
1136 fn poll_flush(
1137 self: Pin<&mut Self>,
1138 cx: &mut Context,
1139 ) -> Poll<Result<(), tokio::io::Error>> {
1140 let this = self.project();
1141 Write::poll_flush(this.inner, cx)
1142 }
1143
1144 fn poll_shutdown(
1145 self: Pin<&mut Self>,
1146 cx: &mut Context,
1147 ) -> Poll<Result<(), tokio::io::Error>> {
1148 let this = self.project();
1149 Write::poll_shutdown(this.inner, cx)
1150 }
1151 }
1152 impl<T> TlsInfoFactory for RustlsTlsConn<T>
1153 where
1154 TokioIo<TlsStream<T>>: TlsInfoFactory,
1155 {
1156 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1157 self.inner.tls_info()
1158 }
1159 }
1160}
1161
1162#[cfg(feature = "socks")]
1163mod socks {
1164 use std::io;
1165
1166 use http::Uri;
1167 use tokio::net::TcpStream;
1168 use tokio_socks::tcp::{Socks4Stream, Socks5Stream};
1169
1170 use super::{BoxError, Scheme};
1171 use crate::proxy::ProxyScheme;
1172
1173 pub(super) enum DnsResolve {
1174 Local,
1175 Proxy,
1176 }
1177
1178 pub(super) async fn connect(
1179 proxy: ProxyScheme,
1180 dst: Uri,
1181 dns: DnsResolve,
1182 ) -> Result<TcpStream, BoxError> {
1183 let https = dst.scheme() == Some(&Scheme::HTTPS);
1184 let original_host = dst
1185 .host()
1186 .ok_or(io::Error::new(io::ErrorKind::Other, "no host in url"))?;
1187 let mut host = original_host.to_owned();
1188 let port = match dst.port() {
1189 Some(p) => p.as_u16(),
1190 None if https => 443u16,
1191 _ => 80u16,
1192 };
1193
1194 if let DnsResolve::Local = dns {
1195 let maybe_new_target = tokio::net::lookup_host((host.as_str(), port)).await?.next();
1196 if let Some(new_target) = maybe_new_target {
1197 host = new_target.ip().to_string();
1198 }
1199 }
1200
1201 match proxy {
1202 ProxyScheme::Socks4 { addr, .. } => {
1203 let stream = Socks4Stream::connect(addr, (host.as_str(), port))
1204 .await
1205 .map_err(|e| format!("socks connect error: {e}"))?;
1206 Ok(stream.into_inner())
1207 }
1208 ProxyScheme::Socks5 { addr, ref auth, .. } => {
1209 let stream = if let Some((username, password)) = auth {
1210 Socks5Stream::connect_with_password(
1211 addr,
1212 (host.as_str(), port),
1213 &username,
1214 &password,
1215 )
1216 .await
1217 .map_err(|e| format!("socks connect error: {e}"))?
1218 } else {
1219 Socks5Stream::connect(addr, (host.as_str(), port))
1220 .await
1221 .map_err(|e| format!("socks connect error: {e}"))?
1222 };
1223
1224 Ok(stream.into_inner())
1225 }
1226 _ => unreachable!(),
1227 }
1228 }
1229}
1230
1231mod verbose {
1232 use hyper::rt::{Read, ReadBufCursor, Write};
1233 use hyper_util::client::legacy::connect::{Connected, Connection};
1234 use std::cmp::min;
1235 use std::fmt;
1236 use std::io::{self, IoSlice};
1237 use std::pin::Pin;
1238 use std::task::{Context, Poll};
1239
1240 pub(super) const OFF: Wrapper = Wrapper(false);
1241
1242 #[derive(Clone, Copy)]
1243 pub(super) struct Wrapper(pub(super) bool);
1244
1245 impl Wrapper {
1246 pub(super) fn wrap<T: super::AsyncConnWithInfo>(&self, conn: T) -> super::BoxConn {
1247 if self.0 && log::log_enabled!(log::Level::Trace) {
1248 Box::new(Verbose {
1249 id: crate::util::fast_random() as u32,
1251 inner: conn,
1252 })
1253 } else {
1254 Box::new(conn)
1255 }
1256 }
1257 }
1258
1259 struct Verbose<T> {
1260 id: u32,
1261 inner: T,
1262 }
1263
1264 impl<T: Connection + Read + Write + Unpin> Connection for Verbose<T> {
1265 fn connected(&self) -> Connected {
1266 self.inner.connected()
1267 }
1268 }
1269
1270 impl<T: Read + Write + Unpin> Read for Verbose<T> {
1271 fn poll_read(
1272 mut self: Pin<&mut Self>,
1273 cx: &mut Context,
1274 mut buf: ReadBufCursor<'_>,
1275 ) -> Poll<std::io::Result<()>> {
1276 let mut vbuf = hyper::rt::ReadBuf::uninit(unsafe { buf.as_mut() });
1280 match Pin::new(&mut self.inner).poll_read(cx, vbuf.unfilled()) {
1281 Poll::Ready(Ok(())) => {
1282 log::trace!("{:08x} read: {:?}", self.id, Escape(vbuf.filled()));
1283 let len = vbuf.filled().len();
1284 unsafe {
1287 buf.advance(len);
1288 }
1289 Poll::Ready(Ok(()))
1290 }
1291 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1292 Poll::Pending => Poll::Pending,
1293 }
1294 }
1295 }
1296
1297 impl<T: Read + Write + Unpin> Write for Verbose<T> {
1298 fn poll_write(
1299 mut self: Pin<&mut Self>,
1300 cx: &mut Context,
1301 buf: &[u8],
1302 ) -> Poll<Result<usize, std::io::Error>> {
1303 match Pin::new(&mut self.inner).poll_write(cx, buf) {
1304 Poll::Ready(Ok(n)) => {
1305 log::trace!("{:08x} write: {:?}", self.id, Escape(&buf[..n]));
1306 Poll::Ready(Ok(n))
1307 }
1308 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1309 Poll::Pending => Poll::Pending,
1310 }
1311 }
1312
1313 fn poll_write_vectored(
1314 mut self: Pin<&mut Self>,
1315 cx: &mut Context<'_>,
1316 bufs: &[IoSlice<'_>],
1317 ) -> Poll<Result<usize, io::Error>> {
1318 match Pin::new(&mut self.inner).poll_write_vectored(cx, bufs) {
1319 Poll::Ready(Ok(nwritten)) => {
1320 log::trace!(
1321 "{:08x} write (vectored): {:?}",
1322 self.id,
1323 Vectored { bufs, nwritten }
1324 );
1325 Poll::Ready(Ok(nwritten))
1326 }
1327 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
1328 Poll::Pending => Poll::Pending,
1329 }
1330 }
1331
1332 fn is_write_vectored(&self) -> bool {
1333 self.inner.is_write_vectored()
1334 }
1335
1336 fn poll_flush(
1337 mut self: Pin<&mut Self>,
1338 cx: &mut Context,
1339 ) -> Poll<Result<(), std::io::Error>> {
1340 Pin::new(&mut self.inner).poll_flush(cx)
1341 }
1342
1343 fn poll_shutdown(
1344 mut self: Pin<&mut Self>,
1345 cx: &mut Context,
1346 ) -> Poll<Result<(), std::io::Error>> {
1347 Pin::new(&mut self.inner).poll_shutdown(cx)
1348 }
1349 }
1350
1351 #[cfg(feature = "__tls")]
1352 impl<T: super::TlsInfoFactory> super::TlsInfoFactory for Verbose<T> {
1353 fn tls_info(&self) -> Option<crate::tls::TlsInfo> {
1354 self.inner.tls_info()
1355 }
1356 }
1357
1358 struct Escape<'a>(&'a [u8]);
1359
1360 impl fmt::Debug for Escape<'_> {
1361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1362 write!(f, "b\"")?;
1363 for &c in self.0 {
1364 if c == b'\n' {
1366 write!(f, "\\n")?;
1367 } else if c == b'\r' {
1368 write!(f, "\\r")?;
1369 } else if c == b'\t' {
1370 write!(f, "\\t")?;
1371 } else if c == b'\\' || c == b'"' {
1372 write!(f, "\\{}", c as char)?;
1373 } else if c == b'\0' {
1374 write!(f, "\\0")?;
1375 } else if c >= 0x20 && c < 0x7f {
1377 write!(f, "{}", c as char)?;
1378 } else {
1379 write!(f, "\\x{c:02x}")?;
1380 }
1381 }
1382 write!(f, "\"")?;
1383 Ok(())
1384 }
1385 }
1386
1387 struct Vectored<'a, 'b> {
1388 bufs: &'a [IoSlice<'b>],
1389 nwritten: usize,
1390 }
1391
1392 impl fmt::Debug for Vectored<'_, '_> {
1393 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1394 let mut left = self.nwritten;
1395 for buf in self.bufs.iter() {
1396 if left == 0 {
1397 break;
1398 }
1399 let n = min(left, buf.len());
1400 Escape(&buf[..n]).fmt(f)?;
1401 left -= n;
1402 }
1403 Ok(())
1404 }
1405 }
1406}
1407
1408#[cfg(feature = "__tls")]
1409#[cfg(test)]
1410mod tests {
1411 use super::tunnel;
1412 use crate::proxy;
1413 use hyper_util::rt::TokioIo;
1414 use std::io::{Read, Write};
1415 use std::net::TcpListener;
1416 use std::thread;
1417 use tokio::net::TcpStream;
1418 use tokio::runtime;
1419
1420 static TUNNEL_UA: &str = "tunnel-test/x.y";
1421 static TUNNEL_OK: &[u8] = b"\
1422 HTTP/1.1 200 OK\r\n\
1423 \r\n\
1424 ";
1425
1426 macro_rules! mock_tunnel {
1427 () => {{
1428 mock_tunnel!(TUNNEL_OK)
1429 }};
1430 ($write:expr) => {{
1431 mock_tunnel!($write, "")
1432 }};
1433 ($write:expr, $auth:expr) => {{
1434 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1435 let addr = listener.local_addr().unwrap();
1436 let connect_expected = format!(
1437 "\
1438 CONNECT {0}:{1} HTTP/1.1\r\n\
1439 Host: {0}:{1}\r\n\
1440 User-Agent: {2}\r\n\
1441 {3}\
1442 \r\n\
1443 ",
1444 addr.ip(),
1445 addr.port(),
1446 TUNNEL_UA,
1447 $auth
1448 )
1449 .into_bytes();
1450
1451 thread::spawn(move || {
1452 let (mut sock, _) = listener.accept().unwrap();
1453 let mut buf = [0u8; 4096];
1454 let n = sock.read(&mut buf).unwrap();
1455 assert_eq!(&buf[..n], &connect_expected[..]);
1456
1457 sock.write_all($write).unwrap();
1458 });
1459 addr
1460 }};
1461 }
1462
1463 fn ua() -> Option<http::header::HeaderValue> {
1464 Some(http::header::HeaderValue::from_static(TUNNEL_UA))
1465 }
1466
1467 #[test]
1468 fn test_tunnel() {
1469 let addr = mock_tunnel!();
1470
1471 let rt = runtime::Builder::new_current_thread()
1472 .enable_all()
1473 .build()
1474 .expect("new rt");
1475 let f = async move {
1476 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1477 let host = addr.ip().to_string();
1478 let port = addr.port();
1479 tunnel(tcp, host, port, ua(), None).await
1480 };
1481
1482 rt.block_on(f).unwrap();
1483 }
1484
1485 #[test]
1486 fn test_tunnel_eof() {
1487 let addr = mock_tunnel!(b"HTTP/1.1 200 OK");
1488
1489 let rt = runtime::Builder::new_current_thread()
1490 .enable_all()
1491 .build()
1492 .expect("new rt");
1493 let f = async move {
1494 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1495 let host = addr.ip().to_string();
1496 let port = addr.port();
1497 tunnel(tcp, host, port, ua(), None).await
1498 };
1499
1500 rt.block_on(f).unwrap_err();
1501 }
1502
1503 #[test]
1504 fn test_tunnel_non_http_response() {
1505 let addr = mock_tunnel!(b"foo bar baz hallo");
1506
1507 let rt = runtime::Builder::new_current_thread()
1508 .enable_all()
1509 .build()
1510 .expect("new rt");
1511 let f = async move {
1512 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1513 let host = addr.ip().to_string();
1514 let port = addr.port();
1515 tunnel(tcp, host, port, ua(), None).await
1516 };
1517
1518 rt.block_on(f).unwrap_err();
1519 }
1520
1521 #[test]
1522 fn test_tunnel_proxy_unauthorized() {
1523 let addr = mock_tunnel!(
1524 b"\
1525 HTTP/1.1 407 Proxy Authentication Required\r\n\
1526 Proxy-Authenticate: Basic realm=\"nope\"\r\n\
1527 \r\n\
1528 "
1529 );
1530
1531 let rt = runtime::Builder::new_current_thread()
1532 .enable_all()
1533 .build()
1534 .expect("new rt");
1535 let f = async move {
1536 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1537 let host = addr.ip().to_string();
1538 let port = addr.port();
1539 tunnel(tcp, host, port, ua(), None).await
1540 };
1541
1542 let error = rt.block_on(f).unwrap_err();
1543 assert_eq!(error.to_string(), "proxy authentication required");
1544 }
1545
1546 #[test]
1547 fn test_tunnel_basic_auth() {
1548 let addr = mock_tunnel!(
1549 TUNNEL_OK,
1550 "Proxy-Authorization: Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==\r\n"
1551 );
1552
1553 let rt = runtime::Builder::new_current_thread()
1554 .enable_all()
1555 .build()
1556 .expect("new rt");
1557 let f = async move {
1558 let tcp = TokioIo::new(TcpStream::connect(&addr).await?);
1559 let host = addr.ip().to_string();
1560 let port = addr.port();
1561 tunnel(
1562 tcp,
1563 host,
1564 port,
1565 ua(),
1566 Some(proxy::encode_basic_auth("Aladdin", "open sesame")),
1567 )
1568 .await
1569 };
1570
1571 rt.block_on(f).unwrap();
1572 }
1573}