1use crate::response::{IntoResponse, Response};
2use axum_core::extract::{FromRequest, FromRequestParts, Request};
3use futures_util::future::BoxFuture;
4use std::{
5 any::type_name,
6 convert::Infallible,
7 fmt,
8 future::Future,
9 marker::PhantomData,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tower::{util::BoxCloneService, ServiceBuilder};
14use tower_layer::Layer;
15use tower_service::Service;
16
17pub fn from_fn<F, T>(f: F) -> FromFnLayer<F, (), T> {
110 from_fn_with_state((), f)
111}
112
113pub fn from_fn_with_state<F, S, T>(state: S, f: F) -> FromFnLayer<F, S, T> {
158 FromFnLayer {
159 f,
160 state,
161 _extractor: PhantomData,
162 }
163}
164
165#[must_use]
171pub struct FromFnLayer<F, S, T> {
172 f: F,
173 state: S,
174 _extractor: PhantomData<fn() -> T>,
175}
176
177impl<F, S, T> Clone for FromFnLayer<F, S, T>
178where
179 F: Clone,
180 S: Clone,
181{
182 fn clone(&self) -> Self {
183 Self {
184 f: self.f.clone(),
185 state: self.state.clone(),
186 _extractor: self._extractor,
187 }
188 }
189}
190
191impl<S, I, F, T> Layer<I> for FromFnLayer<F, S, T>
192where
193 F: Clone,
194 S: Clone,
195{
196 type Service = FromFn<F, S, I, T>;
197
198 fn layer(&self, inner: I) -> Self::Service {
199 FromFn {
200 f: self.f.clone(),
201 state: self.state.clone(),
202 inner,
203 _extractor: PhantomData,
204 }
205 }
206}
207
208impl<F, S, T> fmt::Debug for FromFnLayer<F, S, T>
209where
210 S: fmt::Debug,
211{
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("FromFnLayer")
214 .field("f", &format_args!("{}", type_name::<F>()))
216 .field("state", &self.state)
217 .finish()
218 }
219}
220
221pub struct FromFn<F, S, I, T> {
225 f: F,
226 inner: I,
227 state: S,
228 _extractor: PhantomData<fn() -> T>,
229}
230
231impl<F, S, I, T> Clone for FromFn<F, S, I, T>
232where
233 F: Clone,
234 I: Clone,
235 S: Clone,
236{
237 fn clone(&self) -> Self {
238 Self {
239 f: self.f.clone(),
240 inner: self.inner.clone(),
241 state: self.state.clone(),
242 _extractor: self._extractor,
243 }
244 }
245}
246
247macro_rules! impl_service {
248 (
249 [$($ty:ident),*], $last:ident
250 ) => {
251 #[allow(non_snake_case, unused_mut)]
252 impl<F, Fut, Out, S, I, $($ty,)* $last> Service<Request> for FromFn<F, S, I, ($($ty,)* $last,)>
253 where
254 F: FnMut($($ty,)* $last, Next) -> Fut + Clone + Send + 'static,
255 $( $ty: FromRequestParts<S> + Send, )*
256 $last: FromRequest<S> + Send,
257 Fut: Future<Output = Out> + Send + 'static,
258 Out: IntoResponse + 'static,
259 I: Service<Request, Error = Infallible>
260 + Clone
261 + Send
262 + 'static,
263 I::Response: IntoResponse,
264 I::Future: Send + 'static,
265 S: Clone + Send + Sync + 'static,
266 {
267 type Response = Response;
268 type Error = Infallible;
269 type Future = ResponseFuture;
270
271 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 self.inner.poll_ready(cx)
273 }
274
275 fn call(&mut self, req: Request) -> Self::Future {
276 let not_ready_inner = self.inner.clone();
277 let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
278
279 let mut f = self.f.clone();
280 let state = self.state.clone();
281
282 let future = Box::pin(async move {
283 let (mut parts, body) = req.into_parts();
284
285 $(
286 let $ty = match $ty::from_request_parts(&mut parts, &state).await {
287 Ok(value) => value,
288 Err(rejection) => return rejection.into_response(),
289 };
290 )*
291
292 let req = Request::from_parts(parts, body);
293
294 let $last = match $last::from_request(req, &state).await {
295 Ok(value) => value,
296 Err(rejection) => return rejection.into_response(),
297 };
298
299 let inner = ServiceBuilder::new()
300 .boxed_clone()
301 .map_response(IntoResponse::into_response)
302 .service(ready_inner);
303 let next = Next { inner };
304
305 f($($ty,)* $last, next).await.into_response()
306 });
307
308 ResponseFuture {
309 inner: future
310 }
311 }
312 }
313 };
314}
315
316all_the_tuples!(impl_service);
317
318impl<F, S, I, T> fmt::Debug for FromFn<F, S, I, T>
319where
320 S: fmt::Debug,
321 I: fmt::Debug,
322{
323 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
324 f.debug_struct("FromFnLayer")
325 .field("f", &format_args!("{}", type_name::<F>()))
326 .field("inner", &self.inner)
327 .field("state", &self.state)
328 .finish()
329 }
330}
331
332#[derive(Debug, Clone)]
334pub struct Next {
335 inner: BoxCloneService<Request, Response, Infallible>,
336}
337
338impl Next {
339 pub async fn run(mut self, req: Request) -> Response {
341 match self.inner.call(req).await {
342 Ok(res) => res,
343 Err(err) => match err {},
344 }
345 }
346}
347
348impl Service<Request> for Next {
349 type Response = Response;
350 type Error = Infallible;
351 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
352
353 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
354 self.inner.poll_ready(cx)
355 }
356
357 fn call(&mut self, req: Request) -> Self::Future {
358 self.inner.call(req)
359 }
360}
361
362pub struct ResponseFuture {
364 inner: BoxFuture<'static, Response>,
365}
366
367impl Future for ResponseFuture {
368 type Output = Result<Response, Infallible>;
369
370 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
371 self.inner.as_mut().poll(cx).map(Ok)
372 }
373}
374
375impl fmt::Debug for ResponseFuture {
376 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
377 f.debug_struct("ResponseFuture").finish()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use crate::{body::Body, routing::get, Router};
385 use http::{HeaderMap, StatusCode};
386 use http_body_util::BodyExt;
387 use tower::ServiceExt;
388
389 #[crate::test]
390 async fn basic() {
391 async fn insert_header(mut req: Request, next: Next) -> impl IntoResponse {
392 req.headers_mut()
393 .insert("x-axum-test", "ok".parse().unwrap());
394
395 next.run(req).await
396 }
397
398 async fn handle(headers: HeaderMap) -> String {
399 headers["x-axum-test"].to_str().unwrap().to_owned()
400 }
401
402 let app = Router::new()
403 .route("/", get(handle))
404 .layer(from_fn(insert_header));
405
406 let res = app
407 .oneshot(Request::builder().uri("/").body(Body::empty()).unwrap())
408 .await
409 .unwrap();
410 assert_eq!(res.status(), StatusCode::OK);
411 let body = res.collect().await.unwrap().to_bytes();
412 assert_eq!(&body[..], b"ok");
413 }
414}