axum_extra/
typed_header.rs
1use axum::{
4 async_trait,
5 extract::FromRequestParts,
6 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
7};
8use headers::{Header, HeaderMapExt};
9use http::request::Parts;
10use std::convert::Infallible;
11
12#[cfg(feature = "typed-header")]
54#[derive(Debug, Clone, Copy)]
55#[must_use]
56pub struct TypedHeader<T>(pub T);
57
58#[async_trait]
59impl<T, S> FromRequestParts<S> for TypedHeader<T>
60where
61 T: Header,
62 S: Send + Sync,
63{
64 type Rejection = TypedHeaderRejection;
65
66 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
67 let mut values = parts.headers.get_all(T::name()).iter();
68 let is_missing = values.size_hint() == (0, Some(0));
69 T::decode(&mut values)
70 .map(Self)
71 .map_err(|err| TypedHeaderRejection {
72 name: T::name(),
73 reason: if is_missing {
74 TypedHeaderRejectionReason::Missing
76 } else {
77 TypedHeaderRejectionReason::Error(err)
78 },
79 })
80 }
81}
82
83axum_core::__impl_deref!(TypedHeader);
84
85impl<T> IntoResponseParts for TypedHeader<T>
86where
87 T: Header,
88{
89 type Error = Infallible;
90
91 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
92 res.headers_mut().typed_insert(self.0);
93 Ok(res)
94 }
95}
96
97impl<T> IntoResponse for TypedHeader<T>
98where
99 T: Header,
100{
101 fn into_response(self) -> Response {
102 let mut res = ().into_response();
103 res.headers_mut().typed_insert(self.0);
104 res
105 }
106}
107
108#[cfg(feature = "typed-header")]
110#[derive(Debug)]
111pub struct TypedHeaderRejection {
112 name: &'static http::header::HeaderName,
113 reason: TypedHeaderRejectionReason,
114}
115
116impl TypedHeaderRejection {
117 pub fn name(&self) -> &http::header::HeaderName {
119 self.name
120 }
121
122 pub fn reason(&self) -> &TypedHeaderRejectionReason {
124 &self.reason
125 }
126
127 #[must_use]
131 pub fn is_missing(&self) -> bool {
132 self.reason.is_missing()
133 }
134}
135
136#[cfg(feature = "typed-header")]
138#[derive(Debug)]
139#[non_exhaustive]
140pub enum TypedHeaderRejectionReason {
141 Missing,
143 Error(headers::Error),
145}
146
147impl TypedHeaderRejectionReason {
148 #[must_use]
152 pub fn is_missing(&self) -> bool {
153 matches!(self, Self::Missing)
154 }
155}
156
157impl IntoResponse for TypedHeaderRejection {
158 fn into_response(self) -> Response {
159 (http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
160 }
161}
162
163impl std::fmt::Display for TypedHeaderRejection {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match &self.reason {
166 TypedHeaderRejectionReason::Missing => {
167 write!(f, "Header of type `{}` was missing", self.name)
168 }
169 TypedHeaderRejectionReason::Error(err) => {
170 write!(f, "{} ({})", err, self.name)
171 }
172 }
173 }
174}
175
176impl std::error::Error for TypedHeaderRejection {
177 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
178 match &self.reason {
179 TypedHeaderRejectionReason::Error(err) => Some(err),
180 TypedHeaderRejectionReason::Missing => None,
181 }
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::test_helpers::*;
189 use axum::{routing::get, Router};
190
191 #[tokio::test]
192 async fn typed_header() {
193 async fn handle(
194 TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
195 TypedHeader(cookies): TypedHeader<headers::Cookie>,
196 ) -> impl IntoResponse {
197 let user_agent = user_agent.as_str();
198 let cookies = cookies.iter().collect::<Vec<_>>();
199 format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
200 }
201
202 let app = Router::new().route("/", get(handle));
203
204 let client = TestClient::new(app);
205
206 let res = client
207 .get("/")
208 .header("user-agent", "foobar")
209 .header("cookie", "a=1; b=2")
210 .header("cookie", "c=3")
211 .await;
212 let body = res.text().await;
213 assert_eq!(
214 body,
215 r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
216 );
217
218 let res = client.get("/").header("user-agent", "foobar").await;
219 let body = res.text().await;
220 assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
221
222 let res = client.get("/").header("cookie", "a=1").await;
223 let body = res.text().await;
224 assert_eq!(body, "Header of type `user-agent` was missing");
225 }
226}