mz_ore/netio/
socket.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream, tcp, unix};
27use tracing::warn;
28
29use crate::error::ErrorExt;
30
31/// The type of a [`SocketAddr`].
32#[derive(Debug, Clone, Copy)]
33pub enum SocketAddrType {
34    /// An internet socket address.
35    Inet,
36    /// A Unix domain socket address.
37    Unix,
38    /// A `turmoil` socket address.
39    Turmoil,
40}
41
42impl SocketAddrType {
43    /// Guesses the type of socket address specified by `s`.
44    ///
45    /// * Socket addresses that are absolute paths, as determined by a leading `/` character, are
46    ///   determined to be Unix socket addresses.
47    /// * Addresses with a "turmoil:" prefix are determined to be `turmoil` socket addresses.
48    /// * All other addresses are assumed to be internet socket addresses.
49    pub fn guess(s: &str) -> SocketAddrType {
50        if s.starts_with('/') {
51            SocketAddrType::Unix
52        } else if s.starts_with("turmoil:") {
53            SocketAddrType::Turmoil
54        } else {
55            SocketAddrType::Inet
56        }
57    }
58}
59
60/// An address associated with an internet or Unix domain socket.
61#[derive(Debug, Clone)]
62pub enum SocketAddr {
63    /// An internet socket address.
64    Inet(InetSocketAddr),
65    /// A Unix domain socket address.
66    Unix(UnixSocketAddr),
67    /// A `turmoil` socket address.
68    Turmoil(String),
69}
70
71impl PartialEq for SocketAddr {
72    fn eq(&self, other: &Self) -> bool {
73        match (self, other) {
74            (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
75            (
76                SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
77                SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
78            ) => path1 == path2,
79            (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
80            _ => false,
81        }
82    }
83}
84
85impl fmt::Display for SocketAddr {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        match &self {
88            SocketAddr::Inet(addr) => addr.fmt(f),
89            SocketAddr::Unix(addr) => addr.fmt(f),
90            SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
91        }
92    }
93}
94
95impl FromStr for SocketAddr {
96    type Err = AddrParseError;
97
98    /// Parses a socket address from a string.
99    ///
100    /// Whether a socket address is taken as an internet socket address or a
101    /// Unix socket address is determined by [`SocketAddrType::guess`].
102    fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
103        match SocketAddrType::guess(s) {
104            SocketAddrType::Unix => {
105                let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
106                    kind: AddrParseErrorKind::Unix(e),
107                })?;
108                Ok(SocketAddr::Unix(addr))
109            }
110            SocketAddrType::Inet => {
111                let addr = s.parse().map_err(|_| AddrParseError {
112                    // The underlying error message is always "invalid socket
113                    // address syntax", so there's no benefit to preserving it.
114                    kind: AddrParseErrorKind::Inet,
115                })?;
116                Ok(SocketAddr::Inet(addr))
117            }
118            SocketAddrType::Turmoil => {
119                let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
120                    kind: AddrParseErrorKind::Turmoil,
121                })?;
122                Ok(SocketAddr::Turmoil(addr.into()))
123            }
124        }
125    }
126}
127
128/// An address associated with a Unix domain socket.
129#[derive(Debug, Clone)]
130pub struct UnixSocketAddr {
131    path: Option<String>,
132}
133
134impl UnixSocketAddr {
135    /// Constructs a Unix domain socket address from the provided path.
136    ///
137    /// Unlike the [`UnixSocketAddr::from_pathname`] method in the standard
138    /// library, `path` is required to be valid UTF-8.
139    ///
140    /// # Errors
141    ///
142    /// Returns an error if the path is longer than `SUN_LEN` or if it contains
143    /// null bytes.
144    ///
145    /// [`UnixSocketAddr::from_pathname`]: std::os::unix::net::SocketAddr::from_pathname
146    pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
147    where
148        S: Into<String>,
149    {
150        let path = path.into();
151        let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
152        Ok(UnixSocketAddr { path: Some(path) })
153    }
154
155    /// Constructs a Unix domain socket address representing an unnamed Unix
156    /// socket.
157    pub fn unnamed() -> UnixSocketAddr {
158        UnixSocketAddr { path: None }
159    }
160
161    /// Returns the pathname of this Unix domain socket address, if it was
162    /// constructed from a pathname.
163    pub fn as_pathname(&self) -> Option<&str> {
164        self.path.as_deref()
165    }
166}
167
168impl fmt::Display for UnixSocketAddr {
169    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170        match &self.path {
171            None => f.write_str("<unnamed>"),
172            Some(path) => f.write_str(path),
173        }
174    }
175}
176
177/// The error returned when parsing a [`SocketAddr`] from a string.
178#[derive(Debug)]
179pub struct AddrParseError {
180    kind: AddrParseErrorKind,
181}
182
183#[derive(Debug)]
184pub enum AddrParseErrorKind {
185    Inet,
186    Unix(io::Error),
187    Turmoil,
188}
189
190impl fmt::Display for AddrParseError {
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        match &self.kind {
193            AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
194            AddrParseErrorKind::Unix(e) => {
195                f.write_str("invalid unix socket address syntax: ")?;
196                e.fmt(f)
197            }
198            AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
199        }
200    }
201}
202
203impl Error for AddrParseError {}
204
205/// Converts or resolves without blocking to one or more [`SocketAddr`]s.
206#[async_trait]
207pub trait ToSocketAddrs {
208    /// Converts to resolved [`SocketAddr`]s.
209    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
210}
211
212#[async_trait]
213impl ToSocketAddrs for SocketAddr {
214    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
215        Ok(vec![self.clone()])
216    }
217}
218
219#[async_trait]
220impl<'a> ToSocketAddrs for &'a [SocketAddr] {
221    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
222        Ok(self.to_vec())
223    }
224}
225
226#[async_trait]
227impl ToSocketAddrs for InetSocketAddr {
228    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
229        Ok(vec![SocketAddr::Inet(*self)])
230    }
231}
232
233#[async_trait]
234impl ToSocketAddrs for UnixSocketAddr {
235    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
236        Ok(vec![SocketAddr::Unix(self.clone())])
237    }
238}
239
240#[async_trait]
241impl ToSocketAddrs for str {
242    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
243        match self.parse() {
244            Ok(addr) => Ok(vec![addr]),
245            Err(_) => {
246                let addrs = net::lookup_host(self).await?;
247                Ok(addrs.map(SocketAddr::Inet).collect())
248            }
249        }
250    }
251}
252
253#[async_trait]
254impl ToSocketAddrs for String {
255    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
256        (**self).to_socket_addrs().await
257    }
258}
259
260#[async_trait]
261impl<T> ToSocketAddrs for &T
262where
263    T: ToSocketAddrs + Send + Sync + ?Sized,
264{
265    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
266        (**self).to_socket_addrs().await
267    }
268}
269
270/// A listener bound to either a TCP socket or Unix domain socket.
271pub enum Listener {
272    /// A TCP listener.
273    Tcp(TcpListener),
274    /// A Unix domain socket listener.
275    Unix(UnixListener),
276    /// A `turmoil` socket listener.
277    #[cfg(feature = "turmoil")]
278    Turmoil(turmoil::net::TcpListener),
279}
280
281impl Listener {
282    /// Creates a new listener bound to the specified socket address.
283    ///
284    /// If `addr` is a Unix domain address, this function attempts to unlink the
285    /// socket at the address, if it exists, before binding.
286    pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
287    where
288        A: ToSocketAddrs,
289    {
290        let mut last_err = None;
291        for addr in addr.to_socket_addrs().await? {
292            match Listener::bind_addr(addr).await {
293                Ok(listener) => return Ok(listener),
294                Err(e) => last_err = Some(e),
295            }
296        }
297        Err(last_err.unwrap_or_else(|| {
298            io::Error::new(
299                io::ErrorKind::InvalidInput,
300                "could not resolve to any address",
301            )
302        }))
303    }
304
305    async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
306        match &addr {
307            SocketAddr::Inet(addr) => {
308                let listener = TcpListener::bind(addr).await?;
309                Ok(Listener::Tcp(listener))
310            }
311            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
312                // We would ideally unlink the file only if we could prove that
313                // no process was still listening at the file, but there is no
314                // foolproof API for doing so.
315                // See: https://stackoverflow.com/q/7405932
316                if let Err(e) = fs::remove_file(path).await {
317                    if e.kind() != io::ErrorKind::NotFound {
318                        warn!(
319                            "unable to remove {path} while binding unix domain socket: {}",
320                            e.display_with_causes(),
321                        );
322                    }
323                }
324                let listener = UnixListener::bind(path)?;
325                Ok(Listener::Unix(listener))
326            }
327            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
328                io::ErrorKind::Other,
329                "cannot bind to unnamed Unix socket",
330            )),
331            #[cfg(feature = "turmoil")]
332            SocketAddr::Turmoil(addr) => {
333                let listener = turmoil::net::TcpListener::bind(addr).await?;
334                Ok(Listener::Turmoil(listener))
335            }
336            #[cfg(not(feature = "turmoil"))]
337            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
338        }
339    }
340
341    /// Accepts a new incoming connection to this listener.
342    ///
343    /// If the connection protocol is TCP, the returned stream has `TCP_NODELAY` set, making it
344    /// suitable for low-latency communication by default.
345    pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
346        match self {
347            Listener::Tcp(listener) => {
348                let (stream, addr) = listener.accept().await?;
349                stream.set_nodelay(true)?;
350                let stream = Stream::Tcp(stream);
351                let addr = SocketAddr::Inet(addr);
352                Ok((stream, addr))
353            }
354            Listener::Unix(listener) => {
355                let (stream, addr) = listener.accept().await?;
356                let stream = Stream::Unix(stream);
357                assert!(addr.is_unnamed());
358                let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
359                Ok((stream, addr))
360            }
361            #[cfg(feature = "turmoil")]
362            Listener::Turmoil(listener) => {
363                let (stream, addr) = listener.accept().await?;
364                let stream = Stream::Turmoil(stream);
365                let addr = SocketAddr::Inet(addr);
366                Ok((stream, addr))
367            }
368        }
369    }
370}
371
372impl fmt::Debug for Listener {
373    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
374        match self {
375            Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
376            Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
377            #[cfg(feature = "turmoil")]
378            Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
379        }
380    }
381}
382
383/// A stream associated with either a TCP socket or a Unix domain socket.
384#[derive(Debug)]
385pub enum Stream {
386    /// A TCP stream.
387    Tcp(TcpStream),
388    /// A Unix domain socket stream.
389    Unix(UnixStream),
390    /// A `turmoil` socket stream.
391    #[cfg(feature = "turmoil")]
392    Turmoil(turmoil::net::TcpStream),
393}
394
395impl Stream {
396    /// Opens a connection to the specified socket address.
397    ///
398    /// If the connection protocol is TCP, the returned stream has `TCP_NODELAY` set, making it
399    /// suitable for low-latency communication by default.
400    pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
401    where
402        A: ToSocketAddrs,
403    {
404        let mut last_err = None;
405        for addr in addr.to_socket_addrs().await? {
406            match Stream::connect_addr(addr).await {
407                Ok(stream) => return Ok(stream),
408                Err(e) => last_err = Some(e),
409            }
410        }
411        Err(last_err.unwrap_or_else(|| {
412            io::Error::new(
413                io::ErrorKind::InvalidInput,
414                "could not resolve to any address",
415            )
416        }))
417    }
418
419    async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
420        match addr {
421            SocketAddr::Inet(addr) => {
422                let stream = TcpStream::connect(addr).await?;
423                stream.set_nodelay(true)?;
424                Ok(Stream::Tcp(stream))
425            }
426            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
427                let stream = UnixStream::connect(path).await?;
428                Ok(Stream::Unix(stream))
429            }
430            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
431                io::ErrorKind::Other,
432                "cannot connected to unnamed Unix socket",
433            )),
434            #[cfg(feature = "turmoil")]
435            SocketAddr::Turmoil(addr) => {
436                let stream = turmoil::net::TcpStream::connect(addr).await?;
437                Ok(Stream::Turmoil(stream))
438            }
439            #[cfg(not(feature = "turmoil"))]
440            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
441        }
442    }
443
444    /// Reports whether the underlying stream is a TCP stream.
445    pub fn is_tcp(&self) -> bool {
446        matches!(self, Stream::Tcp(_))
447    }
448
449    /// Reports whether the underlying stream is a Unix stream.
450    pub fn is_unix(&self) -> bool {
451        matches!(self, Stream::Unix(_))
452    }
453
454    /// Returns the underlying TCP stream.
455    ///
456    /// # Panics
457    ///
458    /// Panics if the stream is not a Unix stream.
459    pub fn unwrap_tcp(self) -> TcpStream {
460        match self {
461            Stream::Tcp(stream) => stream,
462            Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
463            #[cfg(feature = "turmoil")]
464            Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
465        }
466    }
467
468    /// Returns the underlying Unix stream.
469    ///
470    /// # Panics
471    ///
472    /// Panics if the stream is not a Unix stream.
473    pub fn unwrap_unix(self) -> UnixStream {
474        match self {
475            Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
476            Stream::Unix(stream) => stream,
477            #[cfg(feature = "turmoil")]
478            Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
479        }
480    }
481
482    /// Splits a stream into a read half and a write half, which can be used to read and write the
483    /// stream concurrently.
484    pub fn split(self) -> (StreamReadHalf, StreamWriteHalf) {
485        match self {
486            Stream::Tcp(stream) => {
487                let (rx, tx) = stream.into_split();
488                (StreamReadHalf::Tcp(rx), StreamWriteHalf::Tcp(tx))
489            }
490            Stream::Unix(stream) => {
491                let (rx, tx) = stream.into_split();
492                (StreamReadHalf::Unix(rx), StreamWriteHalf::Unix(tx))
493            }
494            #[cfg(feature = "turmoil")]
495            Stream::Turmoil(stream) => {
496                let (rx, tx) = stream.into_split();
497                (StreamReadHalf::Turmoil(rx), StreamWriteHalf::Turmoil(tx))
498            }
499        }
500    }
501}
502
503impl AsyncRead for Stream {
504    fn poll_read(
505        self: Pin<&mut Self>,
506        cx: &mut Context,
507        buf: &mut ReadBuf,
508    ) -> Poll<io::Result<()>> {
509        match self.get_mut() {
510            Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
511            Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
512            #[cfg(feature = "turmoil")]
513            Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
514        }
515    }
516}
517
518impl AsyncWrite for Stream {
519    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
520        match self.get_mut() {
521            Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
522            Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
523            #[cfg(feature = "turmoil")]
524            Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
525        }
526    }
527
528    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
529        match self.get_mut() {
530            Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
531            Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
532            #[cfg(feature = "turmoil")]
533            Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
534        }
535    }
536
537    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
538        match self.get_mut() {
539            Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
540            Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
541            #[cfg(feature = "turmoil")]
542            Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
543        }
544    }
545}
546
547/// Read half of a [`Stream`], created by [`Stream::split`].
548#[derive(Debug)]
549pub enum StreamReadHalf {
550    Tcp(tcp::OwnedReadHalf),
551    Unix(unix::OwnedReadHalf),
552    #[cfg(feature = "turmoil")]
553    Turmoil(turmoil::net::tcp::OwnedReadHalf),
554}
555
556impl AsyncRead for StreamReadHalf {
557    fn poll_read(
558        self: Pin<&mut Self>,
559        cx: &mut Context,
560        buf: &mut ReadBuf,
561    ) -> Poll<io::Result<()>> {
562        match self.get_mut() {
563            Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf),
564            Self::Unix(rx) => Pin::new(rx).poll_read(cx, buf),
565            #[cfg(feature = "turmoil")]
566            Self::Turmoil(rx) => Pin::new(rx).poll_read(cx, buf),
567        }
568    }
569}
570
571/// Write half of a [`Stream`], created by [`Stream::split`].
572#[derive(Debug)]
573pub enum StreamWriteHalf {
574    Tcp(tcp::OwnedWriteHalf),
575    Unix(unix::OwnedWriteHalf),
576    #[cfg(feature = "turmoil")]
577    Turmoil(turmoil::net::tcp::OwnedWriteHalf),
578}
579
580impl AsyncWrite for StreamWriteHalf {
581    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
582        match self.get_mut() {
583            Self::Tcp(tx) => Pin::new(tx).poll_write(cx, buf),
584            Self::Unix(tx) => Pin::new(tx).poll_write(cx, buf),
585            #[cfg(feature = "turmoil")]
586            Self::Turmoil(tx) => Pin::new(tx).poll_write(cx, buf),
587        }
588    }
589
590    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
591        match self.get_mut() {
592            Self::Tcp(tx) => Pin::new(tx).poll_flush(cx),
593            Self::Unix(tx) => Pin::new(tx).poll_flush(cx),
594            #[cfg(feature = "turmoil")]
595            Self::Turmoil(tx) => Pin::new(tx).poll_flush(cx),
596        }
597    }
598
599    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
600        match self.get_mut() {
601            Self::Tcp(tx) => Pin::new(tx).poll_shutdown(cx),
602            Self::Unix(tx) => Pin::new(tx).poll_shutdown(cx),
603            #[cfg(feature = "turmoil")]
604            Self::Turmoil(tx) => Pin::new(tx).poll_shutdown(cx),
605        }
606    }
607}
608
609#[cfg(test)]
610mod tests {
611    use std::net::{Ipv4Addr, SocketAddrV4};
612
613    use super::*;
614
615    #[crate::test]
616    fn test_parse() {
617        for (input, expected) in [
618            (
619                "/valid/path",
620                Ok(SocketAddr::Unix(
621                    UnixSocketAddr::from_pathname("/valid/path").unwrap(),
622                )),
623            ),
624            (
625                "/",
626                Ok(SocketAddr::Unix(
627                    UnixSocketAddr::from_pathname("/").unwrap(),
628                )),
629            ),
630            (
631                "/\0",
632                Err(
633                    "invalid unix socket address syntax: paths must not contain interior null bytes",
634                ),
635            ),
636            (
637                "1.2.3.4:5678",
638                Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
639                    Ipv4Addr::new(1, 2, 3, 4),
640                    5678,
641                )))),
642            ),
643            ("1.2.3.4", Err("invalid internet socket address syntax")),
644            ("bad", Err("invalid internet socket address syntax")),
645            (
646                "turmoil:1.2.3.4:5678",
647                Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
648            ),
649        ] {
650            let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
651            let expected = expected.map_err(|e| e.to_string());
652            assert_eq!(actual, expected, "input: {}", input);
653        }
654    }
655}