1use std::{
4 convert::Infallible,
5 fmt::Debug,
6 future::{poll_fn, Future, IntoFuture},
7 io,
8 marker::PhantomData,
9 net::SocketAddr,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13 time::Duration,
14};
15
16use axum_core::{body::Body, extract::Request, response::Response};
17use futures_util::{pin_mut, FutureExt};
18use hyper::body::Incoming;
19use hyper_util::rt::{TokioExecutor, TokioIo};
20#[cfg(any(feature = "http1", feature = "http2"))]
21use hyper_util::server::conn::auto::Builder;
22use pin_project_lite::pin_project;
23use tokio::{
24 net::{TcpListener, TcpStream},
25 sync::watch,
26};
27use tower::util::{Oneshot, ServiceExt};
28use tower_service::Service;
29
30#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
95pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
96where
97 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
98 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
99 S::Future: Send,
100{
101 Serve {
102 tcp_listener,
103 make_service,
104 _marker: PhantomData,
105 }
106}
107
108#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
110#[must_use = "futures must be awaited or polled"]
111pub struct Serve<M, S> {
112 tcp_listener: TcpListener,
113 make_service: M,
114 _marker: PhantomData<S>,
115}
116
117#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
118impl<M, S> Serve<M, S> {
119 pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
141 where
142 F: Future<Output = ()> + Send + 'static,
143 {
144 WithGracefulShutdown {
145 tcp_listener: self.tcp_listener,
146 make_service: self.make_service,
147 signal,
148 _marker: PhantomData,
149 }
150 }
151}
152
153#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
154impl<M, S> Debug for Serve<M, S>
155where
156 M: Debug,
157{
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 let Self {
160 tcp_listener,
161 make_service,
162 _marker: _,
163 } = self;
164
165 f.debug_struct("Serve")
166 .field("tcp_listener", tcp_listener)
167 .field("make_service", make_service)
168 .finish()
169 }
170}
171
172#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
173impl<M, S> IntoFuture for Serve<M, S>
174where
175 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
176 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
177 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
178 S::Future: Send,
179{
180 type Output = io::Result<()>;
181 type IntoFuture = private::ServeFuture;
182
183 fn into_future(self) -> Self::IntoFuture {
184 private::ServeFuture(Box::pin(async move {
185 let Self {
186 tcp_listener,
187 mut make_service,
188 _marker: _,
189 } = self;
190
191 loop {
192 let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
193 Some(conn) => conn,
194 None => continue,
195 };
196 let tcp_stream = TokioIo::new(tcp_stream);
197
198 poll_fn(|cx| make_service.poll_ready(cx))
199 .await
200 .unwrap_or_else(|err| match err {});
201
202 let tower_service = make_service
203 .call(IncomingStream {
204 tcp_stream: &tcp_stream,
205 remote_addr,
206 })
207 .await
208 .unwrap_or_else(|err| match err {});
209
210 let hyper_service = TowerToHyperService {
211 service: tower_service,
212 };
213
214 tokio::spawn(async move {
215 match Builder::new(TokioExecutor::new())
216 .serve_connection_with_upgrades(tcp_stream, hyper_service)
218 .await
219 {
220 Ok(()) => {}
221 Err(_err) => {
222 }
228 }
229 });
230 }
231 }))
232 }
233}
234
235#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
237#[must_use = "futures must be awaited or polled"]
238pub struct WithGracefulShutdown<M, S, F> {
239 tcp_listener: TcpListener,
240 make_service: M,
241 signal: F,
242 _marker: PhantomData<S>,
243}
244
245#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
246impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
247where
248 M: Debug,
249 S: Debug,
250 F: Debug,
251{
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 let Self {
254 tcp_listener,
255 make_service,
256 signal,
257 _marker: _,
258 } = self;
259
260 f.debug_struct("WithGracefulShutdown")
261 .field("tcp_listener", tcp_listener)
262 .field("make_service", make_service)
263 .field("signal", signal)
264 .finish()
265 }
266}
267
268#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
269impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
270where
271 M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
272 for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
273 S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
274 S::Future: Send,
275 F: Future<Output = ()> + Send + 'static,
276{
277 type Output = io::Result<()>;
278 type IntoFuture = private::ServeFuture;
279
280 fn into_future(self) -> Self::IntoFuture {
281 let Self {
282 tcp_listener,
283 mut make_service,
284 signal,
285 _marker: _,
286 } = self;
287
288 let (signal_tx, signal_rx) = watch::channel(());
289 let signal_tx = Arc::new(signal_tx);
290 tokio::spawn(async move {
291 signal.await;
292 trace!("received graceful shutdown signal. Telling tasks to shutdown");
293 drop(signal_rx);
294 });
295
296 let (close_tx, close_rx) = watch::channel(());
297
298 private::ServeFuture(Box::pin(async move {
299 loop {
300 let (tcp_stream, remote_addr) = tokio::select! {
301 conn = tcp_accept(&tcp_listener) => {
302 match conn {
303 Some(conn) => conn,
304 None => continue,
305 }
306 }
307 _ = signal_tx.closed() => {
308 trace!("signal received, not accepting new connections");
309 break;
310 }
311 };
312 let tcp_stream = TokioIo::new(tcp_stream);
313
314 trace!("connection {remote_addr} accepted");
315
316 poll_fn(|cx| make_service.poll_ready(cx))
317 .await
318 .unwrap_or_else(|err| match err {});
319
320 let tower_service = make_service
321 .call(IncomingStream {
322 tcp_stream: &tcp_stream,
323 remote_addr,
324 })
325 .await
326 .unwrap_or_else(|err| match err {});
327
328 let hyper_service = TowerToHyperService {
329 service: tower_service,
330 };
331
332 let signal_tx = Arc::clone(&signal_tx);
333
334 let close_rx = close_rx.clone();
335
336 tokio::spawn(async move {
337 let builder = Builder::new(TokioExecutor::new());
338 let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
339 pin_mut!(conn);
340
341 let signal_closed = signal_tx.closed().fuse();
342 pin_mut!(signal_closed);
343
344 loop {
345 tokio::select! {
346 result = conn.as_mut() => {
347 if let Err(_err) = result {
348 trace!("failed to serve connection: {_err:#}");
349 }
350 break;
351 }
352 _ = &mut signal_closed => {
353 trace!("signal received in task, starting graceful shutdown");
354 conn.as_mut().graceful_shutdown();
355 }
356 }
357 }
358
359 trace!("connection {remote_addr} closed");
360
361 drop(close_rx);
362 });
363 }
364
365 drop(close_rx);
366 drop(tcp_listener);
367
368 trace!(
369 "waiting for {} task(s) to finish",
370 close_tx.receiver_count()
371 );
372 close_tx.closed().await;
373
374 Ok(())
375 }))
376 }
377}
378
379fn is_connection_error(e: &io::Error) -> bool {
380 matches!(
381 e.kind(),
382 io::ErrorKind::ConnectionRefused
383 | io::ErrorKind::ConnectionAborted
384 | io::ErrorKind::ConnectionReset
385 )
386}
387
388async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
389 match listener.accept().await {
390 Ok(conn) => Some(conn),
391 Err(e) => {
392 if is_connection_error(&e) {
393 return None;
394 }
395
396 error!("accept error: {e}");
408 tokio::time::sleep(Duration::from_secs(1)).await;
409 None
410 }
411 }
412}
413
414mod private {
415 use std::{
416 future::Future,
417 io,
418 pin::Pin,
419 task::{Context, Poll},
420 };
421
422 pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
423
424 impl Future for ServeFuture {
425 type Output = io::Result<()>;
426
427 #[inline]
428 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
429 self.0.as_mut().poll(cx)
430 }
431 }
432
433 impl std::fmt::Debug for ServeFuture {
434 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435 f.debug_struct("ServeFuture").finish_non_exhaustive()
436 }
437 }
438}
439
440#[derive(Debug, Copy, Clone)]
441struct TowerToHyperService<S> {
442 service: S,
443}
444
445impl<S> hyper::service::Service<Request<Incoming>> for TowerToHyperService<S>
446where
447 S: tower_service::Service<Request> + Clone,
448{
449 type Response = S::Response;
450 type Error = S::Error;
451 type Future = TowerToHyperServiceFuture<S, Request>;
452
453 fn call(&self, req: Request<Incoming>) -> Self::Future {
454 let req = req.map(Body::new);
455 TowerToHyperServiceFuture {
456 future: self.service.clone().oneshot(req),
457 }
458 }
459}
460
461pin_project! {
462 struct TowerToHyperServiceFuture<S, R>
463 where
464 S: tower_service::Service<R>,
465 {
466 #[pin]
467 future: Oneshot<S, R>,
468 }
469}
470
471impl<S, R> Future for TowerToHyperServiceFuture<S, R>
472where
473 S: tower_service::Service<R>,
474{
475 type Output = Result<S::Response, S::Error>;
476
477 #[inline]
478 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479 self.project().future.poll(cx)
480 }
481}
482
483#[derive(Debug)]
489pub struct IncomingStream<'a> {
490 tcp_stream: &'a TokioIo<TcpStream>,
491 remote_addr: SocketAddr,
492}
493
494impl IncomingStream<'_> {
495 pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
497 self.tcp_stream.inner().local_addr()
498 }
499
500 pub fn remote_addr(&self) -> SocketAddr {
502 self.remote_addr
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509 use crate::{
510 handler::{Handler, HandlerWithoutStateExt},
511 routing::get,
512 Router,
513 };
514
515 #[allow(dead_code, unused_must_use)]
516 async fn if_it_compiles_it_works() {
517 let router: Router = Router::new();
518
519 let addr = "0.0.0.0:0";
520
521 serve(TcpListener::bind(addr).await.unwrap(), router.clone());
523 serve(
524 TcpListener::bind(addr).await.unwrap(),
525 router.clone().into_make_service(),
526 );
527 serve(
528 TcpListener::bind(addr).await.unwrap(),
529 router.into_make_service_with_connect_info::<SocketAddr>(),
530 );
531
532 serve(TcpListener::bind(addr).await.unwrap(), get(handler));
534 serve(
535 TcpListener::bind(addr).await.unwrap(),
536 get(handler).into_make_service(),
537 );
538 serve(
539 TcpListener::bind(addr).await.unwrap(),
540 get(handler).into_make_service_with_connect_info::<SocketAddr>(),
541 );
542
543 serve(
545 TcpListener::bind(addr).await.unwrap(),
546 handler.into_service(),
547 );
548 serve(
549 TcpListener::bind(addr).await.unwrap(),
550 handler.with_state(()),
551 );
552 serve(
553 TcpListener::bind(addr).await.unwrap(),
554 handler.into_make_service(),
555 );
556 serve(
557 TcpListener::bind(addr).await.unwrap(),
558 handler.into_make_service_with_connect_info::<SocketAddr>(),
559 );
560 }
561
562 async fn handler() {}
563}