axum/extract/state.rs
1use async_trait::async_trait;
2use axum_core::extract::{FromRef, FromRequestParts};
3use http::request::Parts;
4use std::{
5 convert::Infallible,
6 ops::{Deref, DerefMut},
7};
8
9/// Extractor for state.
10///
11/// See ["Accessing state in middleware"][state-from-middleware] for how to
12/// access state in middleware.
13///
14/// [state-from-middleware]: crate::middleware#accessing-state-in-middleware
15///
16/// # With `Router`
17///
18/// ```
19/// use axum::{Router, routing::get, extract::State};
20///
21/// // the application state
22/// //
23/// // here you can put configuration, database connection pools, or whatever
24/// // state you need
25/// //
26/// // see "When states need to implement `Clone`" for more details on why we need
27/// // `#[derive(Clone)]` here.
28/// #[derive(Clone)]
29/// struct AppState {}
30///
31/// let state = AppState {};
32///
33/// // create a `Router` that holds our state
34/// let app = Router::new()
35/// .route("/", get(handler))
36/// // provide the state so the router can access it
37/// .with_state(state);
38///
39/// async fn handler(
40/// // access the state via the `State` extractor
41/// // extracting a state of the wrong type results in a compile error
42/// State(state): State<AppState>,
43/// ) {
44/// // use `state`...
45/// }
46/// # let _: axum::Router = app;
47/// ```
48///
49/// Note that `State` is an extractor, so be sure to put it before any body
50/// extractors, see ["the order of extractors"][order-of-extractors].
51///
52/// [order-of-extractors]: crate::extract#the-order-of-extractors
53///
54/// ## Combining stateful routers
55///
56/// Multiple [`Router`]s can be combined with [`Router::nest`] or [`Router::merge`]
57/// When combining [`Router`]s with one of these methods, the [`Router`]s must have
58/// the same state type. Generally, this can be inferred automatically:
59///
60/// ```
61/// use axum::{Router, routing::get, extract::State};
62///
63/// #[derive(Clone)]
64/// struct AppState {}
65///
66/// let state = AppState {};
67///
68/// // create a `Router` that will be nested within another
69/// let api = Router::new()
70/// .route("/posts", get(posts_handler));
71///
72/// let app = Router::new()
73/// .nest("/api", api)
74/// .with_state(state);
75///
76/// async fn posts_handler(State(state): State<AppState>) {
77/// // use `state`...
78/// }
79/// # let _: axum::Router = app;
80/// ```
81///
82/// However, if you are composing [`Router`]s that are defined in separate scopes,
83/// you may need to annotate the [`State`] type explicitly:
84///
85/// ```
86/// use axum::{Router, routing::get, extract::State};
87///
88/// #[derive(Clone)]
89/// struct AppState {}
90///
91/// fn make_app() -> Router {
92/// let state = AppState {};
93///
94/// Router::new()
95/// .nest("/api", make_api())
96/// .with_state(state) // the outer Router's state is inferred
97/// }
98///
99/// // the inner Router must specify its state type to compose with the
100/// // outer router
101/// fn make_api() -> Router<AppState> {
102/// Router::new()
103/// .route("/posts", get(posts_handler))
104/// }
105///
106/// async fn posts_handler(State(state): State<AppState>) {
107/// // use `state`...
108/// }
109/// # let _: axum::Router = make_app();
110/// ```
111///
112/// In short, a [`Router`]'s generic state type defaults to `()`
113/// (no state) unless [`Router::with_state`] is called or the value
114/// of the generic type is given explicitly.
115///
116/// [`Router`]: crate::Router
117/// [`Router::merge`]: crate::Router::merge
118/// [`Router::nest`]: crate::Router::nest
119/// [`Router::with_state`]: crate::Router::with_state
120///
121/// # With `MethodRouter`
122///
123/// ```
124/// use axum::{routing::get, extract::State};
125///
126/// #[derive(Clone)]
127/// struct AppState {}
128///
129/// let state = AppState {};
130///
131/// let method_router_with_state = get(handler)
132/// // provide the state so the handler can access it
133/// .with_state(state);
134/// # let _: axum::routing::MethodRouter = method_router_with_state;
135///
136/// async fn handler(State(state): State<AppState>) {
137/// // use `state`...
138/// }
139/// ```
140///
141/// # With `Handler`
142///
143/// ```
144/// use axum::{routing::get, handler::Handler, extract::State};
145///
146/// #[derive(Clone)]
147/// struct AppState {}
148///
149/// let state = AppState {};
150///
151/// async fn handler(State(state): State<AppState>) {
152/// // use `state`...
153/// }
154///
155/// // provide the state so the handler can access it
156/// let handler_with_state = handler.with_state(state);
157///
158/// # async {
159/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
160/// axum::serve(listener, handler_with_state.into_make_service()).await.unwrap();
161/// # };
162/// ```
163///
164/// # Substates
165///
166/// [`State`] only allows a single state type but you can use [`FromRef`] to extract "substates":
167///
168/// ```
169/// use axum::{Router, routing::get, extract::{State, FromRef}};
170///
171/// // the application state
172/// #[derive(Clone)]
173/// struct AppState {
174/// // that holds some api specific state
175/// api_state: ApiState,
176/// }
177///
178/// // the api specific state
179/// #[derive(Clone)]
180/// struct ApiState {}
181///
182/// // support converting an `AppState` in an `ApiState`
183/// impl FromRef<AppState> for ApiState {
184/// fn from_ref(app_state: &AppState) -> ApiState {
185/// app_state.api_state.clone()
186/// }
187/// }
188///
189/// let state = AppState {
190/// api_state: ApiState {},
191/// };
192///
193/// let app = Router::new()
194/// .route("/", get(handler))
195/// .route("/api/users", get(api_users))
196/// .with_state(state);
197///
198/// async fn api_users(
199/// // access the api specific state
200/// State(api_state): State<ApiState>,
201/// ) {
202/// }
203///
204/// async fn handler(
205/// // we can still access to top level state
206/// State(state): State<AppState>,
207/// ) {
208/// }
209/// # let _: axum::Router = app;
210/// ```
211///
212/// For convenience `FromRef` can also be derived using `#[derive(FromRef)]`.
213///
214/// # For library authors
215///
216/// If you're writing a library that has an extractor that needs state, this is the recommended way
217/// to do it:
218///
219/// ```rust
220/// use axum_core::extract::{FromRequestParts, FromRef};
221/// use http::request::Parts;
222/// use async_trait::async_trait;
223/// use std::convert::Infallible;
224///
225/// // the extractor your library provides
226/// struct MyLibraryExtractor;
227///
228/// #[async_trait]
229/// impl<S> FromRequestParts<S> for MyLibraryExtractor
230/// where
231/// // keep `S` generic but require that it can produce a `MyLibraryState`
232/// // this means users will have to implement `FromRef<UserState> for MyLibraryState`
233/// MyLibraryState: FromRef<S>,
234/// S: Send + Sync,
235/// {
236/// type Rejection = Infallible;
237///
238/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
239/// // get a `MyLibraryState` from a reference to the state
240/// let state = MyLibraryState::from_ref(state);
241///
242/// // ...
243/// # todo!()
244/// }
245/// }
246///
247/// // the state your library needs
248/// struct MyLibraryState {
249/// // ...
250/// }
251/// ```
252///
253/// # When states need to implement `Clone`
254///
255/// Your top level state type must implement `Clone` to be extractable with `State`:
256///
257/// ```
258/// use axum::extract::State;
259///
260/// // no substates, so to extract to `State<AppState>` we must implement `Clone` for `AppState`
261/// #[derive(Clone)]
262/// struct AppState {}
263///
264/// async fn handler(State(state): State<AppState>) {
265/// // ...
266/// }
267/// ```
268///
269/// This works because of [`impl<S> FromRef<S> for S where S: Clone`][`FromRef`].
270///
271/// This is also true if you're extracting substates, unless you _never_ extract the top level
272/// state itself:
273///
274/// ```
275/// use axum::extract::{State, FromRef};
276///
277/// // we never extract `State<AppState>`, just `State<InnerState>`. So `AppState` doesn't need to
278/// // implement `Clone`
279/// struct AppState {
280/// inner: InnerState,
281/// }
282///
283/// #[derive(Clone)]
284/// struct InnerState {}
285///
286/// impl FromRef<AppState> for InnerState {
287/// fn from_ref(app_state: &AppState) -> InnerState {
288/// app_state.inner.clone()
289/// }
290/// }
291///
292/// async fn api_users(State(inner): State<InnerState>) {
293/// // ...
294/// }
295/// ```
296///
297/// In general however we recommend you implement `Clone` for all your state types to avoid
298/// potential type errors.
299///
300/// # Shared mutable state
301///
302/// [As state is global within a `Router`][global] you can't directly get a mutable reference to
303/// the state.
304///
305/// The most basic solution is to use an `Arc<Mutex<_>>`. Which kind of mutex you need depends on
306/// your use case. See [the tokio docs] for more details.
307///
308/// Note that holding a locked `std::sync::Mutex` across `.await` points will result in `!Send`
309/// futures which are incompatible with axum. If you need to hold a mutex across `.await` points,
310/// consider using a `tokio::sync::Mutex` instead.
311///
312/// ## Example
313///
314/// ```
315/// use axum::{Router, routing::get, extract::State};
316/// use std::sync::{Arc, Mutex};
317///
318/// #[derive(Clone)]
319/// struct AppState {
320/// data: Arc<Mutex<String>>,
321/// }
322///
323/// async fn handler(State(state): State<AppState>) {
324/// {
325/// let mut data = state.data.lock().expect("mutex was poisoned");
326/// *data = "updated foo".to_owned();
327/// }
328///
329/// // ...
330/// }
331///
332/// let state = AppState {
333/// data: Arc::new(Mutex::new("foo".to_owned())),
334/// };
335///
336/// let app = Router::new()
337/// .route("/", get(handler))
338/// .with_state(state);
339/// # let _: Router = app;
340/// ```
341///
342/// [global]: crate::Router::with_state
343/// [the tokio docs]: https://docs.rs/tokio/1.25.0/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
344#[derive(Debug, Default, Clone, Copy)]
345pub struct State<S>(pub S);
346
347#[async_trait]
348impl<OuterState, InnerState> FromRequestParts<OuterState> for State<InnerState>
349where
350 InnerState: FromRef<OuterState>,
351 OuterState: Send + Sync,
352{
353 type Rejection = Infallible;
354
355 async fn from_request_parts(
356 _parts: &mut Parts,
357 state: &OuterState,
358 ) -> Result<Self, Self::Rejection> {
359 let inner_state = InnerState::from_ref(state);
360 Ok(Self(inner_state))
361 }
362}
363
364impl<S> Deref for State<S> {
365 type Target = S;
366
367 fn deref(&self) -> &Self::Target {
368 &self.0
369 }
370}
371
372impl<S> DerefMut for State<S> {
373 fn deref_mut(&mut self) -> &mut Self::Target {
374 &mut self.0
375 }
376}