axum_extra/
typed_header.rs

1//! Extractor and response for typed headers.
2
3use axum::{
4    async_trait,
5    extract::FromRequestParts,
6    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
7};
8use headers::{Header, HeaderMapExt};
9use http::request::Parts;
10use std::convert::Infallible;
11
12/// Extractor and response that works with typed header values from [`headers`].
13///
14/// # As extractor
15///
16/// In general, it's recommended to extract only the needed headers via `TypedHeader` rather than
17/// removing all headers with the `HeaderMap` extractor.
18///
19/// ```rust,no_run
20/// use axum::{
21///     routing::get,
22///     Router,
23/// };
24/// use headers::UserAgent;
25/// use axum_extra::TypedHeader;
26///
27/// async fn users_teams_show(
28///     TypedHeader(user_agent): TypedHeader<UserAgent>,
29/// ) {
30///     // ...
31/// }
32///
33/// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show));
34/// # let _: Router = app;
35/// ```
36///
37/// # As response
38///
39/// ```rust
40/// use axum::{
41///     response::IntoResponse,
42/// };
43/// use headers::ContentType;
44/// use axum_extra::TypedHeader;
45///
46/// async fn handler() -> (TypedHeader<ContentType>, &'static str) {
47///     (
48///         TypedHeader(ContentType::text_utf8()),
49///         "Hello, World!",
50///     )
51/// }
52/// ```
53#[cfg(feature = "typed-header")]
54#[derive(Debug, Clone, Copy)]
55#[must_use]
56pub struct TypedHeader<T>(pub T);
57
58#[async_trait]
59impl<T, S> FromRequestParts<S> for TypedHeader<T>
60where
61    T: Header,
62    S: Send + Sync,
63{
64    type Rejection = TypedHeaderRejection;
65
66    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
67        let mut values = parts.headers.get_all(T::name()).iter();
68        let is_missing = values.size_hint() == (0, Some(0));
69        T::decode(&mut values)
70            .map(Self)
71            .map_err(|err| TypedHeaderRejection {
72                name: T::name(),
73                reason: if is_missing {
74                    // Report a more precise rejection for the missing header case.
75                    TypedHeaderRejectionReason::Missing
76                } else {
77                    TypedHeaderRejectionReason::Error(err)
78                },
79            })
80    }
81}
82
83axum_core::__impl_deref!(TypedHeader);
84
85impl<T> IntoResponseParts for TypedHeader<T>
86where
87    T: Header,
88{
89    type Error = Infallible;
90
91    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
92        res.headers_mut().typed_insert(self.0);
93        Ok(res)
94    }
95}
96
97impl<T> IntoResponse for TypedHeader<T>
98where
99    T: Header,
100{
101    fn into_response(self) -> Response {
102        let mut res = ().into_response();
103        res.headers_mut().typed_insert(self.0);
104        res
105    }
106}
107
108/// Rejection used for [`TypedHeader`].
109#[cfg(feature = "typed-header")]
110#[derive(Debug)]
111pub struct TypedHeaderRejection {
112    name: &'static http::header::HeaderName,
113    reason: TypedHeaderRejectionReason,
114}
115
116impl TypedHeaderRejection {
117    /// Name of the header that caused the rejection
118    pub fn name(&self) -> &http::header::HeaderName {
119        self.name
120    }
121
122    /// Reason why the header extraction has failed
123    pub fn reason(&self) -> &TypedHeaderRejectionReason {
124        &self.reason
125    }
126
127    /// Returns `true` if the typed header rejection reason is [`Missing`].
128    ///
129    /// [`Missing`]: TypedHeaderRejectionReason::Missing
130    #[must_use]
131    pub fn is_missing(&self) -> bool {
132        self.reason.is_missing()
133    }
134}
135
136/// Additional information regarding a [`TypedHeaderRejection`]
137#[cfg(feature = "typed-header")]
138#[derive(Debug)]
139#[non_exhaustive]
140pub enum TypedHeaderRejectionReason {
141    /// The header was missing from the HTTP request
142    Missing,
143    /// An error occurred when parsing the header from the HTTP request
144    Error(headers::Error),
145}
146
147impl TypedHeaderRejectionReason {
148    /// Returns `true` if the typed header rejection reason is [`Missing`].
149    ///
150    /// [`Missing`]: TypedHeaderRejectionReason::Missing
151    #[must_use]
152    pub fn is_missing(&self) -> bool {
153        matches!(self, Self::Missing)
154    }
155}
156
157impl IntoResponse for TypedHeaderRejection {
158    fn into_response(self) -> Response {
159        (http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
160    }
161}
162
163impl std::fmt::Display for TypedHeaderRejection {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match &self.reason {
166            TypedHeaderRejectionReason::Missing => {
167                write!(f, "Header of type `{}` was missing", self.name)
168            }
169            TypedHeaderRejectionReason::Error(err) => {
170                write!(f, "{} ({})", err, self.name)
171            }
172        }
173    }
174}
175
176impl std::error::Error for TypedHeaderRejection {
177    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
178        match &self.reason {
179            TypedHeaderRejectionReason::Error(err) => Some(err),
180            TypedHeaderRejectionReason::Missing => None,
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::test_helpers::*;
189    use axum::{routing::get, Router};
190
191    #[tokio::test]
192    async fn typed_header() {
193        async fn handle(
194            TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
195            TypedHeader(cookies): TypedHeader<headers::Cookie>,
196        ) -> impl IntoResponse {
197            let user_agent = user_agent.as_str();
198            let cookies = cookies.iter().collect::<Vec<_>>();
199            format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
200        }
201
202        let app = Router::new().route("/", get(handle));
203
204        let client = TestClient::new(app);
205
206        let res = client
207            .get("/")
208            .header("user-agent", "foobar")
209            .header("cookie", "a=1; b=2")
210            .header("cookie", "c=3")
211            .await;
212        let body = res.text().await;
213        assert_eq!(
214            body,
215            r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
216        );
217
218        let res = client.get("/").header("user-agent", "foobar").await;
219        let body = res.text().await;
220        assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
221
222        let res = client.get("/").header("cookie", "a=1").await;
223        let body = res.text().await;
224        assert_eq!(body, "Header of type `user-agent` was missing");
225    }
226}