axum/
serve.rs

1//! Serve services.
2
3use std::{
4    convert::Infallible,
5    fmt::Debug,
6    future::{poll_fn, Future, IntoFuture},
7    io,
8    marker::PhantomData,
9    net::SocketAddr,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13    time::Duration,
14};
15
16use axum_core::{body::Body, extract::Request, response::Response};
17use futures_util::{pin_mut, FutureExt};
18use hyper::body::Incoming;
19use hyper_util::rt::{TokioExecutor, TokioIo};
20#[cfg(any(feature = "http1", feature = "http2"))]
21use hyper_util::server::conn::auto::Builder;
22use pin_project_lite::pin_project;
23use tokio::{
24    net::{TcpListener, TcpStream},
25    sync::watch,
26};
27use tower::util::{Oneshot, ServiceExt};
28use tower_service::Service;
29
30/// Serve the service with the supplied listener.
31///
32/// This method of running a service is intentionally simple and doesn't support any configuration.
33/// Use hyper or hyper-util if you need configuration.
34///
35/// It supports both HTTP/1 as well as HTTP/2.
36///
37/// # Examples
38///
39/// Serving a [`Router`]:
40///
41/// ```
42/// use axum::{Router, routing::get};
43///
44/// # async {
45/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
46///
47/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
48/// axum::serve(listener, router).await.unwrap();
49/// # };
50/// ```
51///
52/// See also [`Router::into_make_service_with_connect_info`].
53///
54/// Serving a [`MethodRouter`]:
55///
56/// ```
57/// use axum::routing::get;
58///
59/// # async {
60/// let router = get(|| async { "Hello, World!" });
61///
62/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
63/// axum::serve(listener, router).await.unwrap();
64/// # };
65/// ```
66///
67/// See also [`MethodRouter::into_make_service_with_connect_info`].
68///
69/// Serving a [`Handler`]:
70///
71/// ```
72/// use axum::handler::HandlerWithoutStateExt;
73///
74/// # async {
75/// async fn handler() -> &'static str {
76///     "Hello, World!"
77/// }
78///
79/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
80/// axum::serve(listener, handler.into_make_service()).await.unwrap();
81/// # };
82/// ```
83///
84/// See also [`HandlerWithoutStateExt::into_make_service_with_connect_info`] and
85/// [`HandlerService::into_make_service_with_connect_info`].
86///
87/// [`Router`]: crate::Router
88/// [`Router::into_make_service_with_connect_info`]: crate::Router::into_make_service_with_connect_info
89/// [`MethodRouter`]: crate::routing::MethodRouter
90/// [`MethodRouter::into_make_service_with_connect_info`]: crate::routing::MethodRouter::into_make_service_with_connect_info
91/// [`Handler`]: crate::handler::Handler
92/// [`HandlerWithoutStateExt::into_make_service_with_connect_info`]: crate::handler::HandlerWithoutStateExt::into_make_service_with_connect_info
93/// [`HandlerService::into_make_service_with_connect_info`]: crate::handler::HandlerService::into_make_service_with_connect_info
94#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
95pub fn serve<M, S>(tcp_listener: TcpListener, make_service: M) -> Serve<M, S>
96where
97    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S>,
98    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
99    S::Future: Send,
100{
101    Serve {
102        tcp_listener,
103        make_service,
104        _marker: PhantomData,
105    }
106}
107
108/// Future returned by [`serve`].
109#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
110#[must_use = "futures must be awaited or polled"]
111pub struct Serve<M, S> {
112    tcp_listener: TcpListener,
113    make_service: M,
114    _marker: PhantomData<S>,
115}
116
117#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
118impl<M, S> Serve<M, S> {
119    /// Prepares a server to handle graceful shutdown when the provided future completes.
120    ///
121    /// # Example
122    ///
123    /// ```
124    /// use axum::{Router, routing::get};
125    ///
126    /// # async {
127    /// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
128    ///
129    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
130    /// axum::serve(listener, router)
131    ///     .with_graceful_shutdown(shutdown_signal())
132    ///     .await
133    ///     .unwrap();
134    /// # };
135    ///
136    /// async fn shutdown_signal() {
137    ///     // ...
138    /// }
139    /// ```
140    pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
141    where
142        F: Future<Output = ()> + Send + 'static,
143    {
144        WithGracefulShutdown {
145            tcp_listener: self.tcp_listener,
146            make_service: self.make_service,
147            signal,
148            _marker: PhantomData,
149        }
150    }
151}
152
153#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
154impl<M, S> Debug for Serve<M, S>
155where
156    M: Debug,
157{
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        let Self {
160            tcp_listener,
161            make_service,
162            _marker: _,
163        } = self;
164
165        f.debug_struct("Serve")
166            .field("tcp_listener", tcp_listener)
167            .field("make_service", make_service)
168            .finish()
169    }
170}
171
172#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
173impl<M, S> IntoFuture for Serve<M, S>
174where
175    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
176    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
177    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
178    S::Future: Send,
179{
180    type Output = io::Result<()>;
181    type IntoFuture = private::ServeFuture;
182
183    fn into_future(self) -> Self::IntoFuture {
184        private::ServeFuture(Box::pin(async move {
185            let Self {
186                tcp_listener,
187                mut make_service,
188                _marker: _,
189            } = self;
190
191            loop {
192                let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
193                    Some(conn) => conn,
194                    None => continue,
195                };
196                let tcp_stream = TokioIo::new(tcp_stream);
197
198                poll_fn(|cx| make_service.poll_ready(cx))
199                    .await
200                    .unwrap_or_else(|err| match err {});
201
202                let tower_service = make_service
203                    .call(IncomingStream {
204                        tcp_stream: &tcp_stream,
205                        remote_addr,
206                    })
207                    .await
208                    .unwrap_or_else(|err| match err {});
209
210                let hyper_service = TowerToHyperService {
211                    service: tower_service,
212                };
213
214                tokio::spawn(async move {
215                    match Builder::new(TokioExecutor::new())
216                        // upgrades needed for websockets
217                        .serve_connection_with_upgrades(tcp_stream, hyper_service)
218                        .await
219                    {
220                        Ok(()) => {}
221                        Err(_err) => {
222                            // This error only appears when the client doesn't send a request and
223                            // terminate the connection.
224                            //
225                            // If client sends one request then terminate connection whenever, it doesn't
226                            // appear.
227                        }
228                    }
229                });
230            }
231        }))
232    }
233}
234
235/// Serve future with graceful shutdown enabled.
236#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
237#[must_use = "futures must be awaited or polled"]
238pub struct WithGracefulShutdown<M, S, F> {
239    tcp_listener: TcpListener,
240    make_service: M,
241    signal: F,
242    _marker: PhantomData<S>,
243}
244
245#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
246impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
247where
248    M: Debug,
249    S: Debug,
250    F: Debug,
251{
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        let Self {
254            tcp_listener,
255            make_service,
256            signal,
257            _marker: _,
258        } = self;
259
260        f.debug_struct("WithGracefulShutdown")
261            .field("tcp_listener", tcp_listener)
262            .field("make_service", make_service)
263            .field("signal", signal)
264            .finish()
265    }
266}
267
268#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
269impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
270where
271    M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
272    for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
273    S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
274    S::Future: Send,
275    F: Future<Output = ()> + Send + 'static,
276{
277    type Output = io::Result<()>;
278    type IntoFuture = private::ServeFuture;
279
280    fn into_future(self) -> Self::IntoFuture {
281        let Self {
282            tcp_listener,
283            mut make_service,
284            signal,
285            _marker: _,
286        } = self;
287
288        let (signal_tx, signal_rx) = watch::channel(());
289        let signal_tx = Arc::new(signal_tx);
290        tokio::spawn(async move {
291            signal.await;
292            trace!("received graceful shutdown signal. Telling tasks to shutdown");
293            drop(signal_rx);
294        });
295
296        let (close_tx, close_rx) = watch::channel(());
297
298        private::ServeFuture(Box::pin(async move {
299            loop {
300                let (tcp_stream, remote_addr) = tokio::select! {
301                    conn = tcp_accept(&tcp_listener) => {
302                        match conn {
303                            Some(conn) => conn,
304                            None => continue,
305                        }
306                    }
307                    _ = signal_tx.closed() => {
308                        trace!("signal received, not accepting new connections");
309                        break;
310                    }
311                };
312                let tcp_stream = TokioIo::new(tcp_stream);
313
314                trace!("connection {remote_addr} accepted");
315
316                poll_fn(|cx| make_service.poll_ready(cx))
317                    .await
318                    .unwrap_or_else(|err| match err {});
319
320                let tower_service = make_service
321                    .call(IncomingStream {
322                        tcp_stream: &tcp_stream,
323                        remote_addr,
324                    })
325                    .await
326                    .unwrap_or_else(|err| match err {});
327
328                let hyper_service = TowerToHyperService {
329                    service: tower_service,
330                };
331
332                let signal_tx = Arc::clone(&signal_tx);
333
334                let close_rx = close_rx.clone();
335
336                tokio::spawn(async move {
337                    let builder = Builder::new(TokioExecutor::new());
338                    let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
339                    pin_mut!(conn);
340
341                    let signal_closed = signal_tx.closed().fuse();
342                    pin_mut!(signal_closed);
343
344                    loop {
345                        tokio::select! {
346                            result = conn.as_mut() => {
347                                if let Err(_err) = result {
348                                    trace!("failed to serve connection: {_err:#}");
349                                }
350                                break;
351                            }
352                            _ = &mut signal_closed => {
353                                trace!("signal received in task, starting graceful shutdown");
354                                conn.as_mut().graceful_shutdown();
355                            }
356                        }
357                    }
358
359                    trace!("connection {remote_addr} closed");
360
361                    drop(close_rx);
362                });
363            }
364
365            drop(close_rx);
366            drop(tcp_listener);
367
368            trace!(
369                "waiting for {} task(s) to finish",
370                close_tx.receiver_count()
371            );
372            close_tx.closed().await;
373
374            Ok(())
375        }))
376    }
377}
378
379fn is_connection_error(e: &io::Error) -> bool {
380    matches!(
381        e.kind(),
382        io::ErrorKind::ConnectionRefused
383            | io::ErrorKind::ConnectionAborted
384            | io::ErrorKind::ConnectionReset
385    )
386}
387
388async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
389    match listener.accept().await {
390        Ok(conn) => Some(conn),
391        Err(e) => {
392            if is_connection_error(&e) {
393                return None;
394            }
395
396            // [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
397            //
398            // > A possible scenario is that the process has hit the max open files
399            // > allowed, and so trying to accept a new connection will fail with
400            // > `EMFILE`. In some cases, it's preferable to just wait for some time, if
401            // > the application will likely close some files (or connections), and try
402            // > to accept the connection again. If this option is `true`, the error
403            // > will be logged at the `error` level, since it is still a big deal,
404            // > and then the listener will sleep for 1 second.
405            //
406            // hyper allowed customizing this but axum does not.
407            error!("accept error: {e}");
408            tokio::time::sleep(Duration::from_secs(1)).await;
409            None
410        }
411    }
412}
413
414mod private {
415    use std::{
416        future::Future,
417        io,
418        pin::Pin,
419        task::{Context, Poll},
420    };
421
422    pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
423
424    impl Future for ServeFuture {
425        type Output = io::Result<()>;
426
427        #[inline]
428        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
429            self.0.as_mut().poll(cx)
430        }
431    }
432
433    impl std::fmt::Debug for ServeFuture {
434        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435            f.debug_struct("ServeFuture").finish_non_exhaustive()
436        }
437    }
438}
439
440#[derive(Debug, Copy, Clone)]
441struct TowerToHyperService<S> {
442    service: S,
443}
444
445impl<S> hyper::service::Service<Request<Incoming>> for TowerToHyperService<S>
446where
447    S: tower_service::Service<Request> + Clone,
448{
449    type Response = S::Response;
450    type Error = S::Error;
451    type Future = TowerToHyperServiceFuture<S, Request>;
452
453    fn call(&self, req: Request<Incoming>) -> Self::Future {
454        let req = req.map(Body::new);
455        TowerToHyperServiceFuture {
456            future: self.service.clone().oneshot(req),
457        }
458    }
459}
460
461pin_project! {
462    struct TowerToHyperServiceFuture<S, R>
463    where
464        S: tower_service::Service<R>,
465    {
466        #[pin]
467        future: Oneshot<S, R>,
468    }
469}
470
471impl<S, R> Future for TowerToHyperServiceFuture<S, R>
472where
473    S: tower_service::Service<R>,
474{
475    type Output = Result<S::Response, S::Error>;
476
477    #[inline]
478    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
479        self.project().future.poll(cx)
480    }
481}
482
483/// An incoming stream.
484///
485/// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`].
486///
487/// [`IntoMakeServiceWithConnectInfo`]: crate::extract::connect_info::IntoMakeServiceWithConnectInfo
488#[derive(Debug)]
489pub struct IncomingStream<'a> {
490    tcp_stream: &'a TokioIo<TcpStream>,
491    remote_addr: SocketAddr,
492}
493
494impl IncomingStream<'_> {
495    /// Returns the local address that this stream is bound to.
496    pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
497        self.tcp_stream.inner().local_addr()
498    }
499
500    /// Returns the remote address that this stream is bound to.
501    pub fn remote_addr(&self) -> SocketAddr {
502        self.remote_addr
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use crate::{
510        handler::{Handler, HandlerWithoutStateExt},
511        routing::get,
512        Router,
513    };
514
515    #[allow(dead_code, unused_must_use)]
516    async fn if_it_compiles_it_works() {
517        let router: Router = Router::new();
518
519        let addr = "0.0.0.0:0";
520
521        // router
522        serve(TcpListener::bind(addr).await.unwrap(), router.clone());
523        serve(
524            TcpListener::bind(addr).await.unwrap(),
525            router.clone().into_make_service(),
526        );
527        serve(
528            TcpListener::bind(addr).await.unwrap(),
529            router.into_make_service_with_connect_info::<SocketAddr>(),
530        );
531
532        // method router
533        serve(TcpListener::bind(addr).await.unwrap(), get(handler));
534        serve(
535            TcpListener::bind(addr).await.unwrap(),
536            get(handler).into_make_service(),
537        );
538        serve(
539            TcpListener::bind(addr).await.unwrap(),
540            get(handler).into_make_service_with_connect_info::<SocketAddr>(),
541        );
542
543        // handler
544        serve(
545            TcpListener::bind(addr).await.unwrap(),
546            handler.into_service(),
547        );
548        serve(
549            TcpListener::bind(addr).await.unwrap(),
550            handler.with_state(()),
551        );
552        serve(
553            TcpListener::bind(addr).await.unwrap(),
554            handler.into_make_service(),
555        );
556        serve(
557            TcpListener::bind(addr).await.unwrap(),
558            handler.into_make_service_with_connect_info::<SocketAddr>(),
559        );
560    }
561
562    async fn handler() {}
563}