hyper_timeout/
lib.rs
1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use hyper::rt::{Read, Write};
8use tokio::time::timeout;
9
10use hyper::Uri;
11use hyper_util::client::legacy::connect::{Connected, Connection};
12use tower_service::Service;
13
14mod stream;
15use stream::TimeoutStream;
16
17type BoxError = Box<dyn std::error::Error + Send + Sync>;
18
19#[derive(Debug, Clone)]
21pub struct TimeoutConnector<T> {
22 connector: T,
24 connect_timeout: Option<Duration>,
26 read_timeout: Option<Duration>,
28 write_timeout: Option<Duration>,
30}
31
32impl<T> TimeoutConnector<T>
33where
34 T: Service<Uri> + Send,
35 T::Response: Read + Write + Send + Unpin,
36 T::Future: Send + 'static,
37 T::Error: Into<BoxError>,
38{
39 pub fn new(connector: T) -> Self {
41 TimeoutConnector {
42 connector,
43 connect_timeout: None,
44 read_timeout: None,
45 write_timeout: None,
46 }
47 }
48}
49
50impl<T> Service<Uri> for TimeoutConnector<T>
51where
52 T: Service<Uri> + Send,
53 T::Response: Read + Write + Connection + Send + Unpin,
54 T::Future: Send + 'static,
55 T::Error: Into<BoxError>,
56{
57 type Response = Pin<Box<TimeoutStream<T::Response>>>;
58 type Error = BoxError;
59 #[allow(clippy::type_complexity)]
60 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 self.connector.poll_ready(cx).map_err(Into::into)
64 }
65
66 fn call(&mut self, dst: Uri) -> Self::Future {
67 let connect_timeout = self.connect_timeout;
68 let read_timeout = self.read_timeout;
69 let write_timeout = self.write_timeout;
70 let connecting = self.connector.call(dst);
71
72 let fut = async move {
73 let mut stream = match connect_timeout {
74 None => {
75 let io = connecting.await.map_err(Into::into)?;
76 TimeoutStream::new(io)
77 }
78 Some(connect_timeout) => {
79 let timeout = timeout(connect_timeout, connecting);
80 let connecting = timeout
81 .await
82 .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
83 let io = connecting.map_err(Into::into)?;
84 TimeoutStream::new(io)
85 }
86 };
87 stream.set_read_timeout(read_timeout);
88 stream.set_write_timeout(write_timeout);
89 Ok(Box::pin(stream))
90 };
91
92 Box::pin(fut)
93 }
94}
95
96impl<T> TimeoutConnector<T> {
97 #[inline]
101 pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
102 self.connect_timeout = val;
103 }
104
105 #[inline]
109 pub fn set_read_timeout(&mut self, val: Option<Duration>) {
110 self.read_timeout = val;
111 }
112
113 #[inline]
117 pub fn set_write_timeout(&mut self, val: Option<Duration>) {
118 self.write_timeout = val;
119 }
120}
121
122impl<T> Connection for TimeoutConnector<T>
123where
124 T: Read + Write + Connection + Service<Uri> + Send + Unpin,
125 T::Response: Read + Write + Send + Unpin,
126 T::Future: Send + 'static,
127 T::Error: Into<BoxError>,
128{
129 fn connected(&self) -> Connected {
130 self.connector.connected()
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use std::time::Duration;
137 use std::{error::Error, io};
138
139 use http_body_util::Empty;
140 use hyper::body::Bytes;
141 use hyper_util::{
142 client::legacy::{connect::HttpConnector, Client},
143 rt::TokioExecutor,
144 };
145
146 use super::TimeoutConnector;
147
148 #[tokio::test]
149 async fn test_timeout_connector() {
150 let url = "http://10.255.255.1".parse().unwrap();
152
153 let http = HttpConnector::new();
154 let mut connector = TimeoutConnector::new(http);
155 connector.set_connect_timeout(Some(Duration::from_millis(1)));
156
157 let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
158
159 let res = client.get(url).await;
160
161 match res {
162 Ok(_) => panic!("Expected a timeout"),
163 Err(e) => {
164 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
165 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
166 } else {
167 panic!("Expected timeout error");
168 }
169 }
170 }
171 }
172
173 #[tokio::test]
174 async fn test_read_timeout() {
175 let url = "http://example.com".parse().unwrap();
176
177 let http = HttpConnector::new();
178 let mut connector = TimeoutConnector::new(http);
179 connector.set_read_timeout(Some(Duration::from_millis(1)));
181
182 let client = Client::builder(TokioExecutor::new()).build::<_, Empty<Bytes>>(connector);
183
184 let res = client.get(url).await;
185
186 if let Err(client_e) = res {
187 if let Some(hyper_e) = client_e.source() {
188 if let Some(io_e) = hyper_e.source().unwrap().downcast_ref::<io::Error>() {
189 return assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
190 }
191 }
192 }
193 panic!("Expected timeout error");
194 }
195}