axum/extract/
connect_info.rs

1//! Extractor for getting connection information from a client.
2//!
3//! See [`Router::into_make_service_with_connect_info`] for more details.
4//!
5//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
6
7use crate::extension::AddExtension;
8
9use super::{Extension, FromRequestParts};
10use http::request::Parts;
11use std::{
12    convert::Infallible,
13    fmt,
14    future::ready,
15    marker::PhantomData,
16    net::SocketAddr,
17    task::{Context, Poll},
18};
19use tower_layer::Layer;
20use tower_service::Service;
21
22/// A [`MakeService`] created from a router.
23///
24/// See [`Router::into_make_service_with_connect_info`] for more details.
25///
26/// [`MakeService`]: tower::make::MakeService
27/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
28pub struct IntoMakeServiceWithConnectInfo<S, C> {
29    svc: S,
30    _connect_info: PhantomData<fn() -> C>,
31}
32
33#[test]
34fn traits() {
35    use crate::test_helpers::*;
36    assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
37}
38
39impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
40    pub(crate) fn new(svc: S) -> Self {
41        Self {
42            svc,
43            _connect_info: PhantomData,
44        }
45    }
46}
47
48impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
49where
50    S: fmt::Debug,
51{
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.debug_struct("IntoMakeServiceWithConnectInfo")
54            .field("svc", &self.svc)
55            .finish()
56    }
57}
58
59impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
60where
61    S: Clone,
62{
63    fn clone(&self) -> Self {
64        Self {
65            svc: self.svc.clone(),
66            _connect_info: PhantomData,
67        }
68    }
69}
70
71/// Trait that connected IO resources implement and use to produce information
72/// about the connection.
73///
74/// The goal for this trait is to allow users to implement custom IO types that
75/// can still provide the same connection metadata.
76///
77/// See [`Router::into_make_service_with_connect_info`] for more details.
78///
79/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
80pub trait Connected<T>: Clone + Send + Sync + 'static {
81    /// Create type holding information about the connection.
82    fn connect_info(stream: T) -> Self;
83}
84
85#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
86const _: () = {
87    use crate::serve;
88    use tokio::net::TcpListener;
89
90    impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
91        fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
92            *stream.remote_addr()
93        }
94    }
95
96    impl<'a, L, F> Connected<serve::IncomingStream<'a, serve::TapIo<L, F>>> for L::Addr
97    where
98        L: serve::Listener,
99        L::Addr: Clone + Sync + 'static,
100        F: FnMut(&mut L::Io) + Send + 'static,
101    {
102        fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo<L, F>>) -> Self {
103            stream.remote_addr().clone()
104        }
105    }
106};
107
108impl Connected<SocketAddr> for SocketAddr {
109    fn connect_info(remote_addr: SocketAddr) -> Self {
110        remote_addr
111    }
112}
113
114impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
115where
116    S: Clone,
117    C: Connected<T>,
118{
119    type Response = AddExtension<S, ConnectInfo<C>>;
120    type Error = Infallible;
121    type Future = ResponseFuture<S, C>;
122
123    #[inline]
124    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125        Poll::Ready(Ok(()))
126    }
127
128    fn call(&mut self, target: T) -> Self::Future {
129        let connect_info = ConnectInfo(C::connect_info(target));
130        let svc = Extension(connect_info).layer(self.svc.clone());
131        ResponseFuture::new(ready(Ok(svc)))
132    }
133}
134
135opaque_future! {
136    /// Response future for [`IntoMakeServiceWithConnectInfo`].
137    pub type ResponseFuture<S, C> =
138        std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
139}
140
141/// Extractor for getting connection information produced by a [`Connected`].
142///
143/// Note this extractor requires you to use
144/// [`Router::into_make_service_with_connect_info`] to run your app
145/// otherwise it will fail at runtime.
146///
147/// See [`Router::into_make_service_with_connect_info`] for more details.
148///
149/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
150#[derive(Clone, Copy, Debug)]
151pub struct ConnectInfo<T>(pub T);
152
153impl<S, T> FromRequestParts<S> for ConnectInfo<T>
154where
155    S: Send + Sync,
156    T: Clone + Send + Sync + 'static,
157{
158    type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
159
160    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
161        match Extension::<Self>::from_request_parts(parts, state).await {
162            Ok(Extension(connect_info)) => Ok(connect_info),
163            Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
164                Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
165                None => Err(err),
166            },
167        }
168    }
169}
170
171axum_core::__impl_deref!(ConnectInfo);
172
173/// Middleware used to mock [`ConnectInfo`] during tests.
174///
175/// If you're accidentally using [`MockConnectInfo`] and
176/// [`Router::into_make_service_with_connect_info`] at the same time then
177/// [`Router::into_make_service_with_connect_info`] takes precedence.
178///
179/// # Example
180///
181/// ```
182/// use axum::{
183///     Router,
184///     extract::connect_info::{MockConnectInfo, ConnectInfo},
185///     body::Body,
186///     routing::get,
187///     http::{Request, StatusCode},
188/// };
189/// use std::net::SocketAddr;
190/// use tower::ServiceExt;
191///
192/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
193///
194/// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
195/// fn app() -> Router {
196///     Router::new().route("/", get(handler))
197/// }
198///
199/// // use this router for tests
200/// fn test_app() -> Router {
201///     app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
202/// }
203///
204/// // #[tokio::test]
205/// async fn some_test() {
206///     let app = test_app();
207///
208///     let request = Request::new(Body::empty());
209///     let response = app.oneshot(request).await.unwrap();
210///     assert_eq!(response.status(), StatusCode::OK);
211/// }
212/// #
213/// # #[tokio::main]
214/// # async fn main() {
215/// #     some_test().await;
216/// # }
217/// ```
218///
219/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
220#[derive(Clone, Copy, Debug)]
221pub struct MockConnectInfo<T>(pub T);
222
223impl<S, T> Layer<S> for MockConnectInfo<T>
224where
225    T: Clone + Send + Sync + 'static,
226{
227    type Service = <Extension<Self> as Layer<S>>::Service;
228
229    fn layer(&self, inner: S) -> Self::Service {
230        Extension(self.clone()).layer(inner)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::{routing::get, serve::IncomingStream, test_helpers::TestClient, Router};
238    use tokio::net::TcpListener;
239
240    #[crate::test]
241    async fn socket_addr() {
242        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
243            format!("{addr}")
244        }
245
246        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
247        let addr = listener.local_addr().unwrap();
248
249        let (tx, rx) = tokio::sync::oneshot::channel();
250        tokio::spawn(async move {
251            let app = Router::new().route("/", get(handler));
252            tx.send(()).unwrap();
253            crate::serve(
254                listener,
255                app.into_make_service_with_connect_info::<SocketAddr>(),
256            )
257            .await
258            .unwrap();
259        });
260        rx.await.unwrap();
261
262        let client = reqwest::Client::new();
263
264        let res = client.get(format!("http://{addr}")).send().await.unwrap();
265        let body = res.text().await.unwrap();
266        assert!(body.starts_with("127.0.0.1:"));
267    }
268
269    #[crate::test]
270    async fn custom() {
271        #[derive(Clone, Debug)]
272        struct MyConnectInfo {
273            value: &'static str,
274        }
275
276        impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
277            fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
278                Self {
279                    value: "it worked!",
280                }
281            }
282        }
283
284        async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
285            addr.value
286        }
287
288        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
289        let addr = listener.local_addr().unwrap();
290
291        let (tx, rx) = tokio::sync::oneshot::channel();
292        tokio::spawn(async move {
293            let app = Router::new().route("/", get(handler));
294            tx.send(()).unwrap();
295            crate::serve(
296                listener,
297                app.into_make_service_with_connect_info::<MyConnectInfo>(),
298            )
299            .await
300            .unwrap();
301        });
302        rx.await.unwrap();
303
304        let client = reqwest::Client::new();
305
306        let res = client.get(format!("http://{addr}")).send().await.unwrap();
307        let body = res.text().await.unwrap();
308        assert_eq!(body, "it worked!");
309    }
310
311    #[crate::test]
312    async fn mock_connect_info() {
313        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
314            format!("{addr}")
315        }
316
317        let app = Router::new()
318            .route("/", get(handler))
319            .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
320
321        let client = TestClient::new(app);
322
323        let res = client.get("/").await;
324        let body = res.text().await;
325        assert!(body.starts_with("0.0.0.0:1337"));
326    }
327
328    #[crate::test]
329    async fn both_mock_and_real_connect_info() {
330        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
331            format!("{addr}")
332        }
333
334        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
335        let addr = listener.local_addr().unwrap();
336
337        tokio::spawn(async move {
338            let app = Router::new()
339                .route("/", get(handler))
340                .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
341
342            crate::serve(
343                listener,
344                app.into_make_service_with_connect_info::<SocketAddr>(),
345            )
346            .await
347            .unwrap();
348        });
349
350        let client = reqwest::Client::new();
351
352        let res = client.get(format!("http://{addr}")).send().await.unwrap();
353        let body = res.text().await.unwrap();
354        assert!(body.starts_with("127.0.0.1:"));
355    }
356}