1use crate::{
6 body::{boxed, BoxBody},
7 request::SanitizeHeaders,
8 Status,
9};
10use bytes::Bytes;
11use pin_project::pin_project;
12use std::{
13 fmt,
14 future::Future,
15 pin::Pin,
16 task::{Context, Poll},
17};
18use tower_layer::Layer;
19use tower_service::Service;
20
21pub trait Interceptor {
47 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
49}
50
51impl<F> Interceptor for F
52where
53 F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
54{
55 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
56 self(request)
57 }
58}
59
60pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
64where
65 F: Interceptor,
66{
67 InterceptorLayer { f }
68}
69
70#[derive(Debug, Clone, Copy)]
75pub struct InterceptorLayer<F> {
76 f: F,
77}
78
79impl<S, F> Layer<S> for InterceptorLayer<F>
80where
81 F: Interceptor + Clone,
82{
83 type Service = InterceptedService<S, F>;
84
85 fn layer(&self, service: S) -> Self::Service {
86 InterceptedService::new(service, self.f.clone())
87 }
88}
89
90#[derive(Clone, Copy)]
94pub struct InterceptedService<S, F> {
95 inner: S,
96 f: F,
97}
98
99impl<S, F> InterceptedService<S, F> {
100 pub fn new(service: S, f: F) -> Self
103 where
104 F: Interceptor,
105 {
106 Self { inner: service, f }
107 }
108}
109
110impl<S, F> fmt::Debug for InterceptedService<S, F>
111where
112 S: fmt::Debug,
113{
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 f.debug_struct("InterceptedService")
116 .field("inner", &self.inner)
117 .field("f", &format_args!("{}", std::any::type_name::<F>()))
118 .finish()
119 }
120}
121
122impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
123where
124 ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
125 F: Interceptor,
126 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
127 S::Error: Into<crate::Error>,
128 ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
129 ResBody::Error: Into<crate::Error>,
130{
131 type Response = http::Response<BoxBody>;
132 type Error = S::Error;
133 type Future = ResponseFuture<S::Future>;
134
135 #[inline]
136 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
137 self.inner.poll_ready(cx)
138 }
139
140 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
141 let uri = req.uri().clone();
147 let method = req.method().clone();
148 let version = req.version();
149 let req = crate::Request::from_http(req);
150 let (metadata, extensions, msg) = req.into_parts();
151
152 match self
153 .f
154 .call(crate::Request::from_parts(metadata, extensions, ()))
155 {
156 Ok(req) => {
157 let (metadata, extensions, _) = req.into_parts();
158 let req = crate::Request::from_parts(metadata, extensions, msg);
159 let req = req.into_http(uri, method, version, SanitizeHeaders::No);
160 ResponseFuture::future(self.inner.call(req))
161 }
162 Err(status) => ResponseFuture::status(status),
163 }
164 }
165}
166
167impl<S, F> crate::server::NamedService for InterceptedService<S, F>
169where
170 S: crate::server::NamedService,
171{
172 const NAME: &'static str = S::NAME;
173}
174
175#[pin_project]
177#[derive(Debug)]
178pub struct ResponseFuture<F> {
179 #[pin]
180 kind: Kind<F>,
181}
182
183impl<F> ResponseFuture<F> {
184 fn future(future: F) -> Self {
185 Self {
186 kind: Kind::Future(future),
187 }
188 }
189
190 fn status(status: Status) -> Self {
191 Self {
192 kind: Kind::Status(Some(status)),
193 }
194 }
195}
196
197#[pin_project(project = KindProj)]
198#[derive(Debug)]
199enum Kind<F> {
200 Future(#[pin] F),
201 Status(Option<Status>),
202}
203
204impl<F, E, B> Future for ResponseFuture<F>
205where
206 F: Future<Output = Result<http::Response<B>, E>>,
207 E: Into<crate::Error>,
208 B: Default + http_body::Body<Data = Bytes> + Send + 'static,
209 B::Error: Into<crate::Error>,
210{
211 type Output = Result<http::Response<BoxBody>, E>;
212
213 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
214 match self.project().kind.project() {
215 KindProj::Future(future) => future
216 .poll(cx)
217 .map(|result| result.map(|res| res.map(boxed))),
218 KindProj::Status(status) => {
219 let response = status
220 .take()
221 .unwrap()
222 .into_http()
223 .map(|_| B::default())
224 .map(boxed);
225 Poll::Ready(Ok(response))
226 }
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use http_body::Frame;
235 use http_body_util::Empty;
236 use tower::ServiceExt;
237
238 #[derive(Debug, Default)]
239 struct TestBody;
240
241 impl http_body::Body for TestBody {
242 type Data = Bytes;
243 type Error = Status;
244
245 fn poll_frame(
246 self: Pin<&mut Self>,
247 _cx: &mut Context<'_>,
248 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
249 Poll::Ready(None)
250 }
251 }
252
253 #[tokio::test]
254 async fn doesnt_remove_headers_from_requests() {
255 let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
256 assert_eq!(
257 request
258 .headers()
259 .get("user-agent")
260 .expect("missing in leaf service"),
261 "test-tonic"
262 );
263
264 Ok::<_, Status>(http::Response::new(TestBody))
265 });
266
267 let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
268 assert_eq!(
269 request
270 .metadata()
271 .get("user-agent")
272 .expect("missing in interceptor"),
273 "test-tonic"
274 );
275
276 Ok(request)
277 });
278
279 let request = http::Request::builder()
280 .header("user-agent", "test-tonic")
281 .body(TestBody)
282 .unwrap();
283
284 svc.oneshot(request).await.unwrap();
285 }
286
287 #[tokio::test]
288 async fn handles_intercepted_status_as_response() {
289 let message = "Blocked by the interceptor";
290 let expected = Status::permission_denied(message).into_http();
291
292 let svc = tower::service_fn(|_: http::Request<TestBody>| async {
293 Ok::<_, Status>(http::Response::new(TestBody))
294 });
295
296 let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
297 Err(Status::permission_denied(message))
298 });
299
300 let request = http::Request::builder().body(TestBody).unwrap();
301 let response = svc.oneshot(request).await.unwrap();
302
303 assert_eq!(expected.status(), response.status());
304 assert_eq!(expected.version(), response.version());
305 assert_eq!(expected.headers(), response.headers());
306 }
307
308 #[tokio::test]
309 async fn doesnt_change_http_method() {
310 let svc = tower::service_fn(|request: http::Request<Empty<()>>| async move {
311 assert_eq!(request.method(), http::Method::OPTIONS);
312
313 Ok::<_, hyper::Error>(hyper::Response::new(Empty::new()))
314 });
315
316 let svc = InterceptedService::new(svc, Ok);
317
318 let request = http::Request::builder()
319 .method(http::Method::OPTIONS)
320 .body(Empty::new())
321 .unwrap();
322
323 svc.oneshot(request).await.unwrap();
324 }
325}