axum_extra/extract/cached.rs
1use axum::{
2 async_trait,
3 extract::{Extension, FromRequestParts},
4};
5use http::request::Parts;
6
7/// Cache results of other extractors.
8///
9/// `Cached` wraps another extractor and caches its result in [request extensions].
10///
11/// This is useful if you have a tree of extractors that share common sub-extractors that
12/// you only want to run once, perhaps because they're expensive.
13///
14/// The cache purely type based so you can only cache one value of each type. The cache is also
15/// local to the current request and not reused across requests.
16///
17/// # Example
18///
19/// ```rust
20/// use axum_extra::extract::Cached;
21/// use axum::{
22/// async_trait,
23/// extract::FromRequestParts,
24/// response::{IntoResponse, Response},
25/// http::{StatusCode, request::Parts},
26/// };
27///
28/// #[derive(Clone)]
29/// struct Session { /* ... */ }
30///
31/// #[async_trait]
32/// impl<S> FromRequestParts<S> for Session
33/// where
34/// S: Send + Sync,
35/// {
36/// type Rejection = (StatusCode, String);
37///
38/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
39/// // load session...
40/// # unimplemented!()
41/// }
42/// }
43///
44/// struct CurrentUser { /* ... */ }
45///
46/// #[async_trait]
47/// impl<S> FromRequestParts<S> for CurrentUser
48/// where
49/// S: Send + Sync,
50/// {
51/// type Rejection = Response;
52///
53/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
54/// // loading a `CurrentUser` requires first loading the `Session`
55/// //
56/// // by using `Cached<Session>` we avoid extracting the session more than
57/// // once, in case other extractors for the same request also loads the session
58/// let session: Session = Cached::<Session>::from_request_parts(parts, state)
59/// .await
60/// .map_err(|err| err.into_response())?
61/// .0;
62///
63/// // load user from session...
64/// # unimplemented!()
65/// }
66/// }
67///
68/// // handler that extracts the current user and the session
69/// //
70/// // the session will only be loaded once, even though `CurrentUser`
71/// // also loads it
72/// async fn handler(
73/// current_user: CurrentUser,
74/// // we have to use `Cached<Session>` here otherwise the
75/// // cached session would not be used
76/// Cached(session): Cached<Session>,
77/// ) {
78/// // ...
79/// }
80/// ```
81///
82/// [request extensions]: http::Extensions
83#[derive(Debug, Clone, Default)]
84pub struct Cached<T>(pub T);
85
86#[derive(Clone)]
87struct CachedEntry<T>(T);
88
89#[async_trait]
90impl<S, T> FromRequestParts<S> for Cached<T>
91where
92 S: Send + Sync,
93 T: FromRequestParts<S> + Clone + Send + Sync + 'static,
94{
95 type Rejection = T::Rejection;
96
97 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
98 match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
99 Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
100 Err(_) => {
101 let value = T::from_request_parts(parts, state).await?;
102 parts.extensions.insert(CachedEntry(value.clone()));
103 Ok(Self(value))
104 }
105 }
106 }
107}
108
109axum_core::__impl_deref!(Cached);
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use axum::{http::Request, routing::get, Router};
115 use std::{
116 convert::Infallible,
117 sync::atomic::{AtomicU32, Ordering},
118 time::Instant,
119 };
120
121 #[tokio::test]
122 async fn works() {
123 static COUNTER: AtomicU32 = AtomicU32::new(0);
124
125 #[derive(Clone, Debug, PartialEq, Eq)]
126 struct Extractor(Instant);
127
128 #[async_trait]
129 impl<S> FromRequestParts<S> for Extractor
130 where
131 S: Send + Sync,
132 {
133 type Rejection = Infallible;
134
135 async fn from_request_parts(
136 _parts: &mut Parts,
137 _state: &S,
138 ) -> Result<Self, Self::Rejection> {
139 COUNTER.fetch_add(1, Ordering::SeqCst);
140 Ok(Self(Instant::now()))
141 }
142 }
143
144 let (mut parts, _) = Request::new(()).into_parts();
145
146 let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
147 .await
148 .unwrap()
149 .0;
150 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
151
152 let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
153 .await
154 .unwrap()
155 .0;
156 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
157
158 assert_eq!(first, second);
159 }
160
161 // Not a #[test], we just want to know this compiles
162 async fn _last_handler_argument() {
163 async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
164 let _r: Router = Router::new().route("/", get(handler));
165 }
166}