mz_ore/netio/
socket.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License in the LICENSE file at the
6// root of this repository, or online at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll, ready};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream, tcp, unix};
27use tonic::transport::server::{Connected, TcpConnectInfo, UdsConnectInfo};
28use tracing::warn;
29
30use crate::error::ErrorExt;
31
32/// The type of a [`SocketAddr`].
33#[derive(Debug, Clone, Copy)]
34pub enum SocketAddrType {
35    /// An internet socket address.
36    Inet,
37    /// A Unix domain socket address.
38    Unix,
39    /// A `turmoil` socket address.
40    Turmoil,
41}
42
43impl SocketAddrType {
44    /// Guesses the type of socket address specified by `s`.
45    ///
46    /// * Socket addresses that are absolute paths, as determined by a leading `/` character, are
47    ///   determined to be Unix socket addresses.
48    /// * Addresses with a "turmoil:" prefix are determined to be `turmoil` socket addresses.
49    /// * All other addresses are assumed to be internet socket addresses.
50    pub fn guess(s: &str) -> SocketAddrType {
51        if s.starts_with('/') {
52            SocketAddrType::Unix
53        } else if s.starts_with("turmoil:") {
54            SocketAddrType::Turmoil
55        } else {
56            SocketAddrType::Inet
57        }
58    }
59}
60
61/// An address associated with an internet or Unix domain socket.
62#[derive(Debug, Clone)]
63pub enum SocketAddr {
64    /// An internet socket address.
65    Inet(InetSocketAddr),
66    /// A Unix domain socket address.
67    Unix(UnixSocketAddr),
68    /// A `turmoil` socket address.
69    Turmoil(String),
70}
71
72impl PartialEq for SocketAddr {
73    fn eq(&self, other: &Self) -> bool {
74        match (self, other) {
75            (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
76            (
77                SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
78                SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
79            ) => path1 == path2,
80            (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
81            _ => false,
82        }
83    }
84}
85
86impl fmt::Display for SocketAddr {
87    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88        match &self {
89            SocketAddr::Inet(addr) => addr.fmt(f),
90            SocketAddr::Unix(addr) => addr.fmt(f),
91            SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
92        }
93    }
94}
95
96impl FromStr for SocketAddr {
97    type Err = AddrParseError;
98
99    /// Parses a socket address from a string.
100    ///
101    /// Whether a socket address is taken as an internet socket address or a
102    /// Unix socket address is determined by [`SocketAddrType::guess`].
103    fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
104        match SocketAddrType::guess(s) {
105            SocketAddrType::Unix => {
106                let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
107                    kind: AddrParseErrorKind::Unix(e),
108                })?;
109                Ok(SocketAddr::Unix(addr))
110            }
111            SocketAddrType::Inet => {
112                let addr = s.parse().map_err(|_| AddrParseError {
113                    // The underlying error message is always "invalid socket
114                    // address syntax", so there's no benefit to preserving it.
115                    kind: AddrParseErrorKind::Inet,
116                })?;
117                Ok(SocketAddr::Inet(addr))
118            }
119            SocketAddrType::Turmoil => {
120                let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
121                    kind: AddrParseErrorKind::Turmoil,
122                })?;
123                Ok(SocketAddr::Turmoil(addr.into()))
124            }
125        }
126    }
127}
128
129/// An address associated with a Unix domain socket.
130#[derive(Debug, Clone)]
131pub struct UnixSocketAddr {
132    path: Option<String>,
133}
134
135impl UnixSocketAddr {
136    /// Constructs a Unix domain socket address from the provided path.
137    ///
138    /// Unlike the [`UnixSocketAddr::from_pathname`] method in the standard
139    /// library, `path` is required to be valid UTF-8.
140    ///
141    /// # Errors
142    ///
143    /// Returns an error if the path is longer than `SUN_LEN` or if it contains
144    /// null bytes.
145    ///
146    /// [`UnixSocketAddr::from_pathname`]: std::os::unix::net::SocketAddr::from_pathname
147    pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
148    where
149        S: Into<String>,
150    {
151        let path = path.into();
152        let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
153        Ok(UnixSocketAddr { path: Some(path) })
154    }
155
156    /// Constructs a Unix domain socket address representing an unnamed Unix
157    /// socket.
158    pub fn unnamed() -> UnixSocketAddr {
159        UnixSocketAddr { path: None }
160    }
161
162    /// Returns the pathname of this Unix domain socket address, if it was
163    /// constructed from a pathname.
164    pub fn as_pathname(&self) -> Option<&str> {
165        self.path.as_deref()
166    }
167}
168
169impl fmt::Display for UnixSocketAddr {
170    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
171        match &self.path {
172            None => f.write_str("<unnamed>"),
173            Some(path) => f.write_str(path),
174        }
175    }
176}
177
178/// The error returned when parsing a [`SocketAddr`] from a string.
179#[derive(Debug)]
180pub struct AddrParseError {
181    kind: AddrParseErrorKind,
182}
183
184#[derive(Debug)]
185pub enum AddrParseErrorKind {
186    Inet,
187    Unix(io::Error),
188    Turmoil,
189}
190
191impl fmt::Display for AddrParseError {
192    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193        match &self.kind {
194            AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
195            AddrParseErrorKind::Unix(e) => {
196                f.write_str("invalid unix socket address syntax: ")?;
197                e.fmt(f)
198            }
199            AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
200        }
201    }
202}
203
204impl Error for AddrParseError {}
205
206/// Converts or resolves without blocking to one or more [`SocketAddr`]s.
207#[async_trait]
208pub trait ToSocketAddrs {
209    /// Converts to resolved [`SocketAddr`]s.
210    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
211}
212
213#[async_trait]
214impl ToSocketAddrs for SocketAddr {
215    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
216        Ok(vec![self.clone()])
217    }
218}
219
220#[async_trait]
221impl<'a> ToSocketAddrs for &'a [SocketAddr] {
222    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
223        Ok(self.to_vec())
224    }
225}
226
227#[async_trait]
228impl ToSocketAddrs for InetSocketAddr {
229    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
230        Ok(vec![SocketAddr::Inet(*self)])
231    }
232}
233
234#[async_trait]
235impl ToSocketAddrs for UnixSocketAddr {
236    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
237        Ok(vec![SocketAddr::Unix(self.clone())])
238    }
239}
240
241#[async_trait]
242impl ToSocketAddrs for str {
243    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
244        match self.parse() {
245            Ok(addr) => Ok(vec![addr]),
246            Err(_) => {
247                let addrs = net::lookup_host(self).await?;
248                Ok(addrs.map(SocketAddr::Inet).collect())
249            }
250        }
251    }
252}
253
254#[async_trait]
255impl ToSocketAddrs for String {
256    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
257        (**self).to_socket_addrs().await
258    }
259}
260
261#[async_trait]
262impl<T> ToSocketAddrs for &T
263where
264    T: ToSocketAddrs + Send + Sync + ?Sized,
265{
266    async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
267        (**self).to_socket_addrs().await
268    }
269}
270
271/// A listener bound to either a TCP socket or Unix domain socket.
272pub enum Listener {
273    /// A TCP listener.
274    Tcp(TcpListener),
275    /// A Unix domain socket listener.
276    Unix(UnixListener),
277    /// A `turmoil` socket listener.
278    #[cfg(feature = "turmoil")]
279    Turmoil(turmoil::net::TcpListener),
280}
281
282impl Listener {
283    /// Creates a new listener bound to the specified socket address.
284    ///
285    /// If `addr` is a Unix domain address, this function attempts to unlink the
286    /// socket at the address, if it exists, before binding.
287    pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
288    where
289        A: ToSocketAddrs,
290    {
291        let mut last_err = None;
292        for addr in addr.to_socket_addrs().await? {
293            match Listener::bind_addr(addr).await {
294                Ok(listener) => return Ok(listener),
295                Err(e) => last_err = Some(e),
296            }
297        }
298        Err(last_err.unwrap_or_else(|| {
299            io::Error::new(
300                io::ErrorKind::InvalidInput,
301                "could not resolve to any address",
302            )
303        }))
304    }
305
306    async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
307        match &addr {
308            SocketAddr::Inet(addr) => {
309                let listener = TcpListener::bind(addr).await?;
310                Ok(Listener::Tcp(listener))
311            }
312            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
313                // We would ideally unlink the file only if we could prove that
314                // no process was still listening at the file, but there is no
315                // foolproof API for doing so.
316                // See: https://stackoverflow.com/q/7405932
317                if let Err(e) = fs::remove_file(path).await {
318                    if e.kind() != io::ErrorKind::NotFound {
319                        warn!(
320                            "unable to remove {path} while binding unix domain socket: {}",
321                            e.display_with_causes(),
322                        );
323                    }
324                }
325                let listener = UnixListener::bind(path)?;
326                Ok(Listener::Unix(listener))
327            }
328            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
329                io::ErrorKind::Other,
330                "cannot bind to unnamed Unix socket",
331            )),
332            #[cfg(feature = "turmoil")]
333            SocketAddr::Turmoil(addr) => {
334                let listener = turmoil::net::TcpListener::bind(addr).await?;
335                Ok(Listener::Turmoil(listener))
336            }
337            #[cfg(not(feature = "turmoil"))]
338            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
339        }
340    }
341
342    /// Accepts a new incoming connection to this listener.
343    ///
344    /// If the connection protocol is TCP, the returned stream has `TCP_NODELAY` set, making it
345    /// suitable for low-latency communication by default.
346    pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
347        match self {
348            Listener::Tcp(listener) => {
349                let (stream, addr) = listener.accept().await?;
350                stream.set_nodelay(true)?;
351                let stream = Stream::Tcp(stream);
352                let addr = SocketAddr::Inet(addr);
353                Ok((stream, addr))
354            }
355            Listener::Unix(listener) => {
356                let (stream, addr) = listener.accept().await?;
357                let stream = Stream::Unix(stream);
358                assert!(addr.is_unnamed());
359                let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
360                Ok((stream, addr))
361            }
362            #[cfg(feature = "turmoil")]
363            Listener::Turmoil(listener) => {
364                let (stream, addr) = listener.accept().await?;
365                let stream = Stream::Turmoil(stream);
366                let addr = SocketAddr::Inet(addr);
367                Ok((stream, addr))
368            }
369        }
370    }
371
372    fn poll_accept(
373        self: Pin<&mut Self>,
374        cx: &mut Context<'_>,
375    ) -> Poll<Option<Result<Stream, io::Error>>> {
376        match self.get_mut() {
377            Listener::Tcp(listener) => {
378                let (stream, _addr) = ready!(listener.poll_accept(cx))?;
379                stream.set_nodelay(true)?;
380                Poll::Ready(Some(Ok(Stream::Tcp(stream))))
381            }
382            Listener::Unix(listener) => {
383                let (stream, _addr) = ready!(listener.poll_accept(cx))?;
384                Poll::Ready(Some(Ok(Stream::Unix(stream))))
385            }
386            #[cfg(feature = "turmoil")]
387            Listener::Turmoil(_) => {
388                unimplemented!("`turmoil::net::TcpListener::poll_accept`");
389            }
390        }
391    }
392}
393
394impl fmt::Debug for Listener {
395    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396        match self {
397            Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
398            Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
399            #[cfg(feature = "turmoil")]
400            Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
401        }
402    }
403}
404
405impl futures::stream::Stream for Listener {
406    type Item = Result<Stream, io::Error>;
407
408    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409        self.poll_accept(cx)
410    }
411}
412
413/// A stream associated with either a TCP socket or a Unix domain socket.
414#[derive(Debug)]
415pub enum Stream {
416    /// A TCP stream.
417    Tcp(TcpStream),
418    /// A Unix domain socket stream.
419    Unix(UnixStream),
420    /// A `turmoil` socket stream.
421    #[cfg(feature = "turmoil")]
422    Turmoil(turmoil::net::TcpStream),
423}
424
425impl Stream {
426    /// Opens a connection to the specified socket address.
427    ///
428    /// If the connection protocol is TCP, the returned stream has `TCP_NODELAY` set, making it
429    /// suitable for low-latency communication by default.
430    pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
431    where
432        A: ToSocketAddrs,
433    {
434        let mut last_err = None;
435        for addr in addr.to_socket_addrs().await? {
436            match Stream::connect_addr(addr).await {
437                Ok(stream) => return Ok(stream),
438                Err(e) => last_err = Some(e),
439            }
440        }
441        Err(last_err.unwrap_or_else(|| {
442            io::Error::new(
443                io::ErrorKind::InvalidInput,
444                "could not resolve to any address",
445            )
446        }))
447    }
448
449    async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
450        match addr {
451            SocketAddr::Inet(addr) => {
452                let stream = TcpStream::connect(addr).await?;
453                stream.set_nodelay(true)?;
454                Ok(Stream::Tcp(stream))
455            }
456            SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
457                let stream = UnixStream::connect(path).await?;
458                Ok(Stream::Unix(stream))
459            }
460            SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
461                io::ErrorKind::Other,
462                "cannot connected to unnamed Unix socket",
463            )),
464            #[cfg(feature = "turmoil")]
465            SocketAddr::Turmoil(addr) => {
466                let stream = turmoil::net::TcpStream::connect(addr).await?;
467                Ok(Stream::Turmoil(stream))
468            }
469            #[cfg(not(feature = "turmoil"))]
470            SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
471        }
472    }
473
474    /// Reports whether the underlying stream is a TCP stream.
475    pub fn is_tcp(&self) -> bool {
476        matches!(self, Stream::Tcp(_))
477    }
478
479    /// Reports whether the underlying stream is a Unix stream.
480    pub fn is_unix(&self) -> bool {
481        matches!(self, Stream::Unix(_))
482    }
483
484    /// Returns the underlying TCP stream.
485    ///
486    /// # Panics
487    ///
488    /// Panics if the stream is not a Unix stream.
489    pub fn unwrap_tcp(self) -> TcpStream {
490        match self {
491            Stream::Tcp(stream) => stream,
492            Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
493            #[cfg(feature = "turmoil")]
494            Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
495        }
496    }
497
498    /// Returns the underlying Unix stream.
499    ///
500    /// # Panics
501    ///
502    /// Panics if the stream is not a Unix stream.
503    pub fn unwrap_unix(self) -> UnixStream {
504        match self {
505            Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
506            Stream::Unix(stream) => stream,
507            #[cfg(feature = "turmoil")]
508            Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
509        }
510    }
511
512    /// Splits a stream into a read half and a write half, which can be used to read and write the
513    /// stream concurrently.
514    pub fn split(self) -> (StreamReadHalf, StreamWriteHalf) {
515        match self {
516            Stream::Tcp(stream) => {
517                let (rx, tx) = stream.into_split();
518                (StreamReadHalf::Tcp(rx), StreamWriteHalf::Tcp(tx))
519            }
520            Stream::Unix(stream) => {
521                let (rx, tx) = stream.into_split();
522                (StreamReadHalf::Unix(rx), StreamWriteHalf::Unix(tx))
523            }
524            #[cfg(feature = "turmoil")]
525            Stream::Turmoil(stream) => {
526                let (rx, tx) = stream.into_split();
527                (StreamReadHalf::Turmoil(rx), StreamWriteHalf::Turmoil(tx))
528            }
529        }
530    }
531}
532
533impl AsyncRead for Stream {
534    fn poll_read(
535        self: Pin<&mut Self>,
536        cx: &mut Context,
537        buf: &mut ReadBuf,
538    ) -> Poll<io::Result<()>> {
539        match self.get_mut() {
540            Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
541            Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
542            #[cfg(feature = "turmoil")]
543            Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
544        }
545    }
546}
547
548impl AsyncWrite for Stream {
549    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
550        match self.get_mut() {
551            Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
552            Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
553            #[cfg(feature = "turmoil")]
554            Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
555        }
556    }
557
558    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
559        match self.get_mut() {
560            Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
561            Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
562            #[cfg(feature = "turmoil")]
563            Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
564        }
565    }
566
567    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
568        match self.get_mut() {
569            Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
570            Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
571            #[cfg(feature = "turmoil")]
572            Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
573        }
574    }
575}
576
577impl Connected for Stream {
578    type ConnectInfo = ConnectInfo;
579
580    fn connect_info(&self) -> Self::ConnectInfo {
581        match self {
582            Stream::Tcp(stream) => ConnectInfo::Tcp(stream.connect_info()),
583            Stream::Unix(stream) => ConnectInfo::Unix(stream.connect_info()),
584            #[cfg(feature = "turmoil")]
585            Stream::Turmoil(stream) => ConnectInfo::Turmoil(TcpConnectInfo {
586                local_addr: stream.local_addr().ok(),
587                remote_addr: stream.peer_addr().ok(),
588            }),
589        }
590    }
591}
592
593/// Read half of a [`Stream`], created by [`Stream::split`].
594#[derive(Debug)]
595pub enum StreamReadHalf {
596    Tcp(tcp::OwnedReadHalf),
597    Unix(unix::OwnedReadHalf),
598    #[cfg(feature = "turmoil")]
599    Turmoil(turmoil::net::tcp::OwnedReadHalf),
600}
601
602impl AsyncRead for StreamReadHalf {
603    fn poll_read(
604        self: Pin<&mut Self>,
605        cx: &mut Context,
606        buf: &mut ReadBuf,
607    ) -> Poll<io::Result<()>> {
608        match self.get_mut() {
609            Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf),
610            Self::Unix(rx) => Pin::new(rx).poll_read(cx, buf),
611            #[cfg(feature = "turmoil")]
612            Self::Turmoil(rx) => Pin::new(rx).poll_read(cx, buf),
613        }
614    }
615}
616
617/// Write half of a [`Stream`], created by [`Stream::split`].
618#[derive(Debug)]
619pub enum StreamWriteHalf {
620    Tcp(tcp::OwnedWriteHalf),
621    Unix(unix::OwnedWriteHalf),
622    #[cfg(feature = "turmoil")]
623    Turmoil(turmoil::net::tcp::OwnedWriteHalf),
624}
625
626impl AsyncWrite for StreamWriteHalf {
627    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
628        match self.get_mut() {
629            Self::Tcp(tx) => Pin::new(tx).poll_write(cx, buf),
630            Self::Unix(tx) => Pin::new(tx).poll_write(cx, buf),
631            #[cfg(feature = "turmoil")]
632            Self::Turmoil(tx) => Pin::new(tx).poll_write(cx, buf),
633        }
634    }
635
636    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
637        match self.get_mut() {
638            Self::Tcp(tx) => Pin::new(tx).poll_flush(cx),
639            Self::Unix(tx) => Pin::new(tx).poll_flush(cx),
640            #[cfg(feature = "turmoil")]
641            Self::Turmoil(tx) => Pin::new(tx).poll_flush(cx),
642        }
643    }
644
645    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
646        match self.get_mut() {
647            Self::Tcp(tx) => Pin::new(tx).poll_shutdown(cx),
648            Self::Unix(tx) => Pin::new(tx).poll_shutdown(cx),
649            #[cfg(feature = "turmoil")]
650            Self::Turmoil(tx) => Pin::new(tx).poll_shutdown(cx),
651        }
652    }
653}
654
655/// Connection information for a [`Stream`].
656#[derive(Debug, Clone)]
657pub enum ConnectInfo {
658    /// TCP connection information.
659    Tcp(TcpConnectInfo),
660    /// Unix domain socket connection information.
661    Unix(UdsConnectInfo),
662    /// `turmoil` socket conection information.
663    Turmoil(TcpConnectInfo),
664}
665
666#[cfg(test)]
667mod tests {
668    use std::net::{Ipv4Addr, SocketAddrV4};
669
670    use super::*;
671
672    #[crate::test]
673    fn test_parse() {
674        for (input, expected) in [
675            (
676                "/valid/path",
677                Ok(SocketAddr::Unix(
678                    UnixSocketAddr::from_pathname("/valid/path").unwrap(),
679                )),
680            ),
681            (
682                "/",
683                Ok(SocketAddr::Unix(
684                    UnixSocketAddr::from_pathname("/").unwrap(),
685                )),
686            ),
687            (
688                "/\0",
689                Err(
690                    "invalid unix socket address syntax: paths must not contain interior null bytes",
691                ),
692            ),
693            (
694                "1.2.3.4:5678",
695                Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
696                    Ipv4Addr::new(1, 2, 3, 4),
697                    5678,
698                )))),
699            ),
700            ("1.2.3.4", Err("invalid internet socket address syntax")),
701            ("bad", Err("invalid internet socket address syntax")),
702            (
703                "turmoil:1.2.3.4:5678",
704                Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
705            ),
706        ] {
707            let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
708            let expected = expected.map_err(|e| e.to_string());
709            assert_eq!(actual, expected, "input: {}", input);
710        }
711    }
712}