tower_http/auth/async_require_authorization.rs
1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//! use futures_core::future::BoxFuture;
12//! use bytes::Bytes;
13//! use http_body_util::Full;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//!     B: Send + Sync + 'static,
21//! {
22//!     type RequestBody = B;
23//!     type ResponseBody = Full<Bytes>;
24//!     type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
25//!
26//!     fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
27//!         Box::pin(async {
28//!             if let Some(user_id) = check_auth(&request).await {
29//!                 // Set `user_id` as a request extension so it can be accessed by other
30//!                 // services down the stack.
31//!                 request.extensions_mut().insert(user_id);
32//!
33//!                 Ok(request)
34//!             } else {
35//!                 let unauthorized_response = Response::builder()
36//!                     .status(StatusCode::UNAUTHORIZED)
37//!                     .body(Full::<Bytes>::default())
38//!                     .unwrap();
39//!
40//!                 Err(unauthorized_response)
41//!             }
42//!         })
43//!     }
44//! }
45//!
46//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
47//!     // ...
48//!     # None
49//! }
50//!
51//! #[derive(Debug, Clone)]
52//! struct UserId(String);
53//!
54//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
55//!     // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
56//!     // request was authorized and `UserId` will be present.
57//!     let user_id = request
58//!         .extensions()
59//!         .get::<UserId>()
60//!         .expect("UserId will be there if request was authorized");
61//!
62//!     println!("request from {:?}", user_id);
63//!
64//!     Ok(Response::new(Full::default()))
65//! }
66//!
67//! # #[tokio::main]
68//! # async fn main() -> Result<(), BoxError> {
69//! let service = ServiceBuilder::new()
70//!     // Authorize requests using `MyAuth`
71//!     .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
72//!     .service_fn(handle);
73//! # Ok(())
74//! # }
75//! ```
76//!
77//! Or using a closure:
78//!
79//! ```
80//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use http::{Request, Response, StatusCode};
82//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError};
83//! use futures_core::future::BoxFuture;
84//! use http_body_util::Full;
85//! use bytes::Bytes;
86//!
87//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
88//!     // ...
89//!     # None
90//! }
91//!
92//! #[derive(Debug)]
93//! struct UserId(String);
94//!
95//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
96//!     # todo!();
97//!     // ...
98//! }
99//!
100//! # #[tokio::main]
101//! # async fn main() -> Result<(), BoxError> {
102//! let service = ServiceBuilder::new()
103//!     .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
104//!         if let Some(user_id) = check_auth(&request).await {
105//!             Ok(request)
106//!         } else {
107//!             let unauthorized_response = Response::builder()
108//!                 .status(StatusCode::UNAUTHORIZED)
109//!                 .body(Full::<Bytes>::default())
110//!                 .unwrap();
111//!
112//!             Err(unauthorized_response)
113//!         }
114//!     }))
115//!     .service_fn(handle);
116//! # Ok(())
117//! # }
118//! ```
119
120use http::{Request, Response};
121use pin_project_lite::pin_project;
122use std::{
123    future::Future,
124    mem,
125    pin::Pin,
126    task::{ready, Context, Poll},
127};
128use tower_layer::Layer;
129use tower_service::Service;
130
131/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
132/// [`Authorization`] header.
133///
134/// See the [module docs](crate::auth::async_require_authorization) for an example.
135///
136/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
137#[derive(Debug, Clone)]
138pub struct AsyncRequireAuthorizationLayer<T> {
139    auth: T,
140}
141
142impl<T> AsyncRequireAuthorizationLayer<T> {
143    /// Authorize requests using a custom scheme.
144    pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
145        Self { auth }
146    }
147}
148
149impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
150where
151    T: Clone,
152{
153    type Service = AsyncRequireAuthorization<S, T>;
154
155    fn layer(&self, inner: S) -> Self::Service {
156        AsyncRequireAuthorization::new(inner, self.auth.clone())
157    }
158}
159
160/// Middleware that authorizes all requests using the [`Authorization`] header.
161///
162/// See the [module docs](crate::auth::async_require_authorization) for an example.
163///
164/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
165#[derive(Clone, Debug)]
166pub struct AsyncRequireAuthorization<S, T> {
167    inner: S,
168    auth: T,
169}
170
171impl<S, T> AsyncRequireAuthorization<S, T> {
172    define_inner_service_accessors!();
173}
174
175impl<S, T> AsyncRequireAuthorization<S, T> {
176    /// Authorize requests using a custom scheme.
177    ///
178    /// The `Authorization` header is required to have the value provided.
179    pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
180        Self { inner, auth }
181    }
182
183    /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
184    /// middleware.
185    ///
186    /// [`Layer`]: tower_layer::Layer
187    pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
188        AsyncRequireAuthorizationLayer::new(auth)
189    }
190}
191
192impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
193where
194    Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
195    S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
196{
197    type Response = Response<ResBody>;
198    type Error = S::Error;
199    type Future = ResponseFuture<Auth, S, ReqBody>;
200
201    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202        self.inner.poll_ready(cx)
203    }
204
205    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
206        let mut inner = self.inner.clone();
207        let authorize = self.auth.authorize(req);
208        // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
209        mem::swap(&mut self.inner, &mut inner);
210
211        ResponseFuture {
212            state: State::Authorize { authorize },
213            service: inner,
214        }
215    }
216}
217
218pin_project! {
219    /// Response future for [`AsyncRequireAuthorization`].
220    pub struct ResponseFuture<Auth, S, ReqBody>
221    where
222        Auth: AsyncAuthorizeRequest<ReqBody>,
223        S: Service<Request<Auth::RequestBody>>,
224    {
225        #[pin]
226        state: State<Auth::Future, S::Future>,
227        service: S,
228    }
229}
230
231pin_project! {
232    #[project = StateProj]
233    enum State<A, SFut> {
234        Authorize {
235            #[pin]
236            authorize: A,
237        },
238        Authorized {
239            #[pin]
240            fut: SFut,
241        },
242    }
243}
244
245impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
246where
247    Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = B>,
248    S: Service<Request<Auth::RequestBody>, Response = Response<B>>,
249{
250    type Output = Result<Response<B>, S::Error>;
251
252    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253        let mut this = self.project();
254
255        loop {
256            match this.state.as_mut().project() {
257                StateProj::Authorize { authorize } => {
258                    let auth = ready!(authorize.poll(cx));
259                    match auth {
260                        Ok(req) => {
261                            let fut = this.service.call(req);
262                            this.state.set(State::Authorized { fut })
263                        }
264                        Err(res) => {
265                            return Poll::Ready(Ok(res));
266                        }
267                    };
268                }
269                StateProj::Authorized { fut } => {
270                    return fut.poll(cx);
271                }
272            }
273        }
274    }
275}
276
277/// Trait for authorizing requests.
278pub trait AsyncAuthorizeRequest<B> {
279    /// The type of request body returned by `authorize`.
280    ///
281    /// Set this to `B` unless you need to change the request body type.
282    type RequestBody;
283
284    /// The body type used for responses to unauthorized requests.
285    type ResponseBody;
286
287    /// The Future type returned by `authorize`
288    type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
289
290    /// Authorize the request.
291    ///
292    /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
293    fn authorize(&mut self, request: Request<B>) -> Self::Future;
294}
295
296impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
297where
298    F: FnMut(Request<B>) -> Fut,
299    Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
300{
301    type RequestBody = ReqBody;
302    type ResponseBody = ResBody;
303    type Future = Fut;
304
305    fn authorize(&mut self, request: Request<B>) -> Self::Future {
306        self(request)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    #[allow(unused_imports)]
313    use super::*;
314    use crate::test_helpers::Body;
315    use futures_core::future::BoxFuture;
316    use http::{header, StatusCode};
317    use tower::{BoxError, ServiceBuilder, ServiceExt};
318
319    #[derive(Clone, Copy)]
320    struct MyAuth;
321
322    impl<B> AsyncAuthorizeRequest<B> for MyAuth
323    where
324        B: Send + 'static,
325    {
326        type RequestBody = B;
327        type ResponseBody = Body;
328        type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
329
330        fn authorize(&mut self, request: Request<B>) -> Self::Future {
331            Box::pin(async move {
332                let authorized = request
333                    .headers()
334                    .get(header::AUTHORIZATION)
335                    .and_then(|auth| auth.to_str().ok()?.strip_prefix("Bearer "))
336                    == Some("69420");
337
338                if authorized {
339                    Ok(request)
340                } else {
341                    Err(Response::builder()
342                        .status(StatusCode::UNAUTHORIZED)
343                        .body(Body::empty())
344                        .unwrap())
345                }
346            })
347        }
348    }
349
350    #[tokio::test]
351    async fn require_async_auth_works() {
352        let mut service = ServiceBuilder::new()
353            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
354            .service_fn(echo);
355
356        let request = Request::get("/")
357            .header(header::AUTHORIZATION, "Bearer 69420")
358            .body(Body::empty())
359            .unwrap();
360
361        let res = service.ready().await.unwrap().call(request).await.unwrap();
362
363        assert_eq!(res.status(), StatusCode::OK);
364    }
365
366    #[tokio::test]
367    async fn require_async_auth_401() {
368        let mut service = ServiceBuilder::new()
369            .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
370            .service_fn(echo);
371
372        let request = Request::get("/")
373            .header(header::AUTHORIZATION, "Bearer deez")
374            .body(Body::empty())
375            .unwrap();
376
377        let res = service.ready().await.unwrap().call(request).await.unwrap();
378
379        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
380    }
381
382    async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
383        Ok(Response::new(req.into_body()))
384    }
385}