1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
use axum::{
async_trait,
extract::{Extension, FromRequestParts},
};
use http::request::Parts;
/// Cache results of other extractors.
///
/// `Cached` wraps another extractor and caches its result in [request extensions].
///
/// This is useful if you have a tree of extractors that share common sub-extractors that
/// you only want to run once, perhaps because they're expensive.
///
/// The cache purely type based so you can only cache one value of each type. The cache is also
/// local to the current request and not reused across requests.
///
/// # Example
///
/// ```rust
/// use axum_extra::extract::Cached;
/// use axum::{
/// async_trait,
/// extract::FromRequestParts,
/// response::{IntoResponse, Response},
/// http::{StatusCode, request::Parts},
/// };
///
/// #[derive(Clone)]
/// struct Session { /* ... */ }
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for Session
/// where
/// S: Send + Sync,
/// {
/// type Rejection = (StatusCode, String);
///
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // load session...
/// # unimplemented!()
/// }
/// }
///
/// struct CurrentUser { /* ... */ }
///
/// #[async_trait]
/// impl<S> FromRequestParts<S> for CurrentUser
/// where
/// S: Send + Sync,
/// {
/// type Rejection = Response;
///
/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
/// // loading a `CurrentUser` requires first loading the `Session`
/// //
/// // by using `Cached<Session>` we avoid extracting the session more than
/// // once, in case other extractors for the same request also loads the session
/// let session: Session = Cached::<Session>::from_request_parts(parts, state)
/// .await
/// .map_err(|err| err.into_response())?
/// .0;
///
/// // load user from session...
/// # unimplemented!()
/// }
/// }
///
/// // handler that extracts the current user and the session
/// //
/// // the session will only be loaded once, even though `CurrentUser`
/// // also loads it
/// async fn handler(
/// current_user: CurrentUser,
/// // we have to use `Cached<Session>` here otherwise the
/// // cached session would not be used
/// Cached(session): Cached<Session>,
/// ) {
/// // ...
/// }
/// ```
///
/// [request extensions]: http::Extensions
#[derive(Debug, Clone, Default)]
pub struct Cached<T>(pub T);
#[derive(Clone)]
struct CachedEntry<T>(T);
#[async_trait]
impl<S, T> FromRequestParts<S> for Cached<T>
where
S: Send + Sync,
T: FromRequestParts<S> + Clone + Send + Sync + 'static,
{
type Rejection = T::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
Err(_) => {
let value = T::from_request_parts(parts, state).await?;
parts.extensions.insert(CachedEntry(value.clone()));
Ok(Self(value))
}
}
}
}
axum_core::__impl_deref!(Cached);
#[cfg(test)]
mod tests {
use super::*;
use axum::{http::Request, routing::get, Router};
use std::{
convert::Infallible,
sync::atomic::{AtomicU32, Ordering},
time::Instant,
};
#[tokio::test]
async fn works() {
static COUNTER: AtomicU32 = AtomicU32::new(0);
#[derive(Clone, Debug, PartialEq, Eq)]
struct Extractor(Instant);
#[async_trait]
impl<S> FromRequestParts<S> for Extractor
where
S: Send + Sync,
{
type Rejection = Infallible;
async fn from_request_parts(
_parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
COUNTER.fetch_add(1, Ordering::SeqCst);
Ok(Self(Instant::now()))
}
}
let (mut parts, _) = Request::new(()).into_parts();
let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
.await
.unwrap()
.0;
assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
assert_eq!(first, second);
}
// Not a #[test], we just want to know this compiles
async fn _last_handler_argument() {
async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
let _r: Router = Router::new().route("/", get(handler));
}
}