1use crate::{
2 extract::FromRequestParts,
3 response::{IntoResponse, Response},
4};
5use futures_util::future::BoxFuture;
6use http::Request;
7use pin_project_lite::pin_project;
8use std::{
9 fmt,
10 future::Future,
11 marker::PhantomData,
12 pin::Pin,
13 task::{ready, Context, Poll},
14};
15use tower_layer::Layer;
16use tower_service::Service;
17
18pub fn from_extractor<E>() -> FromExtractorLayer<E, ()> {
90 from_extractor_with_state(())
91}
92
93pub fn from_extractor_with_state<E, S>(state: S) -> FromExtractorLayer<E, S> {
97 FromExtractorLayer {
98 state,
99 _marker: PhantomData,
100 }
101}
102
103#[must_use]
110pub struct FromExtractorLayer<E, S> {
111 state: S,
112 _marker: PhantomData<fn() -> E>,
113}
114
115impl<E, S> Clone for FromExtractorLayer<E, S>
116where
117 S: Clone,
118{
119 fn clone(&self) -> Self {
120 Self {
121 state: self.state.clone(),
122 _marker: PhantomData,
123 }
124 }
125}
126
127impl<E, S> fmt::Debug for FromExtractorLayer<E, S>
128where
129 S: fmt::Debug,
130{
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("FromExtractorLayer")
133 .field("state", &self.state)
134 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
135 .finish()
136 }
137}
138
139impl<E, T, S> Layer<T> for FromExtractorLayer<E, S>
140where
141 S: Clone,
142{
143 type Service = FromExtractor<T, E, S>;
144
145 fn layer(&self, inner: T) -> Self::Service {
146 FromExtractor {
147 inner,
148 state: self.state.clone(),
149 _extractor: PhantomData,
150 }
151 }
152}
153
154pub struct FromExtractor<T, E, S> {
158 inner: T,
159 state: S,
160 _extractor: PhantomData<fn() -> E>,
161}
162
163#[test]
164fn traits() {
165 use crate::test_helpers::*;
166 assert_send::<FromExtractor<(), NotSendSync, ()>>();
167 assert_sync::<FromExtractor<(), NotSendSync, ()>>();
168}
169
170impl<T, E, S> Clone for FromExtractor<T, E, S>
171where
172 T: Clone,
173 S: Clone,
174{
175 fn clone(&self) -> Self {
176 Self {
177 inner: self.inner.clone(),
178 state: self.state.clone(),
179 _extractor: PhantomData,
180 }
181 }
182}
183
184impl<T, E, S> fmt::Debug for FromExtractor<T, E, S>
185where
186 T: fmt::Debug,
187 S: fmt::Debug,
188{
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 f.debug_struct("FromExtractor")
191 .field("inner", &self.inner)
192 .field("state", &self.state)
193 .field("extractor", &format_args!("{}", std::any::type_name::<E>()))
194 .finish()
195 }
196}
197
198impl<T, E, B, S> Service<Request<B>> for FromExtractor<T, E, S>
199where
200 E: FromRequestParts<S> + 'static,
201 B: Send + 'static,
202 T: Service<Request<B>> + Clone,
203 T::Response: IntoResponse,
204 S: Clone + Send + Sync + 'static,
205{
206 type Response = Response;
207 type Error = T::Error;
208 type Future = ResponseFuture<B, T, E, S>;
209
210 #[inline]
211 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212 self.inner.poll_ready(cx)
213 }
214
215 fn call(&mut self, req: Request<B>) -> Self::Future {
216 let state = self.state.clone();
217 let (mut parts, body) = req.into_parts();
218
219 let extract_future = Box::pin(async move {
220 let extracted = E::from_request_parts(&mut parts, &state).await;
221 let req = Request::from_parts(parts, body);
222 (req, extracted)
223 });
224
225 ResponseFuture {
226 state: State::Extracting {
227 future: extract_future,
228 },
229 svc: Some(self.inner.clone()),
230 }
231 }
232}
233
234pin_project! {
235 #[allow(missing_debug_implementations)]
237 pub struct ResponseFuture<B, T, E, S>
238 where
239 E: FromRequestParts<S>,
240 T: Service<Request<B>>,
241 {
242 #[pin]
243 state: State<B, T, E, S>,
244 svc: Option<T>,
245 }
246}
247
248pin_project! {
249 #[project = StateProj]
250 enum State<B, T, E, S>
251 where
252 E: FromRequestParts<S>,
253 T: Service<Request<B>>,
254 {
255 Extracting {
256 future: BoxFuture<'static, (Request<B>, Result<E, E::Rejection>)>,
257 },
258 Call { #[pin] future: T::Future },
259 }
260}
261
262impl<B, T, E, S> Future for ResponseFuture<B, T, E, S>
263where
264 E: FromRequestParts<S>,
265 T: Service<Request<B>>,
266 T::Response: IntoResponse,
267{
268 type Output = Result<Response, T::Error>;
269
270 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
271 loop {
272 let mut this = self.as_mut().project();
273
274 let new_state = match this.state.as_mut().project() {
275 StateProj::Extracting { future } => {
276 let (req, extracted) = ready!(future.as_mut().poll(cx));
277
278 match extracted {
279 Ok(_) => {
280 let mut svc = this.svc.take().expect("future polled after completion");
281 let future = svc.call(req);
282 State::Call { future }
283 }
284 Err(err) => {
285 let res = err.into_response();
286 return Poll::Ready(Ok(res));
287 }
288 }
289 }
290 StateProj::Call { future } => {
291 return future
292 .poll(cx)
293 .map(|result| result.map(IntoResponse::into_response));
294 }
295 };
296
297 this.state.set(new_state);
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::{handler::Handler, routing::get, test_helpers::*, Router};
306 use axum_core::extract::FromRef;
307 use http::{header, request::Parts, StatusCode};
308 use tower_http::limit::RequestBodyLimitLayer;
309
310 #[crate::test]
311 async fn test_from_extractor() {
312 #[derive(Clone)]
313 struct Secret(&'static str);
314
315 struct RequireAuth;
316
317 impl<S> FromRequestParts<S> for RequireAuth
318 where
319 S: Send + Sync,
320 Secret: FromRef<S>,
321 {
322 type Rejection = StatusCode;
323
324 async fn from_request_parts(
325 parts: &mut Parts,
326 state: &S,
327 ) -> Result<Self, Self::Rejection> {
328 let Secret(secret) = Secret::from_ref(state);
329 if let Some(auth) = parts
330 .headers
331 .get(header::AUTHORIZATION)
332 .and_then(|v| v.to_str().ok())
333 {
334 if auth == secret {
335 return Ok(Self);
336 }
337 }
338
339 Err(StatusCode::UNAUTHORIZED)
340 }
341 }
342
343 async fn handler() {}
344
345 let state = Secret("secret");
346 let app = Router::new().route(
347 "/",
348 get(handler.layer(from_extractor_with_state::<RequireAuth, _>(state))),
349 );
350
351 let client = TestClient::new(app);
352
353 let res = client.get("/").await;
354 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
355
356 let res = client
357 .get("/")
358 .header(http::header::AUTHORIZATION, "secret")
359 .await;
360 assert_eq!(res.status(), StatusCode::OK);
361 }
362
363 #[allow(dead_code)]
365 fn works_with_request_body_limit() {
366 struct MyExtractor;
367
368 impl<S> FromRequestParts<S> for MyExtractor
369 where
370 S: Send + Sync,
371 {
372 type Rejection = std::convert::Infallible;
373
374 async fn from_request_parts(
375 _parts: &mut Parts,
376 _state: &S,
377 ) -> Result<Self, Self::Rejection> {
378 unimplemented!()
379 }
380 }
381
382 let _: Router = Router::new()
383 .layer(from_extractor::<MyExtractor>())
384 .layer(RequestBodyLimitLayer::new(1));
385 }
386}