tonic/transport/server/
mod.rs

1//! Server implementation and builder.
2
3mod conn;
4mod incoming;
5mod service;
6#[cfg(feature = "tls")]
7mod tls;
8#[cfg(unix)]
9mod unix;
10
11use tokio_stream::StreamExt as _;
12use tracing::{debug, trace};
13
14use crate::service::Routes;
15
16pub use conn::{Connected, TcpConnectInfo};
17use hyper_util::{
18    rt::{TokioExecutor, TokioIo, TokioTimer},
19    server::conn::auto::{Builder as ConnectionBuilder, HttpServerConnExec},
20    service::TowerToHyperService,
21};
22#[cfg(feature = "tls")]
23pub use tls::ServerTlsConfig;
24
25#[cfg(feature = "tls")]
26pub use conn::TlsConnectInfo;
27
28#[cfg(feature = "tls")]
29use self::service::TlsAcceptor;
30
31#[cfg(unix)]
32pub use unix::UdsConnectInfo;
33
34pub use incoming::TcpIncoming;
35
36#[cfg(feature = "tls")]
37use crate::transport::Error;
38
39use self::service::{RecoverError, ServerIo};
40use super::service::GrpcTimeout;
41use crate::body::{boxed, BoxBody};
42use crate::server::NamedService;
43use bytes::Bytes;
44use http::{Request, Response};
45use http_body_util::BodyExt;
46use hyper::{body::Incoming, service::Service as HyperService};
47use pin_project::pin_project;
48use std::future::pending;
49use std::{
50    convert::Infallible,
51    fmt,
52    future::{self, poll_fn, Future},
53    marker::PhantomData,
54    net::SocketAddr,
55    pin::{pin, Pin},
56    sync::Arc,
57    task::{ready, Context, Poll},
58    time::Duration,
59};
60use tokio::io::{AsyncRead, AsyncWrite};
61use tokio::time::sleep;
62use tokio_stream::Stream;
63use tower::{
64    layer::util::{Identity, Stack},
65    layer::Layer,
66    limit::concurrency::ConcurrencyLimitLayer,
67    util::{BoxCloneService, Either},
68    Service, ServiceBuilder, ServiceExt,
69};
70
71type BoxService = tower::util::BoxCloneService<Request<BoxBody>, Response<BoxBody>, crate::Error>;
72type TraceInterceptor = Arc<dyn Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static>;
73
74const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20;
75
76/// A default batteries included `transport` server.
77///
78/// This provides an easy builder pattern style builder [`Server`] on top of
79/// `hyper` connections. This builder exposes easy configuration parameters
80/// for providing a fully featured http2 based gRPC server. This should provide
81/// a very good out of the box http2 server for use with tonic but is also a
82/// reference implementation that should be a good starting point for anyone
83/// wanting to create a more complex and/or specific implementation.
84#[derive(Clone)]
85pub struct Server<L = Identity> {
86    trace_interceptor: Option<TraceInterceptor>,
87    concurrency_limit: Option<usize>,
88    timeout: Option<Duration>,
89    #[cfg(feature = "tls")]
90    tls: Option<TlsAcceptor>,
91    init_stream_window_size: Option<u32>,
92    init_connection_window_size: Option<u32>,
93    max_concurrent_streams: Option<u32>,
94    tcp_keepalive: Option<Duration>,
95    tcp_nodelay: bool,
96    http2_keepalive_interval: Option<Duration>,
97    http2_keepalive_timeout: Option<Duration>,
98    http2_adaptive_window: Option<bool>,
99    http2_max_pending_accept_reset_streams: Option<usize>,
100    http2_max_header_list_size: Option<u32>,
101    max_frame_size: Option<u32>,
102    accept_http1: bool,
103    service_builder: ServiceBuilder<L>,
104    max_connection_age: Option<Duration>,
105}
106
107impl Default for Server<Identity> {
108    fn default() -> Self {
109        Self {
110            trace_interceptor: None,
111            concurrency_limit: None,
112            timeout: None,
113            #[cfg(feature = "tls")]
114            tls: None,
115            init_stream_window_size: None,
116            init_connection_window_size: None,
117            max_concurrent_streams: None,
118            tcp_keepalive: None,
119            tcp_nodelay: false,
120            http2_keepalive_interval: None,
121            http2_keepalive_timeout: None,
122            http2_adaptive_window: None,
123            http2_max_pending_accept_reset_streams: None,
124            http2_max_header_list_size: None,
125            max_frame_size: None,
126            accept_http1: false,
127            service_builder: Default::default(),
128            max_connection_age: None,
129        }
130    }
131}
132
133/// A stack based [`Service`] router.
134#[derive(Debug)]
135pub struct Router<L = Identity> {
136    server: Server<L>,
137    routes: Routes,
138}
139
140impl<S: NamedService, T> NamedService for Either<S, T> {
141    const NAME: &'static str = S::NAME;
142}
143
144impl Server {
145    /// Create a new server builder that can configure a [`Server`].
146    pub fn builder() -> Self {
147        Server {
148            tcp_nodelay: true,
149            accept_http1: false,
150            ..Default::default()
151        }
152    }
153}
154
155impl<L> Server<L> {
156    /// Configure TLS for this server.
157    #[cfg(feature = "tls")]
158    pub fn tls_config(self, tls_config: ServerTlsConfig) -> Result<Self, Error> {
159        Ok(Server {
160            tls: Some(tls_config.tls_acceptor().map_err(Error::from_source)?),
161            ..self
162        })
163    }
164
165    /// Set the concurrency limit applied to on requests inbound per connection.
166    ///
167    /// # Example
168    ///
169    /// ```
170    /// # use tonic::transport::Server;
171    /// # use tower_service::Service;
172    /// # let builder = Server::builder();
173    /// builder.concurrency_limit_per_connection(32);
174    /// ```
175    #[must_use]
176    pub fn concurrency_limit_per_connection(self, limit: usize) -> Self {
177        Server {
178            concurrency_limit: Some(limit),
179            ..self
180        }
181    }
182
183    /// Set a timeout on for all request handlers.
184    ///
185    /// # Example
186    ///
187    /// ```
188    /// # use tonic::transport::Server;
189    /// # use tower_service::Service;
190    /// # use std::time::Duration;
191    /// # let builder = Server::builder();
192    /// builder.timeout(Duration::from_secs(30));
193    /// ```
194    #[must_use]
195    pub fn timeout(self, timeout: Duration) -> Self {
196        Server {
197            timeout: Some(timeout),
198            ..self
199        }
200    }
201
202    /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
203    /// stream-level flow control.
204    ///
205    /// Default is 65,535
206    ///
207    /// [spec]: https://httpwg.org/specs/rfc9113.html#InitialWindowSize
208    #[must_use]
209    pub fn initial_stream_window_size(self, sz: impl Into<Option<u32>>) -> Self {
210        Server {
211            init_stream_window_size: sz.into(),
212            ..self
213        }
214    }
215
216    /// Sets the max connection-level flow control for HTTP2
217    ///
218    /// Default is 65,535
219    #[must_use]
220    pub fn initial_connection_window_size(self, sz: impl Into<Option<u32>>) -> Self {
221        Server {
222            init_connection_window_size: sz.into(),
223            ..self
224        }
225    }
226
227    /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2
228    /// connections.
229    ///
230    /// Default is no limit (`None`).
231    ///
232    /// [spec]: https://httpwg.org/specs/rfc9113.html#n-stream-concurrency
233    #[must_use]
234    pub fn max_concurrent_streams(self, max: impl Into<Option<u32>>) -> Self {
235        Server {
236            max_concurrent_streams: max.into(),
237            ..self
238        }
239    }
240
241    /// Sets the maximum time option in milliseconds that a connection may exist
242    ///
243    /// Default is no limit (`None`).
244    ///
245    /// # Example
246    ///
247    /// ```
248    /// # use tonic::transport::Server;
249    /// # use tower_service::Service;
250    /// # use std::time::Duration;
251    /// # let builder = Server::builder();
252    /// builder.max_connection_age(Duration::from_secs(60));
253    /// ```
254    #[must_use]
255    pub fn max_connection_age(self, max_connection_age: Duration) -> Self {
256        Server {
257            max_connection_age: Some(max_connection_age),
258            ..self
259        }
260    }
261
262    /// Set whether HTTP2 Ping frames are enabled on accepted connections.
263    ///
264    /// If `None` is specified, HTTP2 keepalive is disabled, otherwise the duration
265    /// specified will be the time interval between HTTP2 Ping frames.
266    /// The timeout for receiving an acknowledgement of the keepalive ping
267    /// can be set with [`Server::http2_keepalive_timeout`].
268    ///
269    /// Default is no HTTP2 keepalive (`None`)
270    ///
271    #[must_use]
272    pub fn http2_keepalive_interval(self, http2_keepalive_interval: Option<Duration>) -> Self {
273        Server {
274            http2_keepalive_interval,
275            ..self
276        }
277    }
278
279    /// Sets a timeout for receiving an acknowledgement of the keepalive ping.
280    ///
281    /// If the ping is not acknowledged within the timeout, the connection will be closed.
282    /// Does nothing if http2_keep_alive_interval is disabled.
283    ///
284    /// Default is 20 seconds.
285    ///
286    #[must_use]
287    pub fn http2_keepalive_timeout(self, http2_keepalive_timeout: Option<Duration>) -> Self {
288        Server {
289            http2_keepalive_timeout,
290            ..self
291        }
292    }
293
294    /// Sets whether to use an adaptive flow control. Defaults to false.
295    /// Enabling this will override the limits set in http2_initial_stream_window_size and
296    /// http2_initial_connection_window_size.
297    #[must_use]
298    pub fn http2_adaptive_window(self, enabled: Option<bool>) -> Self {
299        Server {
300            http2_adaptive_window: enabled,
301            ..self
302        }
303    }
304
305    /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent.
306    ///
307    /// This will default to whatever the default in h2 is. As of v0.3.17, it is 20.
308    ///
309    /// See <https://github.com/hyperium/hyper/issues/2877> for more information.
310    #[must_use]
311    pub fn http2_max_pending_accept_reset_streams(self, max: Option<usize>) -> Self {
312        Server {
313            http2_max_pending_accept_reset_streams: max,
314            ..self
315        }
316    }
317
318    /// Set whether TCP keepalive messages are enabled on accepted connections.
319    ///
320    /// If `None` is specified, keepalive is disabled, otherwise the duration
321    /// specified will be the time to remain idle before sending TCP keepalive
322    /// probes.
323    ///
324    /// Default is no keepalive (`None`)
325    ///
326    #[must_use]
327    pub fn tcp_keepalive(self, tcp_keepalive: Option<Duration>) -> Self {
328        Server {
329            tcp_keepalive,
330            ..self
331        }
332    }
333
334    /// Set the value of `TCP_NODELAY` option for accepted connections. Enabled by default.
335    #[must_use]
336    pub fn tcp_nodelay(self, enabled: bool) -> Self {
337        Server {
338            tcp_nodelay: enabled,
339            ..self
340        }
341    }
342
343    /// Sets the max size of received header frames.
344    ///
345    /// This will default to whatever the default in hyper is. As of v1.4.1, it is 16 KiB.
346    #[must_use]
347    pub fn http2_max_header_list_size(self, max: impl Into<Option<u32>>) -> Self {
348        Server {
349            http2_max_header_list_size: max.into(),
350            ..self
351        }
352    }
353
354    /// Sets the maximum frame size to use for HTTP2.
355    ///
356    /// Passing `None` will do nothing.
357    ///
358    /// If not set, will default from underlying transport.
359    #[must_use]
360    pub fn max_frame_size(self, frame_size: impl Into<Option<u32>>) -> Self {
361        Server {
362            max_frame_size: frame_size.into(),
363            ..self
364        }
365    }
366
367    /// Allow this server to accept http1 requests.
368    ///
369    /// Accepting http1 requests is only useful when developing `grpc-web`
370    /// enabled services. If this setting is set to `true` but services are
371    /// not correctly configured to handle grpc-web requests, your server may
372    /// return confusing (but correct) protocol errors.
373    ///
374    /// Default is `false`.
375    #[must_use]
376    pub fn accept_http1(self, accept_http1: bool) -> Self {
377        Server {
378            accept_http1,
379            ..self
380        }
381    }
382
383    /// Intercept inbound headers and add a [`tracing::Span`] to each response future.
384    #[must_use]
385    pub fn trace_fn<F>(self, f: F) -> Self
386    where
387        F: Fn(&http::Request<()>) -> tracing::Span + Send + Sync + 'static,
388    {
389        Server {
390            trace_interceptor: Some(Arc::new(f)),
391            ..self
392        }
393    }
394
395    /// Create a router with the `S` typed service as the first service.
396    ///
397    /// This will clone the `Server` builder and create a router that will
398    /// route around different services.
399    pub fn add_service<S>(&mut self, svc: S) -> Router<L>
400    where
401        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
402            + NamedService
403            + Clone
404            + Send
405            + 'static,
406        S::Future: Send + 'static,
407        L: Clone,
408    {
409        Router::new(self.clone(), Routes::new(svc))
410    }
411
412    /// Create a router with the optional `S` typed service as the first service.
413    ///
414    /// This will clone the `Server` builder and create a router that will
415    /// route around different services.
416    ///
417    /// # Note
418    /// Even when the argument given is `None` this will capture *all* requests to this service name.
419    /// As a result, one cannot use this to toggle between two identically named implementations.
420    pub fn add_optional_service<S>(&mut self, svc: Option<S>) -> Router<L>
421    where
422        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
423            + NamedService
424            + Clone
425            + Send
426            + 'static,
427        S::Future: Send + 'static,
428        L: Clone,
429    {
430        let routes = svc.map(Routes::new).unwrap_or_default();
431        Router::new(self.clone(), routes)
432    }
433
434    /// Create a router with given [`Routes`].
435    ///
436    /// This will clone the `Server` builder and create a router that will
437    /// route around different services that were already added to the provided `routes`.
438    pub fn add_routes(&mut self, routes: Routes) -> Router<L>
439    where
440        L: Clone,
441    {
442        Router::new(self.clone(), routes)
443    }
444
445    /// Set the [Tower] [`Layer`] all services will be wrapped in.
446    ///
447    /// This enables using middleware from the [Tower ecosystem][eco].
448    ///
449    /// # Example
450    ///
451    /// ```
452    /// # use tonic::transport::Server;
453    /// # use tower_service::Service;
454    /// use tower::timeout::TimeoutLayer;
455    /// use std::time::Duration;
456    ///
457    /// # let mut builder = Server::builder();
458    /// builder.layer(TimeoutLayer::new(Duration::from_secs(30)));
459    /// ```
460    ///
461    /// Note that timeouts should be set using [`Server::timeout`]. `TimeoutLayer` is only used
462    /// here as an example.
463    ///
464    /// You can build more complex layers using [`ServiceBuilder`]. Those layers can include
465    /// [interceptors]:
466    ///
467    /// ```
468    /// # use tonic::transport::Server;
469    /// # use tower_service::Service;
470    /// use tower::ServiceBuilder;
471    /// use std::time::Duration;
472    /// use tonic::{Request, Status, service::interceptor};
473    ///
474    /// fn auth_interceptor(request: Request<()>) -> Result<Request<()>, Status> {
475    ///     if valid_credentials(&request) {
476    ///         Ok(request)
477    ///     } else {
478    ///         Err(Status::unauthenticated("invalid credentials"))
479    ///     }
480    /// }
481    ///
482    /// fn valid_credentials(request: &Request<()>) -> bool {
483    ///     // ...
484    ///     # true
485    /// }
486    ///
487    /// fn some_other_interceptor(request: Request<()>) -> Result<Request<()>, Status> {
488    ///     Ok(request)
489    /// }
490    ///
491    /// let layer = ServiceBuilder::new()
492    ///     .load_shed()
493    ///     .timeout(Duration::from_secs(30))
494    ///     .layer(interceptor(auth_interceptor))
495    ///     .layer(interceptor(some_other_interceptor))
496    ///     .into_inner();
497    ///
498    /// Server::builder().layer(layer);
499    /// ```
500    ///
501    /// [Tower]: https://github.com/tower-rs/tower
502    /// [`Layer`]: tower::layer::Layer
503    /// [eco]: https://github.com/tower-rs
504    /// [`ServiceBuilder`]: tower::ServiceBuilder
505    /// [interceptors]: crate::service::Interceptor
506    pub fn layer<NewLayer>(self, new_layer: NewLayer) -> Server<Stack<NewLayer, L>> {
507        Server {
508            service_builder: self.service_builder.layer(new_layer),
509            trace_interceptor: self.trace_interceptor,
510            concurrency_limit: self.concurrency_limit,
511            timeout: self.timeout,
512            #[cfg(feature = "tls")]
513            tls: self.tls,
514            init_stream_window_size: self.init_stream_window_size,
515            init_connection_window_size: self.init_connection_window_size,
516            max_concurrent_streams: self.max_concurrent_streams,
517            tcp_keepalive: self.tcp_keepalive,
518            tcp_nodelay: self.tcp_nodelay,
519            http2_keepalive_interval: self.http2_keepalive_interval,
520            http2_keepalive_timeout: self.http2_keepalive_timeout,
521            http2_adaptive_window: self.http2_adaptive_window,
522            http2_max_pending_accept_reset_streams: self.http2_max_pending_accept_reset_streams,
523            http2_max_header_list_size: self.http2_max_header_list_size,
524            max_frame_size: self.max_frame_size,
525            accept_http1: self.accept_http1,
526            max_connection_age: self.max_connection_age,
527        }
528    }
529
530    pub(crate) async fn serve_with_shutdown<S, I, F, IO, IE, ResBody>(
531        self,
532        svc: S,
533        incoming: I,
534        signal: Option<F>,
535    ) -> Result<(), super::Error>
536    where
537        L: Layer<S>,
538        L::Service:
539            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
540        <<L as Layer<S>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
541        <<L as Layer<S>>::Service as Service<Request<BoxBody>>>::Error:
542            Into<crate::Error> + Send + 'static,
543        I: Stream<Item = Result<IO, IE>>,
544        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
545        IO::ConnectInfo: Clone + Send + Sync + 'static,
546        IE: Into<crate::Error>,
547        F: Future<Output = ()>,
548        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
549        ResBody::Error: Into<crate::Error>,
550    {
551        let trace_interceptor = self.trace_interceptor.clone();
552        let concurrency_limit = self.concurrency_limit;
553        let init_connection_window_size = self.init_connection_window_size;
554        let init_stream_window_size = self.init_stream_window_size;
555        let max_concurrent_streams = self.max_concurrent_streams;
556        let timeout = self.timeout;
557        let max_header_list_size = self.http2_max_header_list_size;
558        let max_frame_size = self.max_frame_size;
559        let http2_only = !self.accept_http1;
560
561        let http2_keepalive_interval = self.http2_keepalive_interval;
562        let http2_keepalive_timeout = self
563            .http2_keepalive_timeout
564            .unwrap_or_else(|| Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0));
565        let http2_adaptive_window = self.http2_adaptive_window;
566        let http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams;
567        let max_connection_age = self.max_connection_age;
568
569        let svc = self.service_builder.service(svc);
570
571        let incoming = incoming::tcp_incoming(
572            incoming,
573            #[cfg(feature = "tls")]
574            self.tls,
575        );
576        let mut svc = MakeSvc {
577            inner: svc,
578            concurrency_limit,
579            timeout,
580            trace_interceptor,
581            _io: PhantomData,
582        };
583
584        let server = {
585            let mut builder = ConnectionBuilder::new(TokioExecutor::new());
586
587            if http2_only {
588                builder = builder.http2_only();
589            }
590
591            builder
592                .http2()
593                .timer(TokioTimer::new())
594                .initial_connection_window_size(init_connection_window_size)
595                .initial_stream_window_size(init_stream_window_size)
596                .max_concurrent_streams(max_concurrent_streams)
597                .keep_alive_interval(http2_keepalive_interval)
598                .keep_alive_timeout(http2_keepalive_timeout)
599                .adaptive_window(http2_adaptive_window.unwrap_or_default())
600                .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams)
601                .max_frame_size(max_frame_size);
602
603            if let Some(max_header_list_size) = max_header_list_size {
604                builder.http2().max_header_list_size(max_header_list_size);
605            }
606
607            builder
608        };
609
610        let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
611        let signal_tx = Arc::new(signal_tx);
612
613        let graceful = signal.is_some();
614        let mut sig = pin!(Fuse { inner: signal });
615        let mut incoming = pin!(incoming);
616
617        loop {
618            tokio::select! {
619                _ = &mut sig => {
620                    trace!("signal received, shutting down");
621                    break;
622                },
623                io = incoming.next() => {
624                    let io = match io {
625                        Some(Ok(io)) => io,
626                        Some(Err(e)) => {
627                            trace!("error accepting connection: {:#}", e);
628                            continue;
629                        },
630                        None => {
631                            break
632                        },
633                    };
634
635                    trace!("connection accepted");
636
637                    poll_fn(|cx| svc.poll_ready(cx))
638                        .await
639                        .map_err(super::Error::from_source)?;
640
641                    let req_svc = svc
642                        .call(&io)
643                        .await
644                        .map_err(super::Error::from_source)?;
645
646                    let hyper_io = TokioIo::new(io);
647                    let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request<Incoming>| req.map(boxed)));
648
649                    serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()), max_connection_age);
650                }
651            }
652        }
653
654        if graceful {
655            let _ = signal_tx.send(());
656            drop(signal_rx);
657            trace!(
658                "waiting for {} connections to close",
659                signal_tx.receiver_count()
660            );
661
662            // Wait for all connections to close
663            signal_tx.closed().await;
664        }
665
666        Ok(())
667    }
668}
669
670// This is moved to its own function as a way to get around
671// https://github.com/rust-lang/rust/issues/102211
672fn serve_connection<B, IO, S, E>(
673    hyper_io: IO,
674    hyper_svc: S,
675    builder: ConnectionBuilder<E>,
676    mut watcher: Option<tokio::sync::watch::Receiver<()>>,
677    max_connection_age: Option<Duration>,
678) where
679    B: http_body::Body + Send + 'static,
680    B::Data: Send,
681    B::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + Sync,
682    IO: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
683    S: HyperService<Request<Incoming>, Response = Response<B>> + Clone + Send + 'static,
684    S::Future: Send + 'static,
685    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send,
686    E: HttpServerConnExec<S::Future, B> + Send + Sync + 'static,
687{
688    tokio::spawn(async move {
689        {
690            let mut sig = pin!(Fuse {
691                inner: watcher.as_mut().map(|w| w.changed()),
692            });
693
694            let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));
695
696            let sleep = sleep_or_pending(max_connection_age);
697            tokio::pin!(sleep);
698
699            loop {
700                tokio::select! {
701                    rv = &mut conn => {
702                        if let Err(err) = rv {
703                            debug!("failed serving connection: {:#}", err);
704                        }
705                        break;
706                    },
707                    _ = &mut sleep  => {
708                        conn.as_mut().graceful_shutdown();
709                        sleep.set(sleep_or_pending(None));
710                    },
711                    _ = &mut sig => {
712                        conn.as_mut().graceful_shutdown();
713                    }
714                }
715            }
716        }
717
718        drop(watcher);
719        trace!("connection closed");
720    });
721}
722
723async fn sleep_or_pending(wait_for: Option<Duration>) {
724    match wait_for {
725        Some(wait) => sleep(wait).await,
726        None => pending().await,
727    };
728}
729
730impl<L> Router<L> {
731    pub(crate) fn new(server: Server<L>, routes: Routes) -> Self {
732        Self { server, routes }
733    }
734}
735
736impl<L> Router<L> {
737    /// Add a new service to this router.
738    pub fn add_service<S>(mut self, svc: S) -> Self
739    where
740        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
741            + NamedService
742            + Clone
743            + Send
744            + 'static,
745        S::Future: Send + 'static,
746    {
747        self.routes = self.routes.add_service(svc);
748        self
749    }
750
751    /// Add a new optional service to this router.
752    ///
753    /// # Note
754    /// Even when the argument given is `None` this will capture *all* requests to this service name.
755    /// As a result, one cannot use this to toggle between two identically named implementations.
756    #[allow(clippy::type_complexity)]
757    pub fn add_optional_service<S>(mut self, svc: Option<S>) -> Self
758    where
759        S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible>
760            + NamedService
761            + Clone
762            + Send
763            + 'static,
764        S::Future: Send + 'static,
765    {
766        if let Some(svc) = svc {
767            self.routes = self.routes.add_service(svc);
768        }
769        self
770    }
771
772    /// Convert this tonic `Router` into an axum `Router` consuming the tonic one.
773    #[deprecated(since = "0.12.2", note = "Use `Routes::into_axum_router` instead.")]
774    pub fn into_router(self) -> axum::Router {
775        self.routes.into_axum_router()
776    }
777
778    /// Consume this [`Server`] creating a future that will execute the server
779    /// on [tokio]'s default executor.
780    ///
781    /// [`Server`]: struct.Server.html
782    /// [tokio]: https://docs.rs/tokio
783    pub async fn serve<ResBody>(self, addr: SocketAddr) -> Result<(), super::Error>
784    where
785        L: Layer<Routes> + Clone,
786        L::Service:
787            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
788        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
789        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Error:
790            Into<crate::Error> + Send,
791        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
792        ResBody::Error: Into<crate::Error>,
793    {
794        let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
795            .map_err(super::Error::from_source)?;
796        self.server
797            .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
798                self.routes.prepare(),
799                incoming,
800                None,
801            )
802            .await
803    }
804
805    /// Consume this [`Server`] creating a future that will execute the server
806    /// on [tokio]'s default executor. And shutdown when the provided signal
807    /// is received.
808    ///
809    /// [`Server`]: struct.Server.html
810    /// [tokio]: https://docs.rs/tokio
811    pub async fn serve_with_shutdown<F: Future<Output = ()>, ResBody>(
812        self,
813        addr: SocketAddr,
814        signal: F,
815    ) -> Result<(), super::Error>
816    where
817        L: Layer<Routes>,
818        L::Service:
819            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
820        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
821        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Error:
822            Into<crate::Error> + Send,
823        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
824        ResBody::Error: Into<crate::Error>,
825    {
826        let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
827            .map_err(super::Error::from_source)?;
828        self.server
829            .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
830            .await
831    }
832
833    /// Consume this [`Server`] creating a future that will execute the server
834    /// on the provided incoming stream of `AsyncRead + AsyncWrite`.
835    ///
836    /// This method discards any provided [`Server`] TCP configuration.
837    ///
838    /// [`Server`]: struct.Server.html
839    pub async fn serve_with_incoming<I, IO, IE, ResBody>(
840        self,
841        incoming: I,
842    ) -> Result<(), super::Error>
843    where
844        I: Stream<Item = Result<IO, IE>>,
845        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
846        IO::ConnectInfo: Clone + Send + Sync + 'static,
847        IE: Into<crate::Error>,
848        L: Layer<Routes>,
849        L::Service:
850            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
851        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
852        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Error:
853            Into<crate::Error> + Send,
854        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
855        ResBody::Error: Into<crate::Error>,
856    {
857        self.server
858            .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>(
859                self.routes.prepare(),
860                incoming,
861                None,
862            )
863            .await
864    }
865
866    /// Consume this [`Server`] creating a future that will execute the server
867    /// on the provided incoming stream of `AsyncRead + AsyncWrite`. Similar to
868    /// `serve_with_shutdown` this method will also take a signal future to
869    /// gracefully shutdown the server.
870    ///
871    /// This method discards any provided [`Server`] TCP configuration.
872    ///
873    /// [`Server`]: struct.Server.html
874    pub async fn serve_with_incoming_shutdown<I, IO, IE, F, ResBody>(
875        self,
876        incoming: I,
877        signal: F,
878    ) -> Result<(), super::Error>
879    where
880        I: Stream<Item = Result<IO, IE>>,
881        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
882        IO::ConnectInfo: Clone + Send + Sync + 'static,
883        IE: Into<crate::Error>,
884        F: Future<Output = ()>,
885        L: Layer<Routes>,
886        L::Service:
887            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
888        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
889        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Error:
890            Into<crate::Error> + Send,
891        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
892        ResBody::Error: Into<crate::Error>,
893    {
894        self.server
895            .serve_with_shutdown(self.routes.prepare(), incoming, Some(signal))
896            .await
897    }
898
899    /// Create a tower service out of a router.
900    pub fn into_service<ResBody>(self) -> L::Service
901    where
902        L: Layer<Routes>,
903        L::Service:
904            Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
905        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Future: Send + 'static,
906        <<L as Layer<Routes>>::Service as Service<Request<BoxBody>>>::Error:
907            Into<crate::Error> + Send,
908        ResBody: http_body::Body<Data = Bytes> + Send + 'static,
909        ResBody::Error: Into<crate::Error>,
910    {
911        self.server.service_builder.service(self.routes.prepare())
912    }
913}
914
915impl<L> fmt::Debug for Server<L> {
916    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
917        f.debug_struct("Builder").finish()
918    }
919}
920
921#[derive(Clone)]
922struct Svc<S> {
923    inner: S,
924    trace_interceptor: Option<TraceInterceptor>,
925}
926
927impl<S, ResBody> Service<Request<BoxBody>> for Svc<S>
928where
929    S: Service<Request<BoxBody>, Response = Response<ResBody>>,
930    S::Error: Into<crate::Error>,
931    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
932    ResBody::Error: Into<crate::Error>,
933{
934    type Response = Response<BoxBody>;
935    type Error = crate::Error;
936    type Future = SvcFuture<S::Future>;
937
938    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
939        self.inner.poll_ready(cx).map_err(Into::into)
940    }
941
942    fn call(&mut self, mut req: Request<BoxBody>) -> Self::Future {
943        let span = if let Some(trace_interceptor) = &self.trace_interceptor {
944            let (parts, body) = req.into_parts();
945            let bodyless_request = Request::from_parts(parts, ());
946
947            let span = trace_interceptor(&bodyless_request);
948
949            let (parts, _) = bodyless_request.into_parts();
950            req = Request::from_parts(parts, body);
951
952            span
953        } else {
954            tracing::Span::none()
955        };
956
957        SvcFuture {
958            inner: self.inner.call(req),
959            span,
960        }
961    }
962}
963
964#[pin_project]
965struct SvcFuture<F> {
966    #[pin]
967    inner: F,
968    span: tracing::Span,
969}
970
971impl<F, E, ResBody> Future for SvcFuture<F>
972where
973    F: Future<Output = Result<Response<ResBody>, E>>,
974    E: Into<crate::Error>,
975    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
976    ResBody::Error: Into<crate::Error>,
977{
978    type Output = Result<Response<BoxBody>, crate::Error>;
979
980    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
981        let this = self.project();
982        let _guard = this.span.enter();
983
984        let response: Response<ResBody> = ready!(this.inner.poll(cx)).map_err(Into::into)?;
985        let response = response.map(|body| boxed(body.map_err(Into::into)));
986        Poll::Ready(Ok(response))
987    }
988}
989
990impl<S> fmt::Debug for Svc<S> {
991    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
992        f.debug_struct("Svc").finish()
993    }
994}
995
996#[derive(Clone)]
997struct MakeSvc<S, IO> {
998    concurrency_limit: Option<usize>,
999    timeout: Option<Duration>,
1000    inner: S,
1001    trace_interceptor: Option<TraceInterceptor>,
1002    _io: PhantomData<fn() -> IO>,
1003}
1004
1005impl<S, ResBody, IO> Service<&ServerIo<IO>> for MakeSvc<S, IO>
1006where
1007    IO: Connected,
1008    S: Service<Request<BoxBody>, Response = Response<ResBody>> + Clone + Send + 'static,
1009    S::Future: Send + 'static,
1010    S::Error: Into<crate::Error> + Send,
1011    ResBody: http_body::Body<Data = Bytes> + Send + 'static,
1012    ResBody::Error: Into<crate::Error>,
1013{
1014    type Response = BoxService;
1015    type Error = crate::Error;
1016    type Future = future::Ready<Result<Self::Response, Self::Error>>;
1017
1018    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1019        Ok(()).into()
1020    }
1021
1022    fn call(&mut self, io: &ServerIo<IO>) -> Self::Future {
1023        let conn_info = io.connect_info();
1024
1025        let svc = self.inner.clone();
1026        let concurrency_limit = self.concurrency_limit;
1027        let timeout = self.timeout;
1028        let trace_interceptor = self.trace_interceptor.clone();
1029
1030        let svc = ServiceBuilder::new()
1031            .layer_fn(RecoverError::new)
1032            .option_layer(concurrency_limit.map(ConcurrencyLimitLayer::new))
1033            .layer_fn(|s| GrpcTimeout::new(s, timeout))
1034            .service(svc);
1035
1036        let svc = ServiceBuilder::new()
1037            .layer(BoxCloneService::layer())
1038            .map_request(move |mut request: Request<BoxBody>| {
1039                match &conn_info {
1040                    tower::util::Either::A(inner) => {
1041                        request.extensions_mut().insert(inner.clone());
1042                    }
1043                    tower::util::Either::B(inner) => {
1044                        #[cfg(feature = "tls")]
1045                        {
1046                            request.extensions_mut().insert(inner.clone());
1047                            request.extensions_mut().insert(inner.get_ref().clone());
1048                        }
1049
1050                        #[cfg(not(feature = "tls"))]
1051                        {
1052                            // just a type check to make sure we didn't forget to
1053                            // insert this into the extensions
1054                            let _: &() = inner;
1055                        }
1056                    }
1057                }
1058
1059                request
1060            })
1061            .service(Svc {
1062                inner: svc,
1063                trace_interceptor,
1064            });
1065
1066        future::ready(Ok(svc))
1067    }
1068}
1069
1070// From `futures-util` crate, borrowed since this is the only dependency tonic requires.
1071// LICENSE: MIT or Apache-2.0
1072// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`.
1073#[pin_project]
1074struct Fuse<F> {
1075    #[pin]
1076    inner: Option<F>,
1077}
1078
1079impl<F> Future for Fuse<F>
1080where
1081    F: Future,
1082{
1083    type Output = F::Output;
1084
1085    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1086        match self.as_mut().project().inner.as_pin_mut() {
1087            Some(fut) => fut.poll(cx).map(|output| {
1088                self.project().inner.set(None);
1089                output
1090            }),
1091            None => Poll::Pending,
1092        }
1093    }
1094}