tonic/service/
interceptor.rs

1//! gRPC interceptors which are a kind of middleware.
2//!
3//! See [`Interceptor`] for more details.
4
5use crate::{
6    body::{boxed, BoxBody},
7    request::SanitizeHeaders,
8    Status,
9};
10use bytes::Bytes;
11use pin_project::pin_project;
12use std::{
13    fmt,
14    future::Future,
15    pin::Pin,
16    task::{Context, Poll},
17};
18use tower_layer::Layer;
19use tower_service::Service;
20
21/// A gRPC interceptor.
22///
23/// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
24/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
25/// request. Two, cancel a request with a `Status`.
26///
27/// Any function that satisfies the bound `FnMut(Request<()>) -> Result<Request<()>, Status>` can be
28/// used as an `Interceptor`.
29///
30/// An interceptor can be used on both the server and client side through the `tonic-build` crate's
31/// generated structs.
32///
33/// See the [interceptor example][example] for more details.
34///
35/// If you need more powerful middleware, [tower] is the recommended approach. You can find
36/// examples of how to use tower with tonic [here][tower-example].
37///
38/// Additionally, interceptors is not the recommended way to add logging to your service. For that
39/// a [tower] middleware is more appropriate since it can also act on the response. For example
40/// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
41/// middleware supports gRPC out of the box.
42///
43/// [tower]: https://crates.io/crates/tower
44/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
45/// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
46pub trait Interceptor {
47    /// Intercept a request before it is sent, optionally cancelling it.
48    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
49}
50
51impl<F> Interceptor for F
52where
53    F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
54{
55    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
56        self(request)
57    }
58}
59
60/// Create a new interceptor layer.
61///
62/// See [`Interceptor`] for more details.
63pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
64where
65    F: Interceptor,
66{
67    InterceptorLayer { f }
68}
69
70/// A gRPC interceptor that can be used as a [`Layer`],
71/// created by calling [`interceptor`].
72///
73/// See [`Interceptor`] for more details.
74#[derive(Debug, Clone, Copy)]
75pub struct InterceptorLayer<F> {
76    f: F,
77}
78
79impl<S, F> Layer<S> for InterceptorLayer<F>
80where
81    F: Interceptor + Clone,
82{
83    type Service = InterceptedService<S, F>;
84
85    fn layer(&self, service: S) -> Self::Service {
86        InterceptedService::new(service, self.f.clone())
87    }
88}
89
90/// A service wrapped in an interceptor middleware.
91///
92/// See [`Interceptor`] for more details.
93#[derive(Clone, Copy)]
94pub struct InterceptedService<S, F> {
95    inner: S,
96    f: F,
97}
98
99impl<S, F> InterceptedService<S, F> {
100    /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the
101    /// function `F`.
102    pub fn new(service: S, f: F) -> Self
103    where
104        F: Interceptor,
105    {
106        Self { inner: service, f }
107    }
108}
109
110impl<S, F> fmt::Debug for InterceptedService<S, F>
111where
112    S: fmt::Debug,
113{
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        f.debug_struct("InterceptedService")
116            .field("inner", &self.inner)
117            .field("f", &format_args!("{}", std::any::type_name::<F>()))
118            .finish()
119    }
120}
121
122impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
123where
124    ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
125    F: Interceptor,
126    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
127    S::Error: Into<crate::Error>,
128    ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
129    ResBody::Error: Into<crate::Error>,
130{
131    type Response = http::Response<BoxBody>;
132    type Error = S::Error;
133    type Future = ResponseFuture<S::Future>;
134
135    #[inline]
136    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137        self.inner.poll_ready(cx)
138    }
139
140    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
141        // It is bad practice to modify the body (i.e. Message) of the request via an interceptor.
142        // To avoid exposing the body of the request to the interceptor function, we first remove it
143        // here, allow the interceptor to modify the metadata and extensions, and then recreate the
144        // HTTP request with the body. Tonic requests do not preserve the URI, HTTP version, and
145        // HTTP method of the HTTP request, so we extract them here and then add them back in below.
146        let uri = req.uri().clone();
147        let method = req.method().clone();
148        let version = req.version();
149        let req = crate::Request::from_http(req);
150        let (metadata, extensions, msg) = req.into_parts();
151
152        match self
153            .f
154            .call(crate::Request::from_parts(metadata, extensions, ()))
155        {
156            Ok(req) => {
157                let (metadata, extensions, _) = req.into_parts();
158                let req = crate::Request::from_parts(metadata, extensions, msg);
159                let req = req.into_http(uri, method, version, SanitizeHeaders::No);
160                ResponseFuture::future(self.inner.call(req))
161            }
162            Err(status) => ResponseFuture::status(status),
163        }
164    }
165}
166
167// required to use `InterceptedService` with `Router`
168impl<S, F> crate::server::NamedService for InterceptedService<S, F>
169where
170    S: crate::server::NamedService,
171{
172    const NAME: &'static str = S::NAME;
173}
174
175/// Response future for [`InterceptedService`].
176#[pin_project]
177#[derive(Debug)]
178pub struct ResponseFuture<F> {
179    #[pin]
180    kind: Kind<F>,
181}
182
183impl<F> ResponseFuture<F> {
184    fn future(future: F) -> Self {
185        Self {
186            kind: Kind::Future(future),
187        }
188    }
189
190    fn status(status: Status) -> Self {
191        Self {
192            kind: Kind::Status(Some(status)),
193        }
194    }
195}
196
197#[pin_project(project = KindProj)]
198#[derive(Debug)]
199enum Kind<F> {
200    Future(#[pin] F),
201    Status(Option<Status>),
202}
203
204impl<F, E, B> Future for ResponseFuture<F>
205where
206    F: Future<Output = Result<http::Response<B>, E>>,
207    E: Into<crate::Error>,
208    B: Default + http_body::Body<Data = Bytes> + Send + 'static,
209    B::Error: Into<crate::Error>,
210{
211    type Output = Result<http::Response<BoxBody>, E>;
212
213    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214        match self.project().kind.project() {
215            KindProj::Future(future) => future
216                .poll(cx)
217                .map(|result| result.map(|res| res.map(boxed))),
218            KindProj::Status(status) => {
219                let response = status
220                    .take()
221                    .unwrap()
222                    .into_http()
223                    .map(|_| B::default())
224                    .map(boxed);
225                Poll::Ready(Ok(response))
226            }
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use http_body::Frame;
235    use http_body_util::Empty;
236    use tower::ServiceExt;
237
238    #[derive(Debug, Default)]
239    struct TestBody;
240
241    impl http_body::Body for TestBody {
242        type Data = Bytes;
243        type Error = Status;
244
245        fn poll_frame(
246            self: Pin<&mut Self>,
247            _cx: &mut Context<'_>,
248        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249            Poll::Ready(None)
250        }
251    }
252
253    #[tokio::test]
254    async fn doesnt_remove_headers_from_requests() {
255        let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
256            assert_eq!(
257                request
258                    .headers()
259                    .get("user-agent")
260                    .expect("missing in leaf service"),
261                "test-tonic"
262            );
263
264            Ok::<_, Status>(http::Response::new(TestBody))
265        });
266
267        let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
268            assert_eq!(
269                request
270                    .metadata()
271                    .get("user-agent")
272                    .expect("missing in interceptor"),
273                "test-tonic"
274            );
275
276            Ok(request)
277        });
278
279        let request = http::Request::builder()
280            .header("user-agent", "test-tonic")
281            .body(TestBody)
282            .unwrap();
283
284        svc.oneshot(request).await.unwrap();
285    }
286
287    #[tokio::test]
288    async fn handles_intercepted_status_as_response() {
289        let message = "Blocked by the interceptor";
290        let expected = Status::permission_denied(message).into_http();
291
292        let svc = tower::service_fn(|_: http::Request<TestBody>| async {
293            Ok::<_, Status>(http::Response::new(TestBody))
294        });
295
296        let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
297            Err(Status::permission_denied(message))
298        });
299
300        let request = http::Request::builder().body(TestBody).unwrap();
301        let response = svc.oneshot(request).await.unwrap();
302
303        assert_eq!(expected.status(), response.status());
304        assert_eq!(expected.version(), response.version());
305        assert_eq!(expected.headers(), response.headers());
306    }
307
308    #[tokio::test]
309    async fn doesnt_change_http_method() {
310        let svc = tower::service_fn(|request: http::Request<Empty<()>>| async move {
311            assert_eq!(request.method(), http::Method::OPTIONS);
312
313            Ok::<_, hyper::Error>(hyper::Response::new(Empty::new()))
314        });
315
316        let svc = InterceptedService::new(svc, Ok);
317
318        let request = http::Request::builder()
319            .method(http::Method::OPTIONS)
320            .body(Empty::new())
321            .unwrap();
322
323        svc.oneshot(request).await.unwrap();
324    }
325}