use axum::{
extract::Request,
response::{IntoResponse, Redirect, Response},
routing::{any, MethodRouter},
Router,
};
use http::{uri::PathAndQuery, StatusCode, Uri};
use std::{borrow::Cow, convert::Infallible};
use tower_service::Service;
mod resource;
#[cfg(feature = "typed-routing")]
mod typed;
pub use self::resource::Resource;
#[cfg(feature = "typed-routing")]
pub use self::typed::WithQueryParams;
#[cfg(feature = "typed-routing")]
pub use axum_macros::TypedPath;
#[cfg(feature = "typed-routing")]
pub use self::typed::{SecondElementIs, TypedPath};
pub trait RouterExt<S>: sealed::Sealed {
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath;
fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
where
Self: Sized;
fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
Self: Sized;
}
impl<S> RouterExt<S> for Router<S>
where
S: Clone + Send + Sync + 'static,
{
#[cfg(feature = "typed-routing")]
fn typed_get<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::get(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_delete<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::delete(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_head<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::head(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_options<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::options(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_patch<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::patch(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_post<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::post(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_put<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::put(handler))
}
#[cfg(feature = "typed-routing")]
fn typed_trace<H, T, P>(self, handler: H) -> Self
where
H: axum::handler::Handler<T, S>,
T: SecondElementIs<P> + 'static,
P: TypedPath,
{
self.route(P::PATH, axum::routing::trace(handler))
}
#[track_caller]
fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
where
Self: Sized,
{
validate_tsr_path(path);
self = self.route(path, method_router);
add_tsr_redirect_route(self, path)
}
#[track_caller]
fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
Self: Sized,
{
validate_tsr_path(path);
self = self.route_service(path, service);
add_tsr_redirect_route(self, path)
}
}
#[track_caller]
fn validate_tsr_path(path: &str) {
if path == "/" {
panic!("Cannot add a trailing slash redirect route for `/`")
}
}
fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
where
S: Clone + Send + Sync + 'static,
{
async fn redirect_handler(uri: Uri) -> Response {
let new_uri = map_path(uri, |path| {
path.strip_suffix('/')
.map(Cow::Borrowed)
.unwrap_or_else(|| Cow::Owned(format!("{path}/")))
});
if let Some(new_uri) = new_uri {
Redirect::permanent(&new_uri.to_string()).into_response()
} else {
StatusCode::BAD_REQUEST.into_response()
}
}
if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
router.route(path_without_trailing_slash, any(redirect_handler))
} else {
router.route(&format!("{path}/"), any(redirect_handler))
}
}
fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
where
F: FnOnce(&str) -> Cow<'_, str>,
{
let mut parts = original_uri.into_parts();
let path_and_query = parts.path_and_query.as_ref()?;
let new_path = f(path_and_query.path());
let new_path_and_query = if let Some(query) = &path_and_query.query() {
format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
} else {
new_path.parse::<PathAndQuery>().ok()?
};
parts.path_and_query = Some(new_path_and_query);
Uri::from_parts(parts).ok()
}
mod sealed {
pub trait Sealed {}
impl<S> Sealed for axum::Router<S> {}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::*;
use axum::{extract::Path, routing::get};
#[tokio::test]
async fn test_tsr() {
let app = Router::new()
.route_with_tsr("/foo", get(|| async {}))
.route_with_tsr("/bar/", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/foo").await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/foo/").await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/foo");
let res = client.get("/bar/").await;
assert_eq!(res.status(), StatusCode::OK);
let res = client.get("/bar").await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/bar/");
}
#[tokio::test]
async fn tsr_with_params() {
let app = Router::new()
.route_with_tsr(
"/a/:a",
get(|Path(param): Path<String>| async move { param }),
)
.route_with_tsr(
"/b/:b/",
get(|Path(param): Path<String>| async move { param }),
);
let client = TestClient::new(app);
let res = client.get("/a/foo").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");
let res = client.get("/a/foo/").await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/a/foo");
let res = client.get("/b/foo/").await;
assert_eq!(res.status(), StatusCode::OK);
assert_eq!(res.text().await, "foo");
let res = client.get("/b/foo").await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/b/foo/");
}
#[tokio::test]
async fn tsr_maintains_query_params() {
let app = Router::new().route_with_tsr("/foo", get(|| async {}));
let client = TestClient::new(app);
let res = client.get("/foo/?a=a").await;
assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
assert_eq!(res.headers()["location"], "/foo?a=a");
}
#[test]
#[should_panic = "Cannot add a trailing slash redirect route for `/`"]
fn tsr_at_root() {
let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
}
}