mz_ore/netio/
socket.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll, ready};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream};
27use tonic::transport::server::{Connected, TcpConnectInfo, UdsConnectInfo};
28use tracing::warn;
29
30use crate::error::ErrorExt;
31
32/// The type of a [`SocketAddr`].
33#[derive(Debug, Clone, Copy)]
34pub enum SocketAddrType {
35    /// An internet socket address.
36    Inet,
37    /// A Unix domain socket address.
38    Unix,
39    /// A `turmoil` socket address.
40    Turmoil,
41}
42
43impl SocketAddrType {
44    /// Guesses the type of socket address specified by `s`.
45    ///
46    /// * Socket addresses that are absolute paths, as determined by a leading `/` character, are
47    ///   determined to be Unix socket addresses.
48    /// * Addresses with a "turmoil:" prefix are determined to be `turmoil` socket addresses.
49    /// * All other addresses are assumed to be internet socket addresses.
50    pub fn guess(s: &str) -> SocketAddrType {
51        if s.starts_with('/') {
52            SocketAddrType::Unix
53        } else if s.starts_with("turmoil:") {
54            SocketAddrType::Turmoil
55        } else {
56            SocketAddrType::Inet
57        }
58    }
59}
60
61/// An address associated with an internet or Unix domain socket.
62#[derive(Debug, Clone)]
63pub enum SocketAddr {
64    /// An internet socket address.
65    Inet(InetSocketAddr),
66    /// A Unix domain socket address.
67    Unix(UnixSocketAddr),
68    /// A `turmoil` socket address.
69    Turmoil(String),
70}
71
72impl PartialEq for SocketAddr {
73    fn eq(&self, other: &Self) -> bool {
74        match (self, other) {
75            (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
76            (
77                SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
78                SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
79            ) => path1 == path2,
80            (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
81            _ => false,
82        }
83    }
84}
85
86impl fmt::Display for SocketAddr {
87    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88        match &self {
89            SocketAddr::Inet(addr) => addr.fmt(f),
90            SocketAddr::Unix(addr) => addr.fmt(f),
91            SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
92        }
93    }
94}
95
96impl FromStr for SocketAddr {
97    type Err = AddrParseError;
98
99    /// Parses a socket address from a string.
100    ///
101    /// Whether a socket address is taken as an internet socket address or a
102    /// Unix socket address is determined by [`SocketAddrType::guess`].
103    fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
104        match SocketAddrType::guess(s) {
105            SocketAddrType::Unix => {
106                let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
107                    kind: AddrParseErrorKind::Unix(e),
108                })?;
109                Ok(SocketAddr::Unix(addr))
110            }
111            SocketAddrType::Inet => {
112                let addr = s.parse().map_err(|_| AddrParseError {
113                    // The underlying error message is always "invalid socket
114                    // address syntax", so there's no benefit to preserving it.
115                    kind: AddrParseErrorKind::Inet,
116                })?;
117                Ok(SocketAddr::Inet(addr))
118            }
119            SocketAddrType::Turmoil => {
120                let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
121                    kind: AddrParseErrorKind::Turmoil,
122                })?;
123                Ok(SocketAddr::Turmoil(addr.into()))
124            }
125        }
126    }
127}
128
129/// An address associated with a Unix domain socket.
130#[derive(Debug, Clone)]
131pub struct UnixSocketAddr {
132    path: Option<String>,
133}
134
135impl UnixSocketAddr {
136    /// Constructs a Unix domain socket address from the provided path.
137    ///
138    /// Unlike the [`UnixSocketAddr::from_pathname`] method in the standard
139    /// library, `path` is required to be valid UTF-8.
140    ///
141    /// # Errors
142    ///
143    /// Returns an error if the path is longer than `SUN_LEN` or if it contains
144    /// null bytes.
145    ///
146    /// [`UnixSocketAddr::from_pathname`]: std::os::unix::net::SocketAddr::from_pathname
147    pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
148    where
149        S: Into<String>,
150    {
151        let path = path.into();
152        let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
153        Ok(UnixSocketAddr { path: Some(path) })
154    }
155
156    /// Constructs a Unix domain socket address representing an unnamed Unix
157    /// socket.
158    pub fn unnamed() -> UnixSocketAddr {
159        UnixSocketAddr { path: None }
160    }
161
162    /// Returns the pathname of this Unix domain socket address, if it was
163    /// constructed from a pathname.
164    pub fn as_pathname(&self) -> Option<&str> {
165        self.path.as_deref()
166    }
167}
168
169impl fmt::Display for UnixSocketAddr {
170    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
171        match &self.path {
172            None => f.write_str("<unnamed>"),
173            Some(path) => f.write_str(path),
174        }
175    }
176}
177
178/// The error returned when parsing a [`SocketAddr`] from a string.
179#[derive(Debug)]
180pub struct AddrParseError {
181    kind: AddrParseErrorKind,
182}
183
184#[derive(Debug)]
185pub enum AddrParseErrorKind {
186    Inet,
187    Unix(io::Error),
188    Turmoil,
189}
190
191impl fmt::Display for AddrParseError {
192    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193        match &self.kind {
194            AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
195            AddrParseErrorKind::Unix(e) => {
196                f.write_str("invalid unix socket address syntax: ")?;
197                e.fmt(f)
198            }
199            AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
200        }
201    }
202}
203
204impl Error for AddrParseError {}
205
206/// Converts or resolves without blocking to one or more [`SocketAddr`]s.
207#[async_trait]
208pub trait ToSocketAddrs {
209    /// Converts to resolved [`SocketAddr`]s.
210    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
211}
212
213#[async_trait]
214impl ToSocketAddrs for SocketAddr {
215    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
216        Ok(vec![self.clone()])
217    }
218}
219
220#[async_trait]
221impl<'a> ToSocketAddrs for &'a [SocketAddr] {
222    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
223        Ok(self.to_vec())
224    }
225}
226
227#[async_trait]
228impl ToSocketAddrs for InetSocketAddr {
229    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
230        Ok(vec![SocketAddr::Inet(*self)])
231    }
232}
233
234#[async_trait]
235impl ToSocketAddrs for UnixSocketAddr {
236    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
237        Ok(vec![SocketAddr::Unix(self.clone())])
238    }
239}
240
241#[async_trait]
242impl ToSocketAddrs for str {
243    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
244        match self.parse() {
245            Ok(addr) => Ok(vec![addr]),
246            Err(_) => {
247                let addrs = net::lookup_host(self).await?;
248                Ok(addrs.map(SocketAddr::Inet).collect())
249            }
250        }
251    }
252}
253
254#[async_trait]
255impl ToSocketAddrs for String {
256    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
257        (**self).to_socket_addrs().await
258    }
259}
260
261#[async_trait]
262impl<T> ToSocketAddrs for &T
263where
264    T: ToSocketAddrs + Send + Sync + ?Sized,
265{
266    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
267        (**self).to_socket_addrs().await
268    }
269}
270
271/// A listener bound to either a TCP socket or Unix domain socket.
272pub enum Listener {
273    /// A TCP listener.
274    Tcp(TcpListener),
275    /// A Unix domain socket listener.
276    Unix(UnixListener),
277    /// A `turmoil` socket listener.
278    #[cfg(feature = "turmoil")]
279    Turmoil(turmoil::net::TcpListener),
280}
281
282impl Listener {
283    /// Creates a new listener bound to the specified socket address.
284    ///
285    /// If `addr` is a Unix domain address, this function attempts to unlink the
286    /// socket at the address, if it exists, before binding.
287    pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
288    where
289        A: ToSocketAddrs,
290    {
291        let mut last_err = None;
292        for addr in addr.to_socket_addrs().await? {
293            match Listener::bind_addr(addr).await {
294                Ok(listener) => return Ok(listener),
295                Err(e) => last_err = Some(e),
296            }
297        }
298        Err(last_err.unwrap_or_else(|| {
299            io::Error::new(
300                io::ErrorKind::InvalidInput,
301                "could not resolve to any address",
302            )
303        }))
304    }
305
306    async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
307        match &addr {
308            SocketAddr::Inet(addr) => {
309                let listener = TcpListener::bind(addr).await?;
310                Ok(Listener::Tcp(listener))
311            }
312            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
313                // We would ideally unlink the file only if we could prove that
314                // no process was still listening at the file, but there is no
315                // foolproof API for doing so.
316                // See: https://stackoverflow.com/q/7405932
317                if let Err(e) = fs::remove_file(path).await {
318                    if e.kind() != io::ErrorKind::NotFound {
319                        warn!(
320                            "unable to remove {path} while binding unix domain socket: {}",
321                            e.display_with_causes(),
322                        );
323                    }
324                }
325                let listener = UnixListener::bind(path)?;
326                Ok(Listener::Unix(listener))
327            }
328            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
329                io::ErrorKind::Other,
330                "cannot bind to unnamed Unix socket",
331            )),
332            #[cfg(feature = "turmoil")]
333            SocketAddr::Turmoil(addr) => {
334                let listener = turmoil::net::TcpListener::bind(addr).await?;
335                Ok(Listener::Turmoil(listener))
336            }
337            #[cfg(not(feature = "turmoil"))]
338            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
339        }
340    }
341
342    /// Accepts a new incoming connection to this listener.
343    pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
344        match self {
345            Listener::Tcp(listener) => {
346                let (stream, addr) = listener.accept().await?;
347                stream.set_nodelay(true)?;
348                let stream = Stream::Tcp(stream);
349                let addr = SocketAddr::Inet(addr);
350                Ok((stream, addr))
351            }
352            Listener::Unix(listener) => {
353                let (stream, addr) = listener.accept().await?;
354                let stream = Stream::Unix(stream);
355                assert!(addr.is_unnamed());
356                let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
357                Ok((stream, addr))
358            }
359            #[cfg(feature = "turmoil")]
360            Listener::Turmoil(listener) => {
361                let (stream, addr) = listener.accept().await?;
362                let stream = Stream::Turmoil(stream);
363                let addr = SocketAddr::Inet(addr);
364                Ok((stream, addr))
365            }
366        }
367    }
368
369    fn poll_accept(
370        self: Pin<&mut Self>,
371        cx: &mut Context<'_>,
372    ) -> Poll<Option<Result<Stream, io::Error>>> {
373        match self.get_mut() {
374            Listener::Tcp(listener) => {
375                let (stream, _addr) = ready!(listener.poll_accept(cx))?;
376                stream.set_nodelay(true)?;
377                Poll::Ready(Some(Ok(Stream::Tcp(stream))))
378            }
379            Listener::Unix(listener) => {
380                let (stream, _addr) = ready!(listener.poll_accept(cx))?;
381                Poll::Ready(Some(Ok(Stream::Unix(stream))))
382            }
383            #[cfg(feature = "turmoil")]
384            Listener::Turmoil(_) => {
385                unimplemented!("`turmoil::net::TcpListener::poll_accept`");
386            }
387        }
388    }
389}
390
391impl fmt::Debug for Listener {
392    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393        match self {
394            Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
395            Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
396            #[cfg(feature = "turmoil")]
397            Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
398        }
399    }
400}
401
402impl futures::stream::Stream for Listener {
403    type Item = Result<Stream, io::Error>;
404
405    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
406        self.poll_accept(cx)
407    }
408}
409
410/// A stream associated with either a TCP socket or a Unix domain socket.
411#[derive(Debug)]
412pub enum Stream {
413    /// A TCP stream.
414    Tcp(TcpStream),
415    /// A Unix domain socket stream.
416    Unix(UnixStream),
417    /// A `turmoil` socket stream.
418    #[cfg(feature = "turmoil")]
419    Turmoil(turmoil::net::TcpStream),
420}
421
422impl Stream {
423    /// Opens a connection to the specified socket address.
424    pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
425    where
426        A: ToSocketAddrs,
427    {
428        let mut last_err = None;
429        for addr in addr.to_socket_addrs().await? {
430            match Stream::connect_addr(addr).await {
431                Ok(stream) => return Ok(stream),
432                Err(e) => last_err = Some(e),
433            }
434        }
435        Err(last_err.unwrap_or_else(|| {
436            io::Error::new(
437                io::ErrorKind::InvalidInput,
438                "could not resolve to any address",
439            )
440        }))
441    }
442
443    async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
444        match addr {
445            SocketAddr::Inet(addr) => {
446                let stream = TcpStream::connect(addr).await?;
447                Ok(Stream::Tcp(stream))
448            }
449            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
450                let stream = UnixStream::connect(path).await?;
451                Ok(Stream::Unix(stream))
452            }
453            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
454                io::ErrorKind::Other,
455                "cannot connected to unnamed Unix socket",
456            )),
457            #[cfg(feature = "turmoil")]
458            SocketAddr::Turmoil(addr) => {
459                let stream = turmoil::net::TcpStream::connect(addr).await?;
460                Ok(Stream::Turmoil(stream))
461            }
462            #[cfg(not(feature = "turmoil"))]
463            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
464        }
465    }
466
467    /// Reports whether the underlying stream is a TCP stream.
468    pub fn is_tcp(&self) -> bool {
469        matches!(self, Stream::Tcp(_))
470    }
471
472    /// Reports whether the underlying stream is a Unix stream.
473    pub fn is_unix(&self) -> bool {
474        matches!(self, Stream::Unix(_))
475    }
476
477    /// Returns the underlying TCP stream.
478    ///
479    /// # Panics
480    ///
481    /// Panics if the stream is not a Unix stream.
482    pub fn unwrap_tcp(self) -> TcpStream {
483        match self {
484            Stream::Tcp(stream) => stream,
485            Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
486            #[cfg(feature = "turmoil")]
487            Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
488        }
489    }
490
491    /// Returns the underlying Unix stream.
492    ///
493    /// # Panics
494    ///
495    /// Panics if the stream is not a Unix stream.
496    pub fn unwrap_unix(self) -> UnixStream {
497        match self {
498            Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
499            Stream::Unix(stream) => stream,
500            #[cfg(feature = "turmoil")]
501            Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
502        }
503    }
504}
505
506impl AsyncRead for Stream {
507    fn poll_read(
508        self: Pin<&mut Self>,
509        cx: &mut Context,
510        buf: &mut ReadBuf,
511    ) -> Poll<io::Result<()>> {
512        match self.get_mut() {
513            Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
514            Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
515            #[cfg(feature = "turmoil")]
516            Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
517        }
518    }
519}
520
521impl AsyncWrite for Stream {
522    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
523        match self.get_mut() {
524            Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
525            Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
526            #[cfg(feature = "turmoil")]
527            Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
528        }
529    }
530
531    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
532        match self.get_mut() {
533            Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
534            Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
535            #[cfg(feature = "turmoil")]
536            Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
537        }
538    }
539
540    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
541        match self.get_mut() {
542            Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
543            Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
544            #[cfg(feature = "turmoil")]
545            Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
546        }
547    }
548}
549
550impl Connected for Stream {
551    type ConnectInfo = ConnectInfo;
552
553    fn connect_info(&self) -> Self::ConnectInfo {
554        match self {
555            Stream::Tcp(stream) => ConnectInfo::Tcp(stream.connect_info()),
556            Stream::Unix(stream) => ConnectInfo::Unix(stream.connect_info()),
557            #[cfg(feature = "turmoil")]
558            Stream::Turmoil(stream) => ConnectInfo::Turmoil(TcpConnectInfo {
559                local_addr: stream.local_addr().ok(),
560                remote_addr: stream.peer_addr().ok(),
561            }),
562        }
563    }
564}
565
566/// Connection information for a [`Stream`].
567#[derive(Debug, Clone)]
568pub enum ConnectInfo {
569    /// TCP connection information.
570    Tcp(TcpConnectInfo),
571    /// Unix domain socket connection information.
572    Unix(UdsConnectInfo),
573    /// `turmoil` socket conection information.
574    Turmoil(TcpConnectInfo),
575}
576
577#[cfg(test)]
578mod tests {
579    use std::net::{Ipv4Addr, SocketAddrV4};
580
581    use super::*;
582
583    #[crate::test]
584    fn test_parse() {
585        for (input, expected) in [
586            (
587                "/valid/path",
588                Ok(SocketAddr::Unix(
589                    UnixSocketAddr::from_pathname("/valid/path").unwrap(),
590                )),
591            ),
592            (
593                "/",
594                Ok(SocketAddr::Unix(
595                    UnixSocketAddr::from_pathname("/").unwrap(),
596                )),
597            ),
598            (
599                "/\0",
600                Err(
601                    "invalid unix socket address syntax: paths must not contain interior null bytes",
602                ),
603            ),
604            (
605                "1.2.3.4:5678",
606                Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
607                    Ipv4Addr::new(1, 2, 3, 4),
608                    5678,
609                )))),
610            ),
611            ("1.2.3.4", Err("invalid internet socket address syntax")),
612            ("bad", Err("invalid internet socket address syntax")),
613            (
614                "turmoil:1.2.3.4:5678",
615                Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
616            ),
617        ] {
618            let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
619            let expected = expected.map_err(|e| e.to_string());
620            assert_eq!(actual, expected, "input: {}", input);
621        }
622    }
623}