1use crate::body::{Body, Bytes, HttpBody};
2use crate::response::{IntoResponse, Response};
3use crate::BoxError;
4use axum_core::extract::{FromRequest, FromRequestParts};
5use futures_util::future::BoxFuture;
6use http::Request;
7use std::{
8 any::type_name,
9 convert::Infallible,
10 fmt,
11 future::Future,
12 marker::PhantomData,
13 pin::Pin,
14 task::{Context, Poll},
15};
16use tower_layer::Layer;
17use tower_service::Service;
18
19pub fn map_request<F, T>(f: F) -> MapRequestLayer<F, (), T> {
118 map_request_with_state((), f)
119}
120
121pub fn map_request_with_state<F, S, T>(state: S, f: F) -> MapRequestLayer<F, S, T> {
160 MapRequestLayer {
161 f,
162 state,
163 _extractor: PhantomData,
164 }
165}
166
167#[must_use]
171pub struct MapRequestLayer<F, S, T> {
172 f: F,
173 state: S,
174 _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for MapRequestLayer<F, S, T>
178where
179 F: Clone,
180 S: Clone,
181{
182 fn clone(&self) -> Self {
183 Self {
184 f: self.f.clone(),
185 state: self.state.clone(),
186 _extractor: self._extractor,
187 }
188 }
189}
190
191impl<S, I, F, T> Layer<I> for MapRequestLayer<F, S, T>
192where
193 F: Clone,
194 S: Clone,
195{
196 type Service = MapRequest<F, S, I, T>;
197
198 fn layer(&self, inner: I) -> Self::Service {
199 MapRequest {
200 f: self.f.clone(),
201 state: self.state.clone(),
202 inner,
203 _extractor: PhantomData,
204 }
205 }
206}
207
208impl<F, S, T> fmt::Debug for MapRequestLayer<F, S, T>
209where
210 S: fmt::Debug,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("MapRequestLayer")
214 .field("f", &format_args!("{}", type_name::<F>()))
216 .field("state", &self.state)
217 .finish()
218 }
219}
220
221pub struct MapRequest<F, S, I, T> {
225 f: F,
226 inner: I,
227 state: S,
228 _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for MapRequest<F, S, I, T>
232where
233 F: Clone,
234 I: Clone,
235 S: Clone,
236{
237 fn clone(&self) -> Self {
238 Self {
239 f: self.f.clone(),
240 inner: self.inner.clone(),
241 state: self.state.clone(),
242 _extractor: self._extractor,
243 }
244 }
245}
246
247macro_rules! impl_service {
248 (
249 [$($ty:ident),*], $last:ident
250 ) => {
251 #[allow(non_snake_case, unused_mut)]
252 impl<F, Fut, S, I, B, $($ty,)* $last> Service<Request<B>> for MapRequest<F, S, I, ($($ty,)* $last,)>
253 where
254 F: FnMut($($ty,)* $last) -> Fut + Clone + Send + 'static,
255 $( $ty: FromRequestParts<S> + Send, )*
256 $last: FromRequest<S> + Send,
257 Fut: Future + Send + 'static,
258 Fut::Output: IntoMapRequestResult<B> + Send + 'static,
259 I: Service<Request<B>, Error = Infallible>
260 + Clone
261 + Send
262 + 'static,
263 I::Response: IntoResponse,
264 I::Future: Send + 'static,
265 B: HttpBody<Data = Bytes> + Send + 'static,
266 B::Error: Into<BoxError>,
267 S: Clone + Send + Sync + 'static,
268 {
269 type Response = Response;
270 type Error = Infallible;
271 type Future = ResponseFuture;
272
273 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274 self.inner.poll_ready(cx)
275 }
276
277 fn call(&mut self, req: Request<B>) -> Self::Future {
278 let req = req.map(Body::new);
279
280 let not_ready_inner = self.inner.clone();
281 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
282
283 let mut f = self.f.clone();
284 let state = self.state.clone();
285 let (mut parts, body) = req.into_parts();
286
287 let future = Box::pin(async move {
288 $(
289 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
290 Ok(value) => value,
291 Err(rejection) => return rejection.into_response(),
292 };
293 )*
294
295 let req = Request::from_parts(parts, body);
296
297 let $last = match $last::from_request(req, &state).await {
298 Ok(value) => value,
299 Err(rejection) => return rejection.into_response(),
300 };
301
302 match f($($ty,)* $last).await.into_map_request_result() {
303 Ok(req) => {
304 ready_inner.call(req).await.into_response()
305 }
306 Err(res) => {
307 res
308 }
309 }
310 });
311
312 ResponseFuture {
313 inner: future
314 }
315 }
316 }
317 };
318}
319
320all_the_tuples!(impl_service);
321
322impl<F, S, I, T> fmt::Debug for MapRequest<F, S, I, T>
323where
324 S: fmt::Debug,
325 I: fmt::Debug,
326{
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 f.debug_struct("MapRequest")
329 .field("f", &format_args!("{}", type_name::<F>()))
330 .field("inner", &self.inner)
331 .field("state", &self.state)
332 .finish()
333 }
334}
335
336pub struct ResponseFuture {
338 inner: BoxFuture<'static, Response>,
339}
340
341impl Future for ResponseFuture {
342 type Output = Result<Response, Infallible>;
343
344 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
345 self.inner.as_mut().poll(cx).map(Ok)
346 }
347}
348
349impl fmt::Debug for ResponseFuture {
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351 f.debug_struct("ResponseFuture").finish()
352 }
353}
354
355mod private {
356 use crate::{http::Request, response::IntoResponse};
357
358 pub trait Sealed<B> {}
359 impl<B, E> Sealed<B> for Result<Request<B>, E> where E: IntoResponse {}
360 impl<B> Sealed<B> for Request<B> {}
361}
362
363pub trait IntoMapRequestResult<B>: private::Sealed<B> {
368 fn into_map_request_result(self) -> Result<Request<B>, Response>;
370}
371
372impl<B, E> IntoMapRequestResult<B> for Result<Request<B>, E>
373where
374 E: IntoResponse,
375{
376 fn into_map_request_result(self) -> Result<Request<B>, Response> {
377 self.map_err(IntoResponse::into_response)
378 }
379}
380
381impl<B> IntoMapRequestResult<B> for Request<B> {
382 fn into_map_request_result(self) -> Result<Request<B>, Response> {
383 Ok(self)
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::{routing::get, test_helpers::TestClient, Router};
391 use http::{HeaderMap, StatusCode};
392
393 #[crate::test]
394 async fn works() {
395 async fn add_header<B>(mut req: Request<B>) -> Request<B> {
396 req.headers_mut().insert("x-foo", "foo".parse().unwrap());
397 req
398 }
399
400 async fn handler(headers: HeaderMap) -> Response {
401 headers["x-foo"]
402 .to_str()
403 .unwrap()
404 .to_owned()
405 .into_response()
406 }
407
408 let app = Router::new()
409 .route("/", get(handler))
410 .layer(map_request(add_header));
411 let client = TestClient::new(app);
412
413 let res = client.get("/").await;
414
415 assert_eq!(res.text().await, "foo");
416 }
417
418 #[crate::test]
419 async fn works_for_short_circutting() {
420 async fn add_header<B>(_req: Request<B>) -> Result<Request<B>, (StatusCode, &'static str)> {
421 Err((StatusCode::INTERNAL_SERVER_ERROR, "something went wrong"))
422 }
423
424 async fn handler(_headers: HeaderMap) -> Response {
425 unreachable!()
426 }
427
428 let app = Router::new()
429 .route("/", get(handler))
430 .layer(map_request(add_header));
431 let client = TestClient::new(app);
432
433 let res = client.get("/").await;
434
435 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
436 assert_eq!(res.text().await, "something went wrong");
437 }
438}