1use std::{
2 io,
3 net::{SocketAddr, TcpListener as StdTcpListener},
4 ops::ControlFlow,
5 pin::{pin, Pin},
6 task::{ready, Context, Poll},
7 time::Duration,
8};
9
10use tokio::{
11 io::{AsyncRead, AsyncWrite},
12 net::{TcpListener, TcpStream},
13};
14use tokio_stream::wrappers::TcpListenerStream;
15use tokio_stream::{Stream, StreamExt};
16use tracing::warn;
17
18use super::service::ServerIo;
19#[cfg(feature = "tls")]
20use super::service::TlsAcceptor;
21
22#[cfg(not(feature = "tls"))]
23pub(crate) fn tcp_incoming<IO, IE>(
24 incoming: impl Stream<Item = Result<IO, IE>>,
25) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
26where
27 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
28 IE: Into<crate::Error>,
29{
30 async_stream::try_stream! {
31 let mut incoming = pin!(incoming);
32
33 while let Some(item) = incoming.next().await {
34 yield match item {
35 Ok(_) => item.map(ServerIo::new_io)?,
36 Err(e) => match handle_accept_error(e) {
37 ControlFlow::Continue(()) => continue,
38 ControlFlow::Break(e) => Err(e)?,
39 }
40 }
41 }
42 }
43}
44
45#[cfg(feature = "tls")]
46pub(crate) fn tcp_incoming<IO, IE>(
47 incoming: impl Stream<Item = Result<IO, IE>>,
48 tls: Option<TlsAcceptor>,
49) -> impl Stream<Item = Result<ServerIo<IO>, crate::Error>>
50where
51 IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
52 IE: Into<crate::Error>,
53{
54 async_stream::try_stream! {
55 let mut incoming = pin!(incoming);
56
57 let mut tasks = tokio::task::JoinSet::new();
58
59 loop {
60 match select(&mut incoming, &mut tasks).await {
61 SelectOutput::Incoming(stream) => {
62 if let Some(tls) = &tls {
63 let tls = tls.clone();
64 tasks.spawn(async move {
65 let io = tls.accept(stream).await?;
66 Ok(ServerIo::new_tls_io(io))
67 });
68 } else {
69 yield ServerIo::new_io(stream);
70 }
71 }
72
73 SelectOutput::Io(io) => {
74 yield io;
75 }
76
77 SelectOutput::Err(e) => match handle_accept_error(e) {
78 ControlFlow::Continue(()) => continue,
79 ControlFlow::Break(e) => Err(e)?,
80 }
81
82 SelectOutput::Done => {
83 break;
84 }
85 }
86 }
87 }
88}
89
90fn handle_accept_error(e: impl Into<crate::Error>) -> ControlFlow<crate::Error> {
91 let e = e.into();
92 tracing::debug!(error = %e, "accept loop error");
93 if let Some(e) = e.downcast_ref::<io::Error>() {
94 if matches!(
95 e.kind(),
96 io::ErrorKind::ConnectionAborted
97 | io::ErrorKind::ConnectionReset
98 | io::ErrorKind::BrokenPipe
99 | io::ErrorKind::Interrupted
100 | io::ErrorKind::InvalidData | io::ErrorKind::UnexpectedEof | io::ErrorKind::WouldBlock
103 ) {
104 return ControlFlow::Continue(());
105 }
106 }
107
108 ControlFlow::Break(e)
109}
110
111#[cfg(feature = "tls")]
112async fn select<IO: 'static, IE>(
113 incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
114 tasks: &mut tokio::task::JoinSet<Result<ServerIo<IO>, crate::Error>>,
115) -> SelectOutput<IO>
116where
117 IE: Into<crate::Error>,
118{
119 if tasks.is_empty() {
120 return match incoming.try_next().await {
121 Ok(Some(stream)) => SelectOutput::Incoming(stream),
122 Ok(None) => SelectOutput::Done,
123 Err(e) => SelectOutput::Err(e.into()),
124 };
125 }
126
127 tokio::select! {
128 stream = incoming.try_next() => {
129 match stream {
130 Ok(Some(stream)) => SelectOutput::Incoming(stream),
131 Ok(None) => SelectOutput::Done,
132 Err(e) => SelectOutput::Err(e.into()),
133 }
134 }
135
136 accept = tasks.join_next() => {
137 match accept.expect("JoinSet should never end") {
138 Ok(Ok(io)) => SelectOutput::Io(io),
139 Ok(Err(e)) => SelectOutput::Err(e),
140 Err(e) => SelectOutput::Err(e.into()),
141 }
142 }
143 }
144}
145
146#[cfg(feature = "tls")]
147enum SelectOutput<A> {
148 Incoming(A),
149 Io(ServerIo<A>),
150 Err(crate::Error),
151 Done,
152}
153
154#[derive(Debug)]
159pub struct TcpIncoming {
160 inner: TcpListenerStream,
161 nodelay: bool,
162 keepalive: Option<Duration>,
163}
164
165impl TcpIncoming {
166 pub fn new(
198 addr: SocketAddr,
199 nodelay: bool,
200 keepalive: Option<Duration>,
201 ) -> Result<Self, crate::Error> {
202 let std_listener = StdTcpListener::bind(addr)?;
203 std_listener.set_nonblocking(true)?;
204
205 let inner = TcpListenerStream::new(TcpListener::from_std(std_listener)?);
206 Ok(Self {
207 inner,
208 nodelay,
209 keepalive,
210 })
211 }
212
213 pub fn from_listener(
215 listener: TcpListener,
216 nodelay: bool,
217 keepalive: Option<Duration>,
218 ) -> Result<Self, crate::Error> {
219 Ok(Self {
220 inner: TcpListenerStream::new(listener),
221 nodelay,
222 keepalive,
223 })
224 }
225}
226
227impl Stream for TcpIncoming {
228 type Item = Result<TcpStream, std::io::Error>;
229
230 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231 match ready!(Pin::new(&mut self.inner).poll_next(cx)) {
232 Some(Ok(stream)) => {
233 set_accepted_socket_options(&stream, self.nodelay, self.keepalive);
234 Some(Ok(stream)).into()
235 }
236 other => Poll::Ready(other),
237 }
238 }
239}
240
241fn set_accepted_socket_options(stream: &TcpStream, nodelay: bool, keepalive: Option<Duration>) {
243 if nodelay {
244 if let Err(e) = stream.set_nodelay(true) {
245 warn!("error trying to set TCP nodelay: {}", e);
246 }
247 }
248
249 if let Some(timeout) = keepalive {
250 let sock_ref = socket2::SockRef::from(&stream);
251 let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout);
252
253 if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) {
254 warn!("error trying to set TCP keepalive: {}", e);
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use crate::transport::server::TcpIncoming;
262 #[tokio::test]
263 async fn one_tcpincoming_at_a_time() {
264 let addr = "127.0.0.1:1322".parse().unwrap();
265 {
266 let _t1 = TcpIncoming::new(addr, true, None).unwrap();
267 let _t2 = TcpIncoming::new(addr, true, None).unwrap_err();
268 }
269 let _t3 = TcpIncoming::new(addr, true, None).unwrap();
270 }
271}