use axum::{
async_trait,
extract::FromRequestParts,
response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
};
use headers::{Header, HeaderMapExt};
use http::request::Parts;
use std::convert::Infallible;
#[cfg(feature = "typed-header")]
#[derive(Debug, Clone, Copy)]
#[must_use]
pub struct TypedHeader<T>(pub T);
#[async_trait]
impl<T, S> FromRequestParts<S> for TypedHeader<T>
where
T: Header,
S: Send + Sync,
{
type Rejection = TypedHeaderRejection;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let mut values = parts.headers.get_all(T::name()).iter();
let is_missing = values.size_hint() == (0, Some(0));
T::decode(&mut values)
.map(Self)
.map_err(|err| TypedHeaderRejection {
name: T::name(),
reason: if is_missing {
TypedHeaderRejectionReason::Missing
} else {
TypedHeaderRejectionReason::Error(err)
},
})
}
}
axum_core::__impl_deref!(TypedHeader);
impl<T> IntoResponseParts for TypedHeader<T>
where
T: Header,
{
type Error = Infallible;
fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
res.headers_mut().typed_insert(self.0);
Ok(res)
}
}
impl<T> IntoResponse for TypedHeader<T>
where
T: Header,
{
fn into_response(self) -> Response {
let mut res = ().into_response();
res.headers_mut().typed_insert(self.0);
res
}
}
#[cfg(feature = "typed-header")]
#[derive(Debug)]
pub struct TypedHeaderRejection {
name: &'static http::header::HeaderName,
reason: TypedHeaderRejectionReason,
}
impl TypedHeaderRejection {
pub fn name(&self) -> &http::header::HeaderName {
self.name
}
pub fn reason(&self) -> &TypedHeaderRejectionReason {
&self.reason
}
#[must_use]
pub fn is_missing(&self) -> bool {
self.reason.is_missing()
}
}
#[cfg(feature = "typed-header")]
#[derive(Debug)]
#[non_exhaustive]
pub enum TypedHeaderRejectionReason {
Missing,
Error(headers::Error),
}
impl TypedHeaderRejectionReason {
#[must_use]
pub fn is_missing(&self) -> bool {
matches!(self, Self::Missing)
}
}
impl IntoResponse for TypedHeaderRejection {
fn into_response(self) -> Response {
(http::StatusCode::BAD_REQUEST, self.to_string()).into_response()
}
}
impl std::fmt::Display for TypedHeaderRejection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.reason {
TypedHeaderRejectionReason::Missing => {
write!(f, "Header of type `{}` was missing", self.name)
}
TypedHeaderRejectionReason::Error(err) => {
write!(f, "{} ({})", err, self.name)
}
}
}
}
impl std::error::Error for TypedHeaderRejection {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.reason {
TypedHeaderRejectionReason::Error(err) => Some(err),
TypedHeaderRejectionReason::Missing => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{routing::get, Router};
#[tokio::test]
async fn typed_header() {
async fn handle(
TypedHeader(user_agent): TypedHeader<headers::UserAgent>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> impl IntoResponse {
let user_agent = user_agent.as_str();
let cookies = cookies.iter().collect::<Vec<_>>();
format!("User-Agent={user_agent:?}, Cookie={cookies:?}")
}
let app = Router::new().route("/", get(handle));
let client = TestClient::new(app);
let res = client
.get("/")
.header("user-agent", "foobar")
.header("cookie", "a=1; b=2")
.header("cookie", "c=3")
.await;
let body = res.text().await;
assert_eq!(
body,
r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"#
);
let res = client.get("/").header("user-agent", "foobar").await;
let body = res.text().await;
assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#);
let res = client.get("/").header("cookie", "a=1").await;
let body = res.text().await;
assert_eq!(body, "Header of type `user-agent` was missing");
}
}