axum/extract/
connect_info.rs
1use crate::extension::AddExtension;
8
9use super::{Extension, FromRequestParts};
10use async_trait::async_trait;
11use http::request::Parts;
12use std::{
13 convert::Infallible,
14 fmt,
15 future::ready,
16 marker::PhantomData,
17 net::SocketAddr,
18 task::{Context, Poll},
19};
20use tower_layer::Layer;
21use tower_service::Service;
22
23pub struct IntoMakeServiceWithConnectInfo<S, C> {
30 svc: S,
31 _connect_info: PhantomData<fn() -> C>,
32}
33
34#[test]
35fn traits() {
36 use crate::test_helpers::*;
37 assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
38}
39
40impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
41 pub(crate) fn new(svc: S) -> Self {
42 Self {
43 svc,
44 _connect_info: PhantomData,
45 }
46 }
47}
48
49impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
50where
51 S: fmt::Debug,
52{
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("IntoMakeServiceWithConnectInfo")
55 .field("svc", &self.svc)
56 .finish()
57 }
58}
59
60impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
61where
62 S: Clone,
63{
64 fn clone(&self) -> Self {
65 Self {
66 svc: self.svc.clone(),
67 _connect_info: PhantomData,
68 }
69 }
70}
71
72pub trait Connected<T>: Clone + Send + Sync + 'static {
82 fn connect_info(target: T) -> Self;
84}
85
86#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
87const _: () = {
88 use crate::serve::IncomingStream;
89
90 impl Connected<IncomingStream<'_>> for SocketAddr {
91 fn connect_info(target: IncomingStream<'_>) -> Self {
92 target.remote_addr()
93 }
94 }
95};
96
97impl Connected<SocketAddr> for SocketAddr {
98 fn connect_info(remote_addr: SocketAddr) -> Self {
99 remote_addr
100 }
101}
102
103impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
104where
105 S: Clone,
106 C: Connected<T>,
107{
108 type Response = AddExtension<S, ConnectInfo<C>>;
109 type Error = Infallible;
110 type Future = ResponseFuture<S, C>;
111
112 #[inline]
113 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114 Poll::Ready(Ok(()))
115 }
116
117 fn call(&mut self, target: T) -> Self::Future {
118 let connect_info = ConnectInfo(C::connect_info(target));
119 let svc = Extension(connect_info).layer(self.svc.clone());
120 ResponseFuture::new(ready(Ok(svc)))
121 }
122}
123
124opaque_future! {
125 pub type ResponseFuture<S, C> =
127 std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
128}
129
130#[derive(Clone, Copy, Debug)]
140pub struct ConnectInfo<T>(pub T);
141
142#[async_trait]
143impl<S, T> FromRequestParts<S> for ConnectInfo<T>
144where
145 S: Send + Sync,
146 T: Clone + Send + Sync + 'static,
147{
148 type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
149
150 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
151 match Extension::<Self>::from_request_parts(parts, state).await {
152 Ok(Extension(connect_info)) => Ok(connect_info),
153 Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
154 Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
155 None => Err(err),
156 },
157 }
158 }
159}
160
161axum_core::__impl_deref!(ConnectInfo);
162
163#[derive(Clone, Copy, Debug)]
211pub struct MockConnectInfo<T>(pub T);
212
213impl<S, T> Layer<S> for MockConnectInfo<T>
214where
215 T: Clone + Send + Sync + 'static,
216{
217 type Service = <Extension<Self> as Layer<S>>::Service;
218
219 fn layer(&self, inner: S) -> Self::Service {
220 Extension(self.clone()).layer(inner)
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::{routing::get, serve::IncomingStream, test_helpers::TestClient, Router};
228 use tokio::net::TcpListener;
229
230 #[crate::test]
231 async fn socket_addr() {
232 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
233 format!("{addr}")
234 }
235
236 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
237 let addr = listener.local_addr().unwrap();
238
239 let (tx, rx) = tokio::sync::oneshot::channel();
240 tokio::spawn(async move {
241 let app = Router::new().route("/", get(handler));
242 tx.send(()).unwrap();
243 crate::serve(
244 listener,
245 app.into_make_service_with_connect_info::<SocketAddr>(),
246 )
247 .await
248 .unwrap();
249 });
250 rx.await.unwrap();
251
252 let client = reqwest::Client::new();
253
254 let res = client.get(format!("http://{addr}")).send().await.unwrap();
255 let body = res.text().await.unwrap();
256 assert!(body.starts_with("127.0.0.1:"));
257 }
258
259 #[crate::test]
260 async fn custom() {
261 #[derive(Clone, Debug)]
262 struct MyConnectInfo {
263 value: &'static str,
264 }
265
266 impl Connected<IncomingStream<'_>> for MyConnectInfo {
267 fn connect_info(_target: IncomingStream<'_>) -> Self {
268 Self {
269 value: "it worked!",
270 }
271 }
272 }
273
274 async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
275 addr.value
276 }
277
278 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
279 let addr = listener.local_addr().unwrap();
280
281 let (tx, rx) = tokio::sync::oneshot::channel();
282 tokio::spawn(async move {
283 let app = Router::new().route("/", get(handler));
284 tx.send(()).unwrap();
285 crate::serve(
286 listener,
287 app.into_make_service_with_connect_info::<MyConnectInfo>(),
288 )
289 .await
290 .unwrap();
291 });
292 rx.await.unwrap();
293
294 let client = reqwest::Client::new();
295
296 let res = client.get(format!("http://{addr}")).send().await.unwrap();
297 let body = res.text().await.unwrap();
298 assert_eq!(body, "it worked!");
299 }
300
301 #[crate::test]
302 async fn mock_connect_info() {
303 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
304 format!("{addr}")
305 }
306
307 let app = Router::new()
308 .route("/", get(handler))
309 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
310
311 let client = TestClient::new(app);
312
313 let res = client.get("/").await;
314 let body = res.text().await;
315 assert!(body.starts_with("0.0.0.0:1337"));
316 }
317
318 #[crate::test]
319 async fn both_mock_and_real_connect_info() {
320 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
321 format!("{addr}")
322 }
323
324 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
325 let addr = listener.local_addr().unwrap();
326
327 tokio::spawn(async move {
328 let app = Router::new()
329 .route("/", get(handler))
330 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
331
332 crate::serve(
333 listener,
334 app.into_make_service_with_connect_info::<SocketAddr>(),
335 )
336 .await
337 .unwrap();
338 });
339
340 let client = reqwest::Client::new();
341
342 let res = client.get(format!("http://{addr}")).send().await.unwrap();
343 let body = res.text().await.unwrap();
344 assert!(body.starts_with("127.0.0.1:"));
345 }
346}