1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use hyper::rt::{Read, Write};
8use tokio::time::timeout;
9
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use tower_service::Service;
13
14mod stream;
15use stream::TimeoutStream;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18
19#[derive(Debug, Clone)]
21pub struct TimeoutConnector<T> {
22    connector: T,
24    connect_timeout: Option<Duration>,
26    read_timeout: Option<Duration>,
28    write_timeout: Option<Duration>,
30}
31
32impl<T> TimeoutConnector<T>
33where
34    T: Service<Uri> + Send,
35    T::Response: Read + Write + Send + Unpin,
36    T::Future: Send + 'static,
37    T::Error: Into<BoxError>,
38{
39    pub fn new(connector: T) -> Self {
41        TimeoutConnector {
42            connector,
43            connect_timeout: None,
44            read_timeout: None,
45            write_timeout: None,
46        }
47    }
48}
49
50impl<T> Service<Uri> for TimeoutConnector<T>
51where
52    T: Service<Uri> + Send,
53    T::Response: Read + Write + Connection + Send + Unpin,
54    T::Future: Send + 'static,
55    T::Error: Into<BoxError>,
56{
57    type Response = Pin<Box<TimeoutStream<T::Response>>>;
58    type Error = BoxError;
59    #[allow(clippy::type_complexity)]
60    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        self.connector.poll_ready(cx).map_err(Into::into)
64    }
65
66    fn call(&mut self, dst: Uri) -> Self::Future {
67        let connect_timeout = self.connect_timeout;
68        let read_timeout = self.read_timeout;
69        let write_timeout = self.write_timeout;
70        let connecting = self.connector.call(dst);
71
72        let fut = async move {
73            let mut stream = match connect_timeout {
74                None => {
75                    let io = connecting.await.map_err(Into::into)?;
76                    TimeoutStream::new(io)
77                }
78                Some(connect_timeout) => {
79                    let timeout = timeout(connect_timeout, connecting);
80                    let connecting = timeout
81                        .await
82                        .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
83                    let io = connecting.map_err(Into::into)?;
84                    TimeoutStream::new(io)
85                }
86            };
87            stream.set_read_timeout(read_timeout);
88            stream.set_write_timeout(write_timeout);
89            Ok(Box::pin(stream))
90        };
91
92        Box::pin(fut)
93    }
94}
95
96impl<T> TimeoutConnector<T> {
97    #[inline]
101    pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
102        self.connect_timeout = val;
103    }
104
105    #[inline]
109    pub fn set_read_timeout(&mut self, val: Option<Duration>) {
110        self.read_timeout = val;
111    }
112
113    #[inline]
117    pub fn set_write_timeout(&mut self, val: Option<Duration>) {
118        self.write_timeout = val;
119    }
120}
121
122impl<T> Connection for TimeoutConnector<T>
123where
124    T: Read + Write + Connection + Service<Uri> + Send + Unpin,
125    T::Response: Read + Write + Send + Unpin,
126    T::Future: Send + 'static,
127    T::Error: Into<BoxError>,
128{
129    fn connected(&self) -> Connected {
130        self.connector.connected()
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use std::time::Duration;
137    use std::{error::Error, io};
138
139    use http_body_util::Empty;
140    use hyper::body::Bytes;
141    use hyper_util::{
142        client::legacy::{connect::HttpConnector, Client},
143        rt::TokioExecutor,
144    };
145
146    use super::TimeoutConnector;
147
148    #[tokio::test]
149    async fn test_timeout_connector() {
150        let url = "http://10.255.255.1".parse().unwrap();
152
153        let http = HttpConnector::new();
154        let mut connector = TimeoutConnector::new(http);
155        connector.set_connect_timeout(Some(Duration::from_millis(1)));
156
157        let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
158
159        let res = client.get(url).await;
160
161        match res {
162            Ok(_) => panic!("Expected a timeout"),
163            Err(e) => {
164                if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
165                    assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
166                } else {
167                    panic!("Expected timeout error");
168                }
169            }
170        }
171    }
172
173    #[tokio::test]
174    async fn test_read_timeout() {
175        let url = "http://example.com".parse().unwrap();
176
177        let http = HttpConnector::new();
178        let mut connector = TimeoutConnector::new(http);
179        connector.set_read_timeout(Some(Duration::from_millis(1)));
181
182        let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
183
184        let res = client.get(url).await;
185
186        if let Err(client_e) = res {
187            if let Some(hyper_e) = client_e.source() {
188                if let Some(io_e) = hyper_e.source().unwrap().downcast_ref::<io::Error>() {
189                    return assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
190                }
191            }
192        }
193        panic!("Expected timeout error");
194    }
195}