axum/extract/
connect_info.rs

1//! Extractor for getting connection information from a client.
2//!
3//! See [`Router::into_make_service_with_connect_info`] for more details.
4//!
5//! [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
6
7use crate::extension::AddExtension;
8
9use super::{Extension, FromRequestParts};
10use async_trait::async_trait;
11use http::request::Parts;
12use std::{
13    convert::Infallible,
14    fmt,
15    future::ready,
16    marker::PhantomData,
17    net::SocketAddr,
18    task::{Context, Poll},
19};
20use tower_layer::Layer;
21use tower_service::Service;
22
23/// A [`MakeService`] created from a router.
24///
25/// See [`Router::into_make_service_with_connect_info`] for more details.
26///
27/// [`MakeService`]: tower::make::MakeService
28/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
29pub struct IntoMakeServiceWithConnectInfo<S, C> {
30    svc: S,
31    _connect_info: PhantomData<fn() -> C>,
32}
33
34#[test]
35fn traits() {
36    use crate::test_helpers::*;
37    assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
38}
39
40impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
41    pub(crate) fn new(svc: S) -> Self {
42        Self {
43            svc,
44            _connect_info: PhantomData,
45        }
46    }
47}
48
49impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
50where
51    S: fmt::Debug,
52{
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("IntoMakeServiceWithConnectInfo")
55            .field("svc", &self.svc)
56            .finish()
57    }
58}
59
60impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
61where
62    S: Clone,
63{
64    fn clone(&self) -> Self {
65        Self {
66            svc: self.svc.clone(),
67            _connect_info: PhantomData,
68        }
69    }
70}
71
72/// Trait that connected IO resources implement and use to produce information
73/// about the connection.
74///
75/// The goal for this trait is to allow users to implement custom IO types that
76/// can still provide the same connection metadata.
77///
78/// See [`Router::into_make_service_with_connect_info`] for more details.
79///
80/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
81pub trait Connected<T>: Clone + Send + Sync + 'static {
82    /// Create type holding information about the connection.
83    fn connect_info(target: T) -> Self;
84}
85
86#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
87const _: () = {
88    use crate::serve::IncomingStream;
89
90    impl Connected<IncomingStream<'_>> for SocketAddr {
91        fn connect_info(target: IncomingStream<'_>) -> Self {
92            target.remote_addr()
93        }
94    }
95};
96
97impl Connected<SocketAddr> for SocketAddr {
98    fn connect_info(remote_addr: SocketAddr) -> Self {
99        remote_addr
100    }
101}
102
103impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
104where
105    S: Clone,
106    C: Connected<T>,
107{
108    type Response = AddExtension<S, ConnectInfo<C>>;
109    type Error = Infallible;
110    type Future = ResponseFuture<S, C>;
111
112    #[inline]
113    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        Poll::Ready(Ok(()))
115    }
116
117    fn call(&mut self, target: T) -> Self::Future {
118        let connect_info = ConnectInfo(C::connect_info(target));
119        let svc = Extension(connect_info).layer(self.svc.clone());
120        ResponseFuture::new(ready(Ok(svc)))
121    }
122}
123
124opaque_future! {
125    /// Response future for [`IntoMakeServiceWithConnectInfo`].
126    pub type ResponseFuture<S, C> =
127        std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
128}
129
130/// Extractor for getting connection information produced by a [`Connected`].
131///
132/// Note this extractor requires you to use
133/// [`Router::into_make_service_with_connect_info`] to run your app
134/// otherwise it will fail at runtime.
135///
136/// See [`Router::into_make_service_with_connect_info`] for more details.
137///
138/// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
139#[derive(Clone, Copy, Debug)]
140pub struct ConnectInfo<T>(pub T);
141
142#[async_trait]
143impl<S, T> FromRequestParts<S> for ConnectInfo<T>
144where
145    S: Send + Sync,
146    T: Clone + Send + Sync + 'static,
147{
148    type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
149
150    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
151        match Extension::<Self>::from_request_parts(parts, state).await {
152            Ok(Extension(connect_info)) => Ok(connect_info),
153            Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
154                Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
155                None => Err(err),
156            },
157        }
158    }
159}
160
161axum_core::__impl_deref!(ConnectInfo);
162
163/// Middleware used to mock [`ConnectInfo`] during tests.
164///
165/// If you're accidentally using [`MockConnectInfo`] and
166/// [`Router::into_make_service_with_connect_info`] at the same time then
167/// [`Router::into_make_service_with_connect_info`] takes precedence.
168///
169/// # Example
170///
171/// ```
172/// use axum::{
173///     Router,
174///     extract::connect_info::{MockConnectInfo, ConnectInfo},
175///     body::Body,
176///     routing::get,
177///     http::{Request, StatusCode},
178/// };
179/// use std::net::SocketAddr;
180/// use tower::ServiceExt;
181///
182/// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) {}
183///
184/// // this router you can run with `app.into_make_service_with_connect_info::<SocketAddr>()`
185/// fn app() -> Router {
186///     Router::new().route("/", get(handler))
187/// }
188///
189/// // use this router for tests
190/// fn test_app() -> Router {
191///     app().layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))))
192/// }
193///
194/// // #[tokio::test]
195/// async fn some_test() {
196///     let app = test_app();
197///
198///     let request = Request::new(Body::empty());
199///     let response = app.oneshot(request).await.unwrap();
200///     assert_eq!(response.status(), StatusCode::OK);
201/// }
202/// #
203/// # #[tokio::main]
204/// # async fn main() {
205/// #     some_test().await;
206/// # }
207/// ```
208///
209/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
210#[derive(Clone, Copy, Debug)]
211pub struct MockConnectInfo<T>(pub T);
212
213impl<S, T> Layer<S> for MockConnectInfo<T>
214where
215    T: Clone + Send + Sync + 'static,
216{
217    type Service = <Extension<Self> as Layer<S>>::Service;
218
219    fn layer(&self, inner: S) -> Self::Service {
220        Extension(self.clone()).layer(inner)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::{routing::get, serve::IncomingStream, test_helpers::TestClient, Router};
228    use tokio::net::TcpListener;
229
230    #[crate::test]
231    async fn socket_addr() {
232        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
233            format!("{addr}")
234        }
235
236        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
237        let addr = listener.local_addr().unwrap();
238
239        let (tx, rx) = tokio::sync::oneshot::channel();
240        tokio::spawn(async move {
241            let app = Router::new().route("/", get(handler));
242            tx.send(()).unwrap();
243            crate::serve(
244                listener,
245                app.into_make_service_with_connect_info::<SocketAddr>(),
246            )
247            .await
248            .unwrap();
249        });
250        rx.await.unwrap();
251
252        let client = reqwest::Client::new();
253
254        let res = client.get(format!("http://{addr}")).send().await.unwrap();
255        let body = res.text().await.unwrap();
256        assert!(body.starts_with("127.0.0.1:"));
257    }
258
259    #[crate::test]
260    async fn custom() {
261        #[derive(Clone, Debug)]
262        struct MyConnectInfo {
263            value: &'static str,
264        }
265
266        impl Connected<IncomingStream<'_>> for MyConnectInfo {
267            fn connect_info(_target: IncomingStream<'_>) -> Self {
268                Self {
269                    value: "it worked!",
270                }
271            }
272        }
273
274        async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
275            addr.value
276        }
277
278        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
279        let addr = listener.local_addr().unwrap();
280
281        let (tx, rx) = tokio::sync::oneshot::channel();
282        tokio::spawn(async move {
283            let app = Router::new().route("/", get(handler));
284            tx.send(()).unwrap();
285            crate::serve(
286                listener,
287                app.into_make_service_with_connect_info::<MyConnectInfo>(),
288            )
289            .await
290            .unwrap();
291        });
292        rx.await.unwrap();
293
294        let client = reqwest::Client::new();
295
296        let res = client.get(format!("http://{addr}")).send().await.unwrap();
297        let body = res.text().await.unwrap();
298        assert_eq!(body, "it worked!");
299    }
300
301    #[crate::test]
302    async fn mock_connect_info() {
303        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
304            format!("{addr}")
305        }
306
307        let app = Router::new()
308            .route("/", get(handler))
309            .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
310
311        let client = TestClient::new(app);
312
313        let res = client.get("/").await;
314        let body = res.text().await;
315        assert!(body.starts_with("0.0.0.0:1337"));
316    }
317
318    #[crate::test]
319    async fn both_mock_and_real_connect_info() {
320        async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
321            format!("{addr}")
322        }
323
324        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
325        let addr = listener.local_addr().unwrap();
326
327        tokio::spawn(async move {
328            let app = Router::new()
329                .route("/", get(handler))
330                .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
331
332            crate::serve(
333                listener,
334                app.into_make_service_with_connect_info::<SocketAddr>(),
335            )
336            .await
337            .unwrap();
338        });
339
340        let client = reqwest::Client::new();
341
342        let res = client.get(format!("http://{addr}")).send().await.unwrap();
343        let body = res.text().await.unwrap();
344        assert!(body.starts_with("127.0.0.1:"));
345    }
346}