1use crate::extension::AddExtension;
8
9use super::{Extension, FromRequestParts};
10use http::request::Parts;
11use std::{
12 convert::Infallible,
13 fmt,
14 future::ready,
15 marker::PhantomData,
16 net::SocketAddr,
17 task::{Context, Poll},
18};
19use tower_layer::Layer;
20use tower_service::Service;
21
22pub struct IntoMakeServiceWithConnectInfo<S, C> {
29 svc: S,
30 _connect_info: PhantomData<fn() -> C>,
31}
32
33#[test]
34fn traits() {
35 use crate::test_helpers::*;
36 assert_send::<IntoMakeServiceWithConnectInfo<(), NotSendSync>>();
37}
38
39impl<S, C> IntoMakeServiceWithConnectInfo<S, C> {
40 pub(crate) fn new(svc: S) -> Self {
41 Self {
42 svc,
43 _connect_info: PhantomData,
44 }
45 }
46}
47
48impl<S, C> fmt::Debug for IntoMakeServiceWithConnectInfo<S, C>
49where
50 S: fmt::Debug,
51{
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.debug_struct("IntoMakeServiceWithConnectInfo")
54 .field("svc", &self.svc)
55 .finish()
56 }
57}
58
59impl<S, C> Clone for IntoMakeServiceWithConnectInfo<S, C>
60where
61 S: Clone,
62{
63 fn clone(&self) -> Self {
64 Self {
65 svc: self.svc.clone(),
66 _connect_info: PhantomData,
67 }
68 }
69}
70
71pub trait Connected<T>: Clone + Send + Sync + 'static {
81 fn connect_info(stream: T) -> Self;
83}
84
85#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
86const _: () = {
87 use crate::serve;
88 use tokio::net::TcpListener;
89
90 impl Connected<serve::IncomingStream<'_, TcpListener>> for SocketAddr {
91 fn connect_info(stream: serve::IncomingStream<'_, TcpListener>) -> Self {
92 *stream.remote_addr()
93 }
94 }
95
96 impl<'a, L, F> Connected<serve::IncomingStream<'a, serve::TapIo<L, F>>> for L::Addr
97 where
98 L: serve::Listener,
99 L::Addr: Clone + Sync + 'static,
100 F: FnMut(&mut L::Io) + Send + 'static,
101 {
102 fn connect_info(stream: serve::IncomingStream<'a, serve::TapIo<L, F>>) -> Self {
103 stream.remote_addr().clone()
104 }
105 }
106};
107
108impl Connected<SocketAddr> for SocketAddr {
109 fn connect_info(remote_addr: SocketAddr) -> Self {
110 remote_addr
111 }
112}
113
114impl<S, C, T> Service<T> for IntoMakeServiceWithConnectInfo<S, C>
115where
116 S: Clone,
117 C: Connected<T>,
118{
119 type Response = AddExtension<S, ConnectInfo<C>>;
120 type Error = Infallible;
121 type Future = ResponseFuture<S, C>;
122
123 #[inline]
124 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 Poll::Ready(Ok(()))
126 }
127
128 fn call(&mut self, target: T) -> Self::Future {
129 let connect_info = ConnectInfo(C::connect_info(target));
130 let svc = Extension(connect_info).layer(self.svc.clone());
131 ResponseFuture::new(ready(Ok(svc)))
132 }
133}
134
135opaque_future! {
136 pub type ResponseFuture<S, C> =
138 std::future::Ready<Result<AddExtension<S, ConnectInfo<C>>, Infallible>>;
139}
140
141#[derive(Clone, Copy, Debug)]
151pub struct ConnectInfo<T>(pub T);
152
153impl<S, T> FromRequestParts<S> for ConnectInfo<T>
154where
155 S: Send + Sync,
156 T: Clone + Send + Sync + 'static,
157{
158 type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
159
160 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
161 match Extension::<Self>::from_request_parts(parts, state).await {
162 Ok(Extension(connect_info)) => Ok(connect_info),
163 Err(err) => match parts.extensions.get::<MockConnectInfo<T>>() {
164 Some(MockConnectInfo(connect_info)) => Ok(Self(connect_info.clone())),
165 None => Err(err),
166 },
167 }
168 }
169}
170
171axum_core::__impl_deref!(ConnectInfo);
172
173#[derive(Clone, Copy, Debug)]
221pub struct MockConnectInfo<T>(pub T);
222
223impl<S, T> Layer<S> for MockConnectInfo<T>
224where
225 T: Clone + Send + Sync + 'static,
226{
227 type Service = <Extension<Self> as Layer<S>>::Service;
228
229 fn layer(&self, inner: S) -> Self::Service {
230 Extension(self.clone()).layer(inner)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::{routing::get, serve::IncomingStream, test_helpers::TestClient, Router};
238 use tokio::net::TcpListener;
239
240 #[crate::test]
241 async fn socket_addr() {
242 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
243 format!("{addr}")
244 }
245
246 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
247 let addr = listener.local_addr().unwrap();
248
249 let (tx, rx) = tokio::sync::oneshot::channel();
250 tokio::spawn(async move {
251 let app = Router::new().route("/", get(handler));
252 tx.send(()).unwrap();
253 crate::serve(
254 listener,
255 app.into_make_service_with_connect_info::<SocketAddr>(),
256 )
257 .await
258 .unwrap();
259 });
260 rx.await.unwrap();
261
262 let client = reqwest::Client::new();
263
264 let res = client.get(format!("http://{addr}")).send().await.unwrap();
265 let body = res.text().await.unwrap();
266 assert!(body.starts_with("127.0.0.1:"));
267 }
268
269 #[crate::test]
270 async fn custom() {
271 #[derive(Clone, Debug)]
272 struct MyConnectInfo {
273 value: &'static str,
274 }
275
276 impl Connected<IncomingStream<'_, TcpListener>> for MyConnectInfo {
277 fn connect_info(_target: IncomingStream<'_, TcpListener>) -> Self {
278 Self {
279 value: "it worked!",
280 }
281 }
282 }
283
284 async fn handler(ConnectInfo(addr): ConnectInfo<MyConnectInfo>) -> &'static str {
285 addr.value
286 }
287
288 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
289 let addr = listener.local_addr().unwrap();
290
291 let (tx, rx) = tokio::sync::oneshot::channel();
292 tokio::spawn(async move {
293 let app = Router::new().route("/", get(handler));
294 tx.send(()).unwrap();
295 crate::serve(
296 listener,
297 app.into_make_service_with_connect_info::<MyConnectInfo>(),
298 )
299 .await
300 .unwrap();
301 });
302 rx.await.unwrap();
303
304 let client = reqwest::Client::new();
305
306 let res = client.get(format!("http://{addr}")).send().await.unwrap();
307 let body = res.text().await.unwrap();
308 assert_eq!(body, "it worked!");
309 }
310
311 #[crate::test]
312 async fn mock_connect_info() {
313 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
314 format!("{addr}")
315 }
316
317 let app = Router::new()
318 .route("/", get(handler))
319 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
320
321 let client = TestClient::new(app);
322
323 let res = client.get("/").await;
324 let body = res.text().await;
325 assert!(body.starts_with("0.0.0.0:1337"));
326 }
327
328 #[crate::test]
329 async fn both_mock_and_real_connect_info() {
330 async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
331 format!("{addr}")
332 }
333
334 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
335 let addr = listener.local_addr().unwrap();
336
337 tokio::spawn(async move {
338 let app = Router::new()
339 .route("/", get(handler))
340 .layer(MockConnectInfo(SocketAddr::from(([0, 0, 0, 0], 1337))));
341
342 crate::serve(
343 listener,
344 app.into_make_service_with_connect_info::<SocketAddr>(),
345 )
346 .await
347 .unwrap();
348 });
349
350 let client = reqwest::Client::new();
351
352 let res = client.get(format!("http://{addr}")).send().await.unwrap();
353 let body = res.text().await.unwrap();
354 assert!(body.starts_with("127.0.0.1:"));
355 }
356}