axum_extra/extract/
cached.rs

1use axum::{
2    async_trait,
3    extract::{Extension, FromRequestParts},
4};
5use http::request::Parts;
6
7/// Cache results of other extractors.
8///
9/// `Cached` wraps another extractor and caches its result in [request extensions].
10///
11/// This is useful if you have a tree of extractors that share common sub-extractors that
12/// you only want to run once, perhaps because they're expensive.
13///
14/// The cache purely type based so you can only cache one value of each type. The cache is also
15/// local to the current request and not reused across requests.
16///
17/// # Example
18///
19/// ```rust
20/// use axum_extra::extract::Cached;
21/// use axum::{
22///     async_trait,
23///     extract::FromRequestParts,
24///     response::{IntoResponse, Response},
25///     http::{StatusCode, request::Parts},
26/// };
27///
28/// #[derive(Clone)]
29/// struct Session { /* ... */ }
30///
31/// #[async_trait]
32/// impl<S> FromRequestParts<S> for Session
33/// where
34///     S: Send + Sync,
35/// {
36///     type Rejection = (StatusCode, String);
37///
38///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
39///         // load session...
40///         # unimplemented!()
41///     }
42/// }
43///
44/// struct CurrentUser { /* ... */ }
45///
46/// #[async_trait]
47/// impl<S> FromRequestParts<S> for CurrentUser
48/// where
49///     S: Send + Sync,
50/// {
51///     type Rejection = Response;
52///
53///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
54///         // loading a `CurrentUser` requires first loading the `Session`
55///         //
56///         // by using `Cached<Session>` we avoid extracting the session more than
57///         // once, in case other extractors for the same request also loads the session
58///         let session: Session = Cached::<Session>::from_request_parts(parts, state)
59///             .await
60///             .map_err(|err| err.into_response())?
61///             .0;
62///
63///         // load user from session...
64///         # unimplemented!()
65///     }
66/// }
67///
68/// // handler that extracts the current user and the session
69/// //
70/// // the session will only be loaded once, even though `CurrentUser`
71/// // also loads it
72/// async fn handler(
73///     current_user: CurrentUser,
74///     // we have to use `Cached<Session>` here otherwise the
75///     // cached session would not be used
76///     Cached(session): Cached<Session>,
77/// ) {
78///     // ...
79/// }
80/// ```
81///
82/// [request extensions]: http::Extensions
83#[derive(Debug, Clone, Default)]
84pub struct Cached<T>(pub T);
85
86#[derive(Clone)]
87struct CachedEntry<T>(T);
88
89#[async_trait]
90impl<S, T> FromRequestParts<S> for Cached<T>
91where
92    S: Send + Sync,
93    T: FromRequestParts<S> + Clone + Send + Sync + 'static,
94{
95    type Rejection = T::Rejection;
96
97    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
98        match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
99            Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
100            Err(_) => {
101                let value = T::from_request_parts(parts, state).await?;
102                parts.extensions.insert(CachedEntry(value.clone()));
103                Ok(Self(value))
104            }
105        }
106    }
107}
108
109axum_core::__impl_deref!(Cached);
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use axum::{http::Request, routing::get, Router};
115    use std::{
116        convert::Infallible,
117        sync::atomic::{AtomicU32, Ordering},
118        time::Instant,
119    };
120
121    #[tokio::test]
122    async fn works() {
123        static COUNTER: AtomicU32 = AtomicU32::new(0);
124
125        #[derive(Clone, Debug, PartialEq, Eq)]
126        struct Extractor(Instant);
127
128        #[async_trait]
129        impl<S> FromRequestParts<S> for Extractor
130        where
131            S: Send + Sync,
132        {
133            type Rejection = Infallible;
134
135            async fn from_request_parts(
136                _parts: &mut Parts,
137                _state: &S,
138            ) -> Result<Self, Self::Rejection> {
139                COUNTER.fetch_add(1, Ordering::SeqCst);
140                Ok(Self(Instant::now()))
141            }
142        }
143
144        let (mut parts, _) = Request::new(()).into_parts();
145
146        let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
147            .await
148            .unwrap()
149            .0;
150        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
151
152        let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
153            .await
154            .unwrap()
155            .0;
156        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
157
158        assert_eq!(first, second);
159    }
160
161    // Not a #[test], we just want to know this compiles
162    async fn _last_handler_argument() {
163        async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
164        let _r: Router = Router::new().route("/", get(handler));
165    }
166}