tonic/transport/service/
grpc_timeout.rs

1use crate::{metadata::GRPC_TIMEOUT_HEADER, TimeoutExpired};
2use http::{HeaderMap, HeaderValue, Request};
3use pin_project::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{ready, Context, Poll},
8    time::Duration,
9};
10use tokio::time::Sleep;
11use tower_service::Service;
12
13#[derive(Debug, Clone)]
14pub(crate) struct GrpcTimeout<S> {
15    inner: S,
16    server_timeout: Option<Duration>,
17}
18
19impl<S> GrpcTimeout<S> {
20    pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
21        Self {
22            inner,
23            server_timeout,
24        }
25    }
26}
27
28impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
29where
30    S: Service<Request<ReqBody>>,
31    S::Error: Into<crate::Error>,
32{
33    type Response = S::Response;
34    type Error = crate::Error;
35    type Future = ResponseFuture<S::Future>;
36
37    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
38        self.inner.poll_ready(cx).map_err(Into::into)
39    }
40
41    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
42        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
43            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
44            None
45        });
46
47        // Use the shorter of the two durations, if either are set
48        let timeout_duration = match (client_timeout, self.server_timeout) {
49            (None, None) => None,
50            (Some(dur), None) => Some(dur),
51            (None, Some(dur)) => Some(dur),
52            (Some(header), Some(server)) => {
53                let shorter_duration = std::cmp::min(header, server);
54                Some(shorter_duration)
55            }
56        };
57
58        ResponseFuture {
59            inner: self.inner.call(req),
60            sleep: timeout_duration.map(tokio::time::sleep),
61        }
62    }
63}
64
65#[pin_project]
66pub(crate) struct ResponseFuture<F> {
67    #[pin]
68    inner: F,
69    #[pin]
70    sleep: Option<Sleep>,
71}
72
73impl<F, Res, E> Future for ResponseFuture<F>
74where
75    F: Future<Output = Result<Res, E>>,
76    E: Into<crate::Error>,
77{
78    type Output = Result<Res, crate::Error>;
79
80    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
81        let this = self.project();
82
83        if let Poll::Ready(result) = this.inner.poll(cx) {
84            return Poll::Ready(result.map_err(Into::into));
85        }
86
87        if let Some(sleep) = this.sleep.as_pin_mut() {
88            ready!(sleep.poll(cx));
89            return Poll::Ready(Err(TimeoutExpired(()).into()));
90        }
91
92        Poll::Pending
93    }
94}
95
96const SECONDS_IN_HOUR: u64 = 60 * 60;
97const SECONDS_IN_MINUTE: u64 = 60;
98
99/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
100/// the value we attempted to parse.
101///
102/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
103fn try_parse_grpc_timeout(
104    headers: &HeaderMap<HeaderValue>,
105) -> Result<Option<Duration>, &HeaderValue> {
106    match headers.get(GRPC_TIMEOUT_HEADER) {
107        Some(val) => {
108            let (timeout_value, timeout_unit) = val
109                .to_str()
110                .map_err(|_| val)
111                .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
112                // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
113                // `split_at` will never panic from trying to split in the middle of a character.
114                // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
115                //
116                // `len - 1` also wont panic since we just checked `s.is_empty`.
117                .split_at(val.len() - 1);
118
119            // gRPC spec specifies `TimeoutValue` will be at most 8 digits
120            // Caping this at 8 digits also prevents integer overflow from ever occurring
121            if timeout_value.len() > 8 {
122                return Err(val);
123            }
124
125            let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
126
127            let duration = match timeout_unit {
128                // Hours
129                "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
130                // Minutes
131                "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
132                // Seconds
133                "S" => Duration::from_secs(timeout_value),
134                // Milliseconds
135                "m" => Duration::from_millis(timeout_value),
136                // Microseconds
137                "u" => Duration::from_micros(timeout_value),
138                // Nanoseconds
139                "n" => Duration::from_nanos(timeout_value),
140                _ => return Err(val),
141            };
142
143            Ok(Some(duration))
144        }
145        None => Ok(None),
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use quickcheck::{Arbitrary, Gen};
153    use quickcheck_macros::quickcheck;
154
155    // Helper function to reduce the boiler plate of our test cases
156    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
157        let mut hm = HeaderMap::new();
158        if let Some(v) = val {
159            let hv = HeaderValue::from_str(v).unwrap();
160            hm.insert(GRPC_TIMEOUT_HEADER, hv);
161        };
162
163        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
164    }
165
166    #[test]
167    fn test_hours() {
168        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
169        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
170    }
171
172    #[test]
173    fn test_minutes() {
174        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
175        assert_eq!(Duration::from_secs(60), parsed_duration);
176    }
177
178    #[test]
179    fn test_seconds() {
180        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
181        assert_eq!(Duration::from_secs(42), parsed_duration);
182    }
183
184    #[test]
185    fn test_milliseconds() {
186        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
187        assert_eq!(Duration::from_millis(13), parsed_duration);
188    }
189
190    #[test]
191    fn test_microseconds() {
192        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
193        assert_eq!(Duration::from_micros(2), parsed_duration);
194    }
195
196    #[test]
197    fn test_nanoseconds() {
198        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
199        assert_eq!(Duration::from_nanos(82), parsed_duration);
200    }
201
202    #[test]
203    fn test_header_not_present() {
204        let parsed_duration = setup_map_try_parse(None).unwrap();
205        assert!(parsed_duration.is_none());
206    }
207
208    #[test]
209    #[should_panic(expected = "82f")]
210    fn test_invalid_unit() {
211        // "f" is not a valid TimeoutUnit
212        setup_map_try_parse(Some("82f")).unwrap().unwrap();
213    }
214
215    #[test]
216    #[should_panic(expected = "123456789H")]
217    fn test_too_many_digits() {
218        // gRPC spec states TimeoutValue will be at most 8 digits
219        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
220    }
221
222    #[test]
223    #[should_panic(expected = "oneH")]
224    fn test_invalid_digits() {
225        // gRPC spec states TimeoutValue will be at most 8 digits
226        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
227    }
228
229    #[quickcheck]
230    fn fuzz(header_value: HeaderValueGen) -> bool {
231        let header_value = header_value.0;
232
233        // this just shouldn't panic
234        let _ = setup_map_try_parse(Some(&header_value));
235
236        true
237    }
238
239    /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s.
240    #[derive(Clone, Debug)]
241    struct HeaderValueGen(String);
242
243    impl Arbitrary for HeaderValueGen {
244        fn arbitrary(g: &mut Gen) -> Self {
245            let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
246            Self(gen_string(g, 0, max))
247        }
248    }
249
250    // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs
251    fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
252        let bytes: Vec<_> = (min..max)
253            .map(|_| {
254                // Chars to pick from
255                g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
256                    .copied()
257                    .unwrap()
258            })
259            .collect();
260
261        String::from_utf8(bytes).unwrap()
262    }
263}