tonic/transport/server/
incoming.rs

1use std::{
2    io,
3    net::{SocketAddr, TcpListener as StdTcpListener},
4    ops::ControlFlow,
5    pin::{pin, Pin},
6    task::{ready, Context, Poll},
7    time::Duration,
8};
9
10use tokio::{
11    io::{AsyncRead, AsyncWrite},
12    net::{TcpListener, TcpStream},
13};
14use tokio_stream::wrappers::TcpListenerStream;
15use tokio_stream::{Stream, StreamExt};
16use tracing::warn;
17
18use super::service::ServerIo;
19#[cfg(feature = "tls")]
20use super::service::TlsAcceptor;
21
22#[cfg(not(feature = "tls"))]
23pub(crate) fn tcp_incoming<IO, IE>(
24    incoming: impl Stream<Item = Result<IO, IE>>,
25) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
26where
27    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
28    IE: Into<crate::Error>,
29{
30    async_stream::try_stream! {
31        let mut incoming = pin!(incoming);
32
33        while let Some(item) = incoming.next().await {
34            yield match item {
35                Ok(_) => item.map(ServerIo::new_io)?,
36                Err(e) => match handle_accept_error(e) {
37                    ControlFlow::Continue(()) => continue,
38                    ControlFlow::Break(e) => Err(e)?,
39                }
40            }
41        }
42    }
43}
44
45#[cfg(feature = "tls")]
46pub(crate) fn tcp_incoming<IO, IE>(
47    incoming: impl Stream<Item = Result<IO, IE>>,
48    tls: Option<TlsAcceptor>,
49) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
50where
51    IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
52    IE: Into<crate::Error>,
53{
54    async_stream::try_stream! {
55        let mut incoming = pin!(incoming);
56
57        let mut tasks = tokio::task::JoinSet::new();
58
59        loop {
60            match select(&mut incoming, &mut tasks).await {
61                SelectOutput::Incoming(stream) => {
62                    if let Some(tls) = &tls {
63                        let tls = tls.clone();
64                        tasks.spawn(async move {
65                            let io = tls.accept(stream).await?;
66                            Ok(ServerIo::new_tls_io(io))
67                        });
68                    } else {
69                        yield ServerIo::new_io(stream);
70                    }
71                }
72
73                SelectOutput::Io(io) => {
74                    yield io;
75                }
76
77                SelectOutput::Err(e) => match handle_accept_error(e) {
78                    ControlFlow::Continue(()) => continue,
79                    ControlFlow::Break(e) => Err(e)?,
80                }
81
82                SelectOutput::Done => {
83                    break;
84                }
85            }
86        }
87    }
88}
89
90fn handle_accept_error(e: impl Into<crate::Error>) -> ControlFlow<crate::Error> {
91    let e = e.into();
92    tracing::debug!(error = %e, "accept loop error");
93    if let Some(e) = e.downcast_ref::<io::Error>() {
94        if matches!(
95            e.kind(),
96            io::ErrorKind::ConnectionAborted
97                | io::ErrorKind::ConnectionReset
98                | io::ErrorKind::BrokenPipe
99                | io::ErrorKind::Interrupted
100                | io::ErrorKind::InvalidData // Raised if TLS handshake failed
101                | io::ErrorKind::UnexpectedEof // Raised if TLS handshake failed
102                | io::ErrorKind::WouldBlock
103        ) {
104            return ControlFlow::Continue(());
105        }
106    }
107
108    ControlFlow::Break(e)
109}
110
111#[cfg(feature = "tls")]
112async fn select<IO: 'static, IE>(
113    incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
114    tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
115) -> SelectOutput<IO>
116where
117    IE: Into<crate::Error>,
118{
119    if tasks.is_empty() {
120        return match incoming.try_next().await {
121            Ok(Some(stream)) => SelectOutput::Incoming(stream),
122            Ok(None) => SelectOutput::Done,
123            Err(e) => SelectOutput::Err(e.into()),
124        };
125    }
126
127    tokio::select! {
128        stream = incoming.try_next() => {
129            match stream {
130                Ok(Some(stream)) => SelectOutput::Incoming(stream),
131                Ok(None) => SelectOutput::Done,
132                Err(e) => SelectOutput::Err(e.into()),
133            }
134        }
135
136        accept = tasks.join_next() => {
137            match accept.expect("JoinSet should never end") {
138                Ok(Ok(io)) => SelectOutput::Io(io),
139                Ok(Err(e)) => SelectOutput::Err(e),
140                Err(e) => SelectOutput::Err(e.into()),
141            }
142        }
143    }
144}
145
146#[cfg(feature = "tls")]
147enum SelectOutput<A> {
148    Incoming(A),
149    Io(ServerIo<A>),
150    Err(crate::Error),
151    Done,
152}
153
154/// Binds a socket address for a [Router](super::Router)
155///
156/// An incoming stream, usable with [Router::serve_with_incoming](super::Router::serve_with_incoming),
157/// of `AsyncRead + AsyncWrite` that communicate with clients that connect to a socket address.
158#[derive(Debug)]
159pub struct TcpIncoming {
160    inner: TcpListenerStream,
161    nodelay: bool,
162    keepalive: Option<Duration>,
163}
164
165impl TcpIncoming {
166    /// Creates an instance by binding (opening) the specified socket address
167    /// to which the specified TCP 'nodelay' and 'keepalive' parameters are applied.
168    /// Returns a TcpIncoming if the socket address was successfully bound.
169    ///
170    /// # Examples
171    /// ```no_run
172    /// # use tower_service::Service;
173    /// # use http::{request::Request, response::Response};
174    /// # use tonic::{body::BoxBody, server::NamedService, transport::{Server, server::TcpIncoming}};
175    /// # use core::convert::Infallible;
176    /// # use std::error::Error;
177    /// # fn main() { }  // Cannot have type parameters, hence instead define:
178    /// # fn run<S>(some_service: S) -> Result<(), Box<dyn Error + Send + Sync>>
179    /// # where
180    /// #   S: Service<Request<BoxBody>, Response = Response<BoxBody>, Error = Infallible> + NamedService + Clone + Send + 'static,
181    /// #   S::Future: Send + 'static,
182    /// # {
183    /// // Find a free port
184    /// let mut port = 1322;
185    /// let tinc = loop {
186    ///    let addr = format!("127.0.0.1:{}", port).parse().unwrap();
187    ///    match TcpIncoming::new(addr, true, None) {
188    ///       Ok(t) => break t,
189    ///       Err(_) => port += 1
190    ///    }
191    /// };
192    /// Server::builder()
193    ///    .add_service(some_service)
194    ///    .serve_with_incoming(tinc);
195    /// # Ok(())
196    /// # }
197    pub fn new(
198        addr: SocketAddr,
199        nodelay: bool,
200        keepalive: Option<Duration>,
201    ) -> Result<Self, crate::Error> {
202        let std_listener = StdTcpListener::bind(addr)?;
203        std_listener.set_nonblocking(true)?;
204
205        let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
206        Ok(Self {
207            inner,
208            nodelay,
209            keepalive,
210        })
211    }
212
213    /// Creates a new `TcpIncoming` from an existing `tokio::net::TcpListener`.
214    pub fn from_listener(
215        listener: TcpListener,
216        nodelay: bool,
217        keepalive: Option<Duration>,
218    ) -> Result<Self, crate::Error> {
219        Ok(Self {
220            inner: TcpListenerStream::new(listener),
221            nodelay,
222            keepalive,
223        })
224    }
225}
226
227impl Stream for TcpIncoming {
228    type Item = Result<TcpStream, std::io::Error>;
229
230    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231        match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
232            Some(Ok(stream)) => {
233                set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
234                Some(Ok(stream)).into()
235            }
236            other => Poll::Ready(other),
237        }
238    }
239}
240
241// Consistent with hyper-0.14, this function does not return an error.
242fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
243    if nodelay {
244        if let Err(e) = stream.set_nodelay(true) {
245            warn!("error trying to set TCP nodelay: {}", e);
246        }
247    }
248
249    if let Some(timeout) = keepalive {
250        let sock_ref = socket2::SockRef::from(&stream);
251        let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
252
253        if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
254            warn!("error trying to set TCP keepalive: {}", e);
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use crate::transport::server::TcpIncoming;
262    #[tokio::test]
263    async fn one_tcpincoming_at_a_time() {
264        let addr = "127.0.0.1:1322".parse().unwrap();
265        {
266            let _t1 = TcpIncoming::new(addr, true, None).unwrap();
267            let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
268        }
269        let _t3 = TcpIncoming::new(addr, true, None).unwrap();
270    }
271}