tower_http/auth/async_require_authorization.rs
1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
9//! use http::{Request, Response, StatusCode, header::AUTHORIZATION};
10//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn, BoxError};
11//! use futures_core::future::BoxFuture;
12//! use bytes::Bytes;
13//! use http_body_util::Full;
14//!
15//! #[derive(Clone, Copy)]
16//! struct MyAuth;
17//!
18//! impl<B> AsyncAuthorizeRequest<B> for MyAuth
19//! where
20//! B: Send + Sync + 'static,
21//! {
22//! type RequestBody = B;
23//! type ResponseBody = Full<Bytes>;
24//! type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
25//!
26//! fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
27//! Box::pin(async {
28//! if let Some(user_id) = check_auth(&request).await {
29//! // Set `user_id` as a request extension so it can be accessed by other
30//! // services down the stack.
31//! request.extensions_mut().insert(user_id);
32//!
33//! Ok(request)
34//! } else {
35//! let unauthorized_response = Response::builder()
36//! .status(StatusCode::UNAUTHORIZED)
37//! .body(Full::<Bytes>::default())
38//! .unwrap();
39//!
40//! Err(unauthorized_response)
41//! }
42//! })
43//! }
44//! }
45//!
46//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
47//! // ...
48//! # None
49//! }
50//!
51//! #[derive(Debug, Clone)]
52//! struct UserId(String);
53//!
54//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
55//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
56//! // request was authorized and `UserId` will be present.
57//! let user_id = request
58//! .extensions()
59//! .get::<UserId>()
60//! .expect("UserId will be there if request was authorized");
61//!
62//! println!("request from {:?}", user_id);
63//!
64//! Ok(Response::new(Full::default()))
65//! }
66//!
67//! # #[tokio::main]
68//! # async fn main() -> Result<(), BoxError> {
69//! let service = ServiceBuilder::new()
70//! // Authorize requests using `MyAuth`
71//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
72//! .service_fn(handle);
73//! # Ok(())
74//! # }
75//! ```
76//!
77//! Or using a closure:
78//!
79//! ```
80//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use http::{Request, Response, StatusCode};
82//! use tower::{Service, ServiceExt, ServiceBuilder, BoxError};
83//! use futures_core::future::BoxFuture;
84//! use http_body_util::Full;
85//! use bytes::Bytes;
86//!
87//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
88//! // ...
89//! # None
90//! }
91//!
92//! #[derive(Debug)]
93//! struct UserId(String);
94//!
95//! async fn handle(request: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
96//! # todo!();
97//! // ...
98//! }
99//!
100//! # #[tokio::main]
101//! # async fn main() -> Result<(), BoxError> {
102//! let service = ServiceBuilder::new()
103//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request<Full<Bytes>>| async move {
104//! if let Some(user_id) = check_auth(&request).await {
105//! Ok(request)
106//! } else {
107//! let unauthorized_response = Response::builder()
108//! .status(StatusCode::UNAUTHORIZED)
109//! .body(Full::<Bytes>::default())
110//! .unwrap();
111//!
112//! Err(unauthorized_response)
113//! }
114//! }))
115//! .service_fn(handle);
116//! # Ok(())
117//! # }
118//! ```
119
120use http::{Request, Response};
121use pin_project_lite::pin_project;
122use std::{
123 future::Future,
124 mem,
125 pin::Pin,
126 task::{ready, Context, Poll},
127};
128use tower_layer::Layer;
129use tower_service::Service;
130
131/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
132/// [`Authorization`] header.
133///
134/// See the [module docs](crate::auth::async_require_authorization) for an example.
135///
136/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
137#[derive(Debug, Clone)]
138pub struct AsyncRequireAuthorizationLayer<T> {
139 auth: T,
140}
141
142impl<T> AsyncRequireAuthorizationLayer<T> {
143 /// Authorize requests using a custom scheme.
144 pub fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
145 Self { auth }
146 }
147}
148
149impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
150where
151 T: Clone,
152{
153 type Service = AsyncRequireAuthorization<S, T>;
154
155 fn layer(&self, inner: S) -> Self::Service {
156 AsyncRequireAuthorization::new(inner, self.auth.clone())
157 }
158}
159
160/// Middleware that authorizes all requests using the [`Authorization`] header.
161///
162/// See the [module docs](crate::auth::async_require_authorization) for an example.
163///
164/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
165#[derive(Clone, Debug)]
166pub struct AsyncRequireAuthorization<S, T> {
167 inner: S,
168 auth: T,
169}
170
171impl<S, T> AsyncRequireAuthorization<S, T> {
172 define_inner_service_accessors!();
173}
174
175impl<S, T> AsyncRequireAuthorization<S, T> {
176 /// Authorize requests using a custom scheme.
177 ///
178 /// The `Authorization` header is required to have the value provided.
179 pub fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
180 Self { inner, auth }
181 }
182
183 /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`]
184 /// middleware.
185 ///
186 /// [`Layer`]: tower_layer::Layer
187 pub fn layer(auth: T) -> AsyncRequireAuthorizationLayer<T> {
188 AsyncRequireAuthorizationLayer::new(auth)
189 }
190}
191
192impl<ReqBody, ResBody, S, Auth> Service<Request<ReqBody>> for AsyncRequireAuthorization<S, Auth>
193where
194 Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = ResBody>,
195 S: Service<Request<Auth::RequestBody>, Response = Response<ResBody>> + Clone,
196{
197 type Response = Response<ResBody>;
198 type Error = S::Error;
199 type Future = ResponseFuture<Auth, S, ReqBody>;
200
201 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 self.inner.poll_ready(cx)
203 }
204
205 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
206 let mut inner = self.inner.clone();
207 let authorize = self.auth.authorize(req);
208 // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
209 mem::swap(&mut self.inner, &mut inner);
210
211 ResponseFuture {
212 state: State::Authorize { authorize },
213 service: inner,
214 }
215 }
216}
217
218pin_project! {
219 /// Response future for [`AsyncRequireAuthorization`].
220 pub struct ResponseFuture<Auth, S, ReqBody>
221 where
222 Auth: AsyncAuthorizeRequest<ReqBody>,
223 S: Service<Request<Auth::RequestBody>>,
224 {
225 #[pin]
226 state: State<Auth::Future, S::Future>,
227 service: S,
228 }
229}
230
231pin_project! {
232 #[project = StateProj]
233 enum State<A, SFut> {
234 Authorize {
235 #[pin]
236 authorize: A,
237 },
238 Authorized {
239 #[pin]
240 fut: SFut,
241 },
242 }
243}
244
245impl<Auth, S, ReqBody, B> Future for ResponseFuture<Auth, S, ReqBody>
246where
247 Auth: AsyncAuthorizeRequest<ReqBody, ResponseBody = B>,
248 S: Service<Request<Auth::RequestBody>, Response = Response<B>>,
249{
250 type Output = Result<Response<B>, S::Error>;
251
252 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
253 let mut this = self.project();
254
255 loop {
256 match this.state.as_mut().project() {
257 StateProj::Authorize { authorize } => {
258 let auth = ready!(authorize.poll(cx));
259 match auth {
260 Ok(req) => {
261 let fut = this.service.call(req);
262 this.state.set(State::Authorized { fut })
263 }
264 Err(res) => {
265 return Poll::Ready(Ok(res));
266 }
267 };
268 }
269 StateProj::Authorized { fut } => {
270 return fut.poll(cx);
271 }
272 }
273 }
274 }
275}
276
277/// Trait for authorizing requests.
278pub trait AsyncAuthorizeRequest<B> {
279 /// The type of request body returned by `authorize`.
280 ///
281 /// Set this to `B` unless you need to change the request body type.
282 type RequestBody;
283
284 /// The body type used for responses to unauthorized requests.
285 type ResponseBody;
286
287 /// The Future type returned by `authorize`
288 type Future: Future<Output = Result<Request<Self::RequestBody>, Response<Self::ResponseBody>>>;
289
290 /// Authorize the request.
291 ///
292 /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
293 fn authorize(&mut self, request: Request<B>) -> Self::Future;
294}
295
296impl<B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<B> for F
297where
298 F: FnMut(Request<B>) -> Fut,
299 Fut: Future<Output = Result<Request<ReqBody>, Response<ResBody>>>,
300{
301 type RequestBody = ReqBody;
302 type ResponseBody = ResBody;
303 type Future = Fut;
304
305 fn authorize(&mut self, request: Request<B>) -> Self::Future {
306 self(request)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 #[allow(unused_imports)]
313 use super::*;
314 use crate::test_helpers::Body;
315 use futures_core::future::BoxFuture;
316 use http::{header, StatusCode};
317 use tower::{BoxError, ServiceBuilder, ServiceExt};
318
319 #[derive(Clone, Copy)]
320 struct MyAuth;
321
322 impl<B> AsyncAuthorizeRequest<B> for MyAuth
323 where
324 B: Send + 'static,
325 {
326 type RequestBody = B;
327 type ResponseBody = Body;
328 type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
329
330 fn authorize(&mut self, request: Request<B>) -> Self::Future {
331 Box::pin(async move {
332 let authorized = request
333 .headers()
334 .get(header::AUTHORIZATION)
335 .and_then(|auth| auth.to_str().ok()?.strip_prefix("Bearer "))
336 == Some("69420");
337
338 if authorized {
339 Ok(request)
340 } else {
341 Err(Response::builder()
342 .status(StatusCode::UNAUTHORIZED)
343 .body(Body::empty())
344 .unwrap())
345 }
346 })
347 }
348 }
349
350 #[tokio::test]
351 async fn require_async_auth_works() {
352 let mut service = ServiceBuilder::new()
353 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
354 .service_fn(echo);
355
356 let request = Request::get("/")
357 .header(header::AUTHORIZATION, "Bearer 69420")
358 .body(Body::empty())
359 .unwrap();
360
361 let res = service.ready().await.unwrap().call(request).await.unwrap();
362
363 assert_eq!(res.status(), StatusCode::OK);
364 }
365
366 #[tokio::test]
367 async fn require_async_auth_401() {
368 let mut service = ServiceBuilder::new()
369 .layer(AsyncRequireAuthorizationLayer::new(MyAuth))
370 .service_fn(echo);
371
372 let request = Request::get("/")
373 .header(header::AUTHORIZATION, "Bearer deez")
374 .body(Body::empty())
375 .unwrap();
376
377 let res = service.ready().await.unwrap().call(request).await.unwrap();
378
379 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
380 }
381
382 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
383 Ok(Response::new(req.into_body()))
384 }
385}