tonic/transport/service/
grpc_timeout.rs
1use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
2use http::{HeaderMap, HeaderValue, Request};
3use pin_project::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::Sleep;
11use tower_service::Service;
12
13#[derive(Debug, Clone)]
14pub(crate) struct GrpcTimeout<S> {
15 inner: S,
16 server_timeout: Option<Duration>,
17}
18
19impl<S> GrpcTimeout<S> {
20 pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
21 Self {
22 inner,
23 server_timeout,
24 }
25 }
26}
27
28impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
29where
30 S: Service<Request<ReqBody>>,
31 S::Error: Into<crate::Error>,
32{
33 type Response = S::Response;
34 type Error = crate::Error;
35 type Future = ResponseFuture<S::Future>;
36
37 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38 self.inner.poll_ready(cx).map_err(Into::into)
39 }
40
41 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
42 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
43 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
44 None
45 });
46
47 let timeout_duration = match (client_timeout, self.server_timeout) {
49 (None, None) => None,
50 (Some(dur), None) => Some(dur),
51 (None, Some(dur)) => Some(dur),
52 (Some(header), Some(server)) => {
53 let shorter_duration = std::cmp::min(header, server);
54 Some(shorter_duration)
55 }
56 };
57
58 ResponseFuture {
59 inner: self.inner.call(req),
60 sleep: timeout_duration.map(tokio::time::sleep),
61 }
62 }
63}
64
65#[pin_project]
66pub(crate) struct ResponseFuture<F> {
67 #[pin]
68 inner: F,
69 #[pin]
70 sleep: Option<Sleep>,
71}
72
73impl<F, Res, E> Future for ResponseFuture<F>
74where
75 F: Future<Output = Result<Res, E>>,
76 E: Into<crate::Error>,
77{
78 type Output = Result<Res, crate::Error>;
79
80 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81 let this = self.project();
82
83 if let Poll::Ready(result) = this.inner.poll(cx) {
84 return Poll::Ready(result.map_err(Into::into));
85 }
86
87 if let Some(sleep) = this.sleep.as_pin_mut() {
88 ready!(sleep.poll(cx));
89 return Poll::Ready(Err(TimeoutExpired(()).into()));
90 }
91
92 Poll::Pending
93 }
94}
95
96const SECONDS_IN_HOUR: u64 = 60 * 60;
97const SECONDS_IN_MINUTE: u64 = 60;
98
99fn try_parse_grpc_timeout(
104 headers: &HeaderMap<HeaderValue>,
105) -> Result<Option<Duration>, &HeaderValue> {
106 match headers.get(GRPC_TIMEOUT_HEADER) {
107 Some(val) => {
108 let (timeout_value, timeout_unit) = val
109 .to_str()
110 .map_err(|_| val)
111 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
112 .split_at(val.len() - 1);
118
119 if timeout_value.len() > 8 {
122 return Err(val);
123 }
124
125 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
126
127 let duration = match timeout_unit {
128 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
130 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
132 "S" => Duration::from_secs(timeout_value),
134 "m" => Duration::from_millis(timeout_value),
136 "u" => Duration::from_micros(timeout_value),
138 "n" => Duration::from_nanos(timeout_value),
140 _ => return Err(val),
141 };
142
143 Ok(Some(duration))
144 }
145 None => Ok(None),
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use quickcheck::{Arbitrary, Gen};
153 use quickcheck_macros::quickcheck;
154
155 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
157 let mut hm = HeaderMap::new();
158 if let Some(v) = val {
159 let hv = HeaderValue::from_str(v).unwrap();
160 hm.insert(GRPC_TIMEOUT_HEADER, hv);
161 };
162
163 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
164 }
165
166 #[test]
167 fn test_hours() {
168 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
169 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
170 }
171
172 #[test]
173 fn test_minutes() {
174 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
175 assert_eq!(Duration::from_secs(60), parsed_duration);
176 }
177
178 #[test]
179 fn test_seconds() {
180 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
181 assert_eq!(Duration::from_secs(42), parsed_duration);
182 }
183
184 #[test]
185 fn test_milliseconds() {
186 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
187 assert_eq!(Duration::from_millis(13), parsed_duration);
188 }
189
190 #[test]
191 fn test_microseconds() {
192 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
193 assert_eq!(Duration::from_micros(2), parsed_duration);
194 }
195
196 #[test]
197 fn test_nanoseconds() {
198 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
199 assert_eq!(Duration::from_nanos(82), parsed_duration);
200 }
201
202 #[test]
203 fn test_header_not_present() {
204 let parsed_duration = setup_map_try_parse(None).unwrap();
205 assert!(parsed_duration.is_none());
206 }
207
208 #[test]
209 #[should_panic(expected = "82f")]
210 fn test_invalid_unit() {
211 setup_map_try_parse(Some("82f")).unwrap().unwrap();
213 }
214
215 #[test]
216 #[should_panic(expected = "123456789H")]
217 fn test_too_many_digits() {
218 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
220 }
221
222 #[test]
223 #[should_panic(expected = "oneH")]
224 fn test_invalid_digits() {
225 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
227 }
228
229 #[quickcheck]
230 fn fuzz(header_value: HeaderValueGen) -> bool {
231 let header_value = header_value.0;
232
233 let _ = setup_map_try_parse(Some(&header_value));
235
236 true
237 }
238
239 #[derive(Clone, Debug)]
241 struct HeaderValueGen(String);
242
243 impl Arbitrary for HeaderValueGen {
244 fn arbitrary(g: &mut Gen) -> Self {
245 let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
246 Self(gen_string(g, 0, max))
247 }
248 }
249
250 fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
252 let bytes: Vec<_> = (min..max)
253 .map(|_| {
254 g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
256 .copied()
257 .unwrap()
258 })
259 .collect();
260
261 String::from_utf8(bytes).unwrap()
262 }
263}