1use crate::response::{IntoResponse, Response};
2use axum_core::extract::FromRequestParts;
3use futures_util::future::BoxFuture;
4use http::Request;
5use std::{
6 any::type_name,
7 convert::Infallible,
8 fmt,
9 future::Future,
10 marker::PhantomData,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub fn map_response<F, T>(f: F) -> MapResponseLayer<F, (), T> {
100 map_response_with_state((), f)
101}
102
103pub fn map_response_with_state<F, S, T>(state: S, f: F) -> MapResponseLayer<F, S, T> {
142 MapResponseLayer {
143 f,
144 state,
145 _extractor: PhantomData,
146 }
147}
148
149#[must_use]
153pub struct MapResponseLayer<F, S, T> {
154 f: F,
155 state: S,
156 _extractor: PhantomData<fn() -> T>,
157}
158
159impl<F, S, T> Clone for MapResponseLayer<F, S, T>
160where
161 F: Clone,
162 S: Clone,
163{
164 fn clone(&self) -> Self {
165 Self {
166 f: self.f.clone(),
167 state: self.state.clone(),
168 _extractor: self._extractor,
169 }
170 }
171}
172
173impl<S, I, F, T> Layer<I> for MapResponseLayer<F, S, T>
174where
175 F: Clone,
176 S: Clone,
177{
178 type Service = MapResponse<F, S, I, T>;
179
180 fn layer(&self, inner: I) -> Self::Service {
181 MapResponse {
182 f: self.f.clone(),
183 state: self.state.clone(),
184 inner,
185 _extractor: PhantomData,
186 }
187 }
188}
189
190impl<F, S, T> fmt::Debug for MapResponseLayer<F, S, T>
191where
192 S: fmt::Debug,
193{
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("MapResponseLayer")
196 .field("f", &format_args!("{}", type_name::<F>()))
198 .field("state", &self.state)
199 .finish()
200 }
201}
202
203pub struct MapResponse<F, S, I, T> {
207 f: F,
208 inner: I,
209 state: S,
210 _extractor: PhantomData<fn() -> T>,
211}
212
213impl<F, S, I, T> Clone for MapResponse<F, S, I, T>
214where
215 F: Clone,
216 I: Clone,
217 S: Clone,
218{
219 fn clone(&self) -> Self {
220 Self {
221 f: self.f.clone(),
222 inner: self.inner.clone(),
223 state: self.state.clone(),
224 _extractor: self._extractor,
225 }
226 }
227}
228
229macro_rules! impl_service {
230 (
231 $($ty:ident),*
232 ) => {
233 #[allow(non_snake_case, unused_mut)]
234 impl<F, Fut, S, I, B, ResBody, $($ty,)*> Service<Request<B>> for MapResponse<F, S, I, ($($ty,)*)>
235 where
236 F: FnMut($($ty,)* Response<ResBody>) -> Fut + Clone + Send + 'static,
237 $( $ty: FromRequestParts<S> + Send, )*
238 Fut: Future + Send + 'static,
239 Fut::Output: IntoResponse + Send + 'static,
240 I: Service<Request<B>, Response = Response<ResBody>, Error = Infallible>
241 + Clone
242 + Send
243 + 'static,
244 I::Future: Send + 'static,
245 B: Send + 'static,
246 ResBody: Send + 'static,
247 S: Clone + Send + Sync + 'static,
248 {
249 type Response = Response;
250 type Error = Infallible;
251 type Future = ResponseFuture;
252
253 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254 self.inner.poll_ready(cx)
255 }
256
257
258 fn call(&mut self, req: Request<B>) -> Self::Future {
259 let not_ready_inner = self.inner.clone();
260 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
261
262 let mut f = self.f.clone();
263 let _state = self.state.clone();
264 let (mut parts, body) = req.into_parts();
265
266 let future = Box::pin(async move {
267 $(
268 let $ty = match $ty::from_request_parts(&mut parts, &_state).await {
269 Ok(value) => value,
270 Err(rejection) => return rejection.into_response(),
271 };
272 )*
273
274 let req = Request::from_parts(parts, body);
275
276 match ready_inner.call(req).await {
277 Ok(res) => {
278 f($($ty,)* res).await.into_response()
279 }
280 Err(err) => match err {}
281 }
282 });
283
284 ResponseFuture {
285 inner: future
286 }
287 }
288 }
289 };
290}
291
292impl_service!();
293impl_service!(T1);
294impl_service!(T1, T2);
295impl_service!(T1, T2, T3);
296impl_service!(T1, T2, T3, T4);
297impl_service!(T1, T2, T3, T4, T5);
298impl_service!(T1, T2, T3, T4, T5, T6);
299impl_service!(T1, T2, T3, T4, T5, T6, T7);
300impl_service!(T1, T2, T3, T4, T5, T6, T7, T8);
301impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
302impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
303impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
304impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
305impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
306impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
307impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
308impl_service!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
309
310impl<F, S, I, T> fmt::Debug for MapResponse<F, S, I, T>
311where
312 S: fmt::Debug,
313 I: fmt::Debug,
314{
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316 f.debug_struct("MapResponse")
317 .field("f", &format_args!("{}", type_name::<F>()))
318 .field("inner", &self.inner)
319 .field("state", &self.state)
320 .finish()
321 }
322}
323
324pub struct ResponseFuture {
326 inner: BoxFuture<'static, Response>,
327}
328
329impl Future for ResponseFuture {
330 type Output = Result<Response, Infallible>;
331
332 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
333 self.inner.as_mut().poll(cx).map(Ok)
334 }
335}
336
337impl fmt::Debug for ResponseFuture {
338 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
339 f.debug_struct("ResponseFuture").finish()
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 #[allow(unused_imports)]
346 use super::*;
347 use crate::{test_helpers::TestClient, Router};
348
349 #[crate::test]
350 async fn works() {
351 async fn add_header<B>(mut res: Response<B>) -> Response<B> {
352 res.headers_mut().insert("x-foo", "foo".parse().unwrap());
353 res
354 }
355
356 let app = Router::new().layer(map_response(add_header));
357 let client = TestClient::new(app);
358
359 let res = client.get("/").await;
360
361 assert_eq!(res.headers()["x-foo"], "foo");
362 }
363}