axum/extract/
request_parts.rs

1use super::{Extension, FromRequestParts};
2use async_trait::async_trait;
3use http::{request::Parts, Uri};
4use std::convert::Infallible;
5
6/// Extractor that gets the original request URI regardless of nesting.
7///
8/// This is necessary since [`Uri`](http::Uri), when used as an extractor, will
9/// have the prefix stripped if used in a nested service.
10///
11/// # Example
12///
13/// ```
14/// use axum::{
15///     routing::get,
16///     Router,
17///     extract::OriginalUri,
18///     http::Uri
19/// };
20///
21/// let api_routes = Router::new()
22///     .route(
23///         "/users",
24///         get(|uri: Uri, OriginalUri(original_uri): OriginalUri| async {
25///             // `uri` is `/users`
26///             // `original_uri` is `/api/users`
27///         }),
28///     );
29///
30/// let app = Router::new().nest("/api", api_routes);
31/// # let _: Router = app;
32/// ```
33///
34/// # Extracting via request extensions
35///
36/// `OriginalUri` can also be accessed from middleware via request extensions.
37/// This is useful for example with [`Trace`](tower_http::trace::Trace) to
38/// create a span that contains the full path, if your service might be nested:
39///
40/// ```
41/// use axum::{
42///     Router,
43///     extract::OriginalUri,
44///     http::Request,
45///     routing::get,
46/// };
47/// use tower_http::trace::TraceLayer;
48///
49/// let api_routes = Router::new()
50///     .route("/users/:id", get(|| async { /* ... */ }))
51///     .layer(
52///         TraceLayer::new_for_http().make_span_with(|req: &Request<_>| {
53///             let path = if let Some(path) = req.extensions().get::<OriginalUri>() {
54///                 // This will include `/api`
55///                 path.0.path().to_owned()
56///             } else {
57///                 // The `OriginalUri` extension will always be present if using
58///                 // `Router` unless another extractor or middleware has removed it
59///                 req.uri().path().to_owned()
60///             };
61///             tracing::info_span!("http-request", %path)
62///         }),
63///     );
64///
65/// let app = Router::new().nest("/api", api_routes);
66/// # let _: Router = app;
67/// ```
68#[cfg(feature = "original-uri")]
69#[derive(Debug, Clone)]
70pub struct OriginalUri(pub Uri);
71
72#[cfg(feature = "original-uri")]
73#[async_trait]
74impl<S> FromRequestParts<S> for OriginalUri
75where
76    S: Send + Sync,
77{
78    type Rejection = Infallible;
79
80    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
81        let uri = Extension::<Self>::from_request_parts(parts, state)
82            .await
83            .unwrap_or_else(|_| Extension(OriginalUri(parts.uri.clone())))
84            .0;
85        Ok(uri)
86    }
87}
88
89#[cfg(feature = "original-uri")]
90axum_core::__impl_deref!(OriginalUri: Uri);
91
92#[cfg(test)]
93mod tests {
94    use crate::{extract::Extension, routing::get, test_helpers::*, Router};
95    use http::{Method, StatusCode};
96
97    #[crate::test]
98    async fn extract_request_parts() {
99        #[derive(Clone)]
100        struct Ext;
101
102        async fn handler(parts: http::request::Parts) {
103            assert_eq!(parts.method, Method::GET);
104            assert_eq!(parts.uri, "/");
105            assert_eq!(parts.version, http::Version::HTTP_11);
106            assert_eq!(parts.headers["x-foo"], "123");
107            parts.extensions.get::<Ext>().unwrap();
108        }
109
110        let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext)));
111
112        let res = client.get("/").header("x-foo", "123").await;
113        assert_eq!(res.status(), StatusCode::OK);
114    }
115}