1use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream, tcp, unix};
27use tracing::warn;
28
29use crate::error::ErrorExt;
30
31#[derive(Debug, Clone, Copy)]
33pub enum SocketAddrType {
34 Inet,
36 Unix,
38 Turmoil,
40}
41
42impl SocketAddrType {
43 pub fn guess(s: &str) -> SocketAddrType {
50 if s.starts_with('/') {
51 SocketAddrType::Unix
52 } else if s.starts_with("turmoil:") {
53 SocketAddrType::Turmoil
54 } else {
55 SocketAddrType::Inet
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub enum SocketAddr {
63 Inet(InetSocketAddr),
65 Unix(UnixSocketAddr),
67 Turmoil(String),
69}
70
71impl PartialEq for SocketAddr {
72 fn eq(&self, other: &Self) -> bool {
73 match (self, other) {
74 (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
75 (
76 SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
77 SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
78 ) => path1 == path2,
79 (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
80 _ => false,
81 }
82 }
83}
84
85impl fmt::Display for SocketAddr {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 match &self {
88 SocketAddr::Inet(addr) => addr.fmt(f),
89 SocketAddr::Unix(addr) => addr.fmt(f),
90 SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
91 }
92 }
93}
94
95impl FromStr for SocketAddr {
96 type Err = AddrParseError;
97
98 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
103 match SocketAddrType::guess(s) {
104 SocketAddrType::Unix => {
105 let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
106 kind: AddrParseErrorKind::Unix(e),
107 })?;
108 Ok(SocketAddr::Unix(addr))
109 }
110 SocketAddrType::Inet => {
111 let addr = s.parse().map_err(|_| AddrParseError {
112 kind: AddrParseErrorKind::Inet,
115 })?;
116 Ok(SocketAddr::Inet(addr))
117 }
118 SocketAddrType::Turmoil => {
119 let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
120 kind: AddrParseErrorKind::Turmoil,
121 })?;
122 Ok(SocketAddr::Turmoil(addr.into()))
123 }
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct UnixSocketAddr {
131 path: Option<String>,
132}
133
134impl UnixSocketAddr {
135 pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
147 where
148 S: Into<String>,
149 {
150 let path = path.into();
151 let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
152 Ok(UnixSocketAddr { path: Some(path) })
153 }
154
155 pub fn unnamed() -> UnixSocketAddr {
158 UnixSocketAddr { path: None }
159 }
160
161 pub fn as_pathname(&self) -> Option<&str> {
164 self.path.as_deref()
165 }
166}
167
168impl fmt::Display for UnixSocketAddr {
169 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170 match &self.path {
171 None => f.write_str("<unnamed>"),
172 Some(path) => f.write_str(path),
173 }
174 }
175}
176
177#[derive(Debug)]
179pub struct AddrParseError {
180 kind: AddrParseErrorKind,
181}
182
183#[derive(Debug)]
184pub enum AddrParseErrorKind {
185 Inet,
186 Unix(io::Error),
187 Turmoil,
188}
189
190impl fmt::Display for AddrParseError {
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 match &self.kind {
193 AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
194 AddrParseErrorKind::Unix(e) => {
195 f.write_str("invalid unix socket address syntax: ")?;
196 e.fmt(f)
197 }
198 AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
199 }
200 }
201}
202
203impl Error for AddrParseError {}
204
205#[async_trait]
207pub trait ToSocketAddrs {
208 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
210}
211
212#[async_trait]
213impl ToSocketAddrs for SocketAddr {
214 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
215 Ok(vec![self.clone()])
216 }
217}
218
219#[async_trait]
220impl<'a> ToSocketAddrs for &'a [SocketAddr] {
221 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
222 Ok(self.to_vec())
223 }
224}
225
226#[async_trait]
227impl ToSocketAddrs for InetSocketAddr {
228 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
229 Ok(vec![SocketAddr::Inet(*self)])
230 }
231}
232
233#[async_trait]
234impl ToSocketAddrs for UnixSocketAddr {
235 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
236 Ok(vec![SocketAddr::Unix(self.clone())])
237 }
238}
239
240#[async_trait]
241impl ToSocketAddrs for str {
242 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
243 match self.parse() {
244 Ok(addr) => Ok(vec![addr]),
245 Err(_) => {
246 let addrs = net::lookup_host(self).await?;
247 Ok(addrs.map(SocketAddr::Inet).collect())
248 }
249 }
250 }
251}
252
253#[async_trait]
254impl ToSocketAddrs for String {
255 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
256 (**self).to_socket_addrs().await
257 }
258}
259
260#[async_trait]
261impl<T> ToSocketAddrs for &T
262where
263 T: ToSocketAddrs + Send + Sync + ?Sized,
264{
265 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
266 (**self).to_socket_addrs().await
267 }
268}
269
270pub enum Listener {
272 Tcp(TcpListener),
274 Unix(UnixListener),
276 #[cfg(feature = "turmoil")]
278 Turmoil(turmoil::net::TcpListener),
279}
280
281impl Listener {
282 pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
287 where
288 A: ToSocketAddrs,
289 {
290 let mut last_err = None;
291 for addr in addr.to_socket_addrs().await? {
292 match Listener::bind_addr(addr).await {
293 Ok(listener) => return Ok(listener),
294 Err(e) => last_err = Some(e),
295 }
296 }
297 Err(last_err.unwrap_or_else(|| {
298 io::Error::new(
299 io::ErrorKind::InvalidInput,
300 "could not resolve to any address",
301 )
302 }))
303 }
304
305 async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
306 match &addr {
307 SocketAddr::Inet(addr) => {
308 let listener = TcpListener::bind(addr).await?;
309 Ok(Listener::Tcp(listener))
310 }
311 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
312 if let Err(e) = fs::remove_file(path).await {
317 if e.kind() != io::ErrorKind::NotFound {
318 warn!(
319 "unable to remove {path} while binding unix domain socket: {}",
320 e.display_with_causes(),
321 );
322 }
323 }
324 let listener = UnixListener::bind(path)?;
325 Ok(Listener::Unix(listener))
326 }
327 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
328 io::ErrorKind::Other,
329 "cannot bind to unnamed Unix socket",
330 )),
331 #[cfg(feature = "turmoil")]
332 SocketAddr::Turmoil(addr) => {
333 let listener = turmoil::net::TcpListener::bind(addr).await?;
334 Ok(Listener::Turmoil(listener))
335 }
336 #[cfg(not(feature = "turmoil"))]
337 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
338 }
339 }
340
341 pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
346 match self {
347 Listener::Tcp(listener) => {
348 let (stream, addr) = listener.accept().await?;
349 stream.set_nodelay(true)?;
350 let stream = Stream::Tcp(stream);
351 let addr = SocketAddr::Inet(addr);
352 Ok((stream, addr))
353 }
354 Listener::Unix(listener) => {
355 let (stream, addr) = listener.accept().await?;
356 let stream = Stream::Unix(stream);
357 assert!(addr.is_unnamed());
358 let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
359 Ok((stream, addr))
360 }
361 #[cfg(feature = "turmoil")]
362 Listener::Turmoil(listener) => {
363 let (stream, addr) = listener.accept().await?;
364 let stream = Stream::Turmoil(stream);
365 let addr = SocketAddr::Inet(addr);
366 Ok((stream, addr))
367 }
368 }
369 }
370}
371
372impl fmt::Debug for Listener {
373 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
374 match self {
375 Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
376 Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
377 #[cfg(feature = "turmoil")]
378 Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
379 }
380 }
381}
382
383#[derive(Debug)]
385pub enum Stream {
386 Tcp(TcpStream),
388 Unix(UnixStream),
390 #[cfg(feature = "turmoil")]
392 Turmoil(turmoil::net::TcpStream),
393}
394
395impl Stream {
396 pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
401 where
402 A: ToSocketAddrs,
403 {
404 let mut last_err = None;
405 for addr in addr.to_socket_addrs().await? {
406 match Stream::connect_addr(addr).await {
407 Ok(stream) => return Ok(stream),
408 Err(e) => last_err = Some(e),
409 }
410 }
411 Err(last_err.unwrap_or_else(|| {
412 io::Error::new(
413 io::ErrorKind::InvalidInput,
414 "could not resolve to any address",
415 )
416 }))
417 }
418
419 async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
420 match addr {
421 SocketAddr::Inet(addr) => {
422 let stream = TcpStream::connect(addr).await?;
423 stream.set_nodelay(true)?;
424 Ok(Stream::Tcp(stream))
425 }
426 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
427 let stream = UnixStream::connect(path).await?;
428 Ok(Stream::Unix(stream))
429 }
430 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
431 io::ErrorKind::Other,
432 "cannot connected to unnamed Unix socket",
433 )),
434 #[cfg(feature = "turmoil")]
435 SocketAddr::Turmoil(addr) => {
436 let stream = turmoil::net::TcpStream::connect(addr).await?;
437 Ok(Stream::Turmoil(stream))
438 }
439 #[cfg(not(feature = "turmoil"))]
440 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
441 }
442 }
443
444 pub fn is_tcp(&self) -> bool {
446 matches!(self, Stream::Tcp(_))
447 }
448
449 pub fn is_unix(&self) -> bool {
451 matches!(self, Stream::Unix(_))
452 }
453
454 pub fn unwrap_tcp(self) -> TcpStream {
460 match self {
461 Stream::Tcp(stream) => stream,
462 Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
463 #[cfg(feature = "turmoil")]
464 Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
465 }
466 }
467
468 pub fn unwrap_unix(self) -> UnixStream {
474 match self {
475 Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
476 Stream::Unix(stream) => stream,
477 #[cfg(feature = "turmoil")]
478 Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
479 }
480 }
481
482 pub fn split(self) -> (StreamReadHalf, StreamWriteHalf) {
485 match self {
486 Stream::Tcp(stream) => {
487 let (rx, tx) = stream.into_split();
488 (StreamReadHalf::Tcp(rx), StreamWriteHalf::Tcp(tx))
489 }
490 Stream::Unix(stream) => {
491 let (rx, tx) = stream.into_split();
492 (StreamReadHalf::Unix(rx), StreamWriteHalf::Unix(tx))
493 }
494 #[cfg(feature = "turmoil")]
495 Stream::Turmoil(stream) => {
496 let (rx, tx) = stream.into_split();
497 (StreamReadHalf::Turmoil(rx), StreamWriteHalf::Turmoil(tx))
498 }
499 }
500 }
501}
502
503impl AsyncRead for Stream {
504 fn poll_read(
505 self: Pin<&mut Self>,
506 cx: &mut Context,
507 buf: &mut ReadBuf,
508 ) -> Poll<io::Result<()>> {
509 match self.get_mut() {
510 Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
511 Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
512 #[cfg(feature = "turmoil")]
513 Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
514 }
515 }
516}
517
518impl AsyncWrite for Stream {
519 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
520 match self.get_mut() {
521 Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
522 Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
523 #[cfg(feature = "turmoil")]
524 Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
525 }
526 }
527
528 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
529 match self.get_mut() {
530 Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
531 Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
532 #[cfg(feature = "turmoil")]
533 Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
534 }
535 }
536
537 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
538 match self.get_mut() {
539 Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
540 Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
541 #[cfg(feature = "turmoil")]
542 Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
543 }
544 }
545}
546
547#[derive(Debug)]
549pub enum StreamReadHalf {
550 Tcp(tcp::OwnedReadHalf),
551 Unix(unix::OwnedReadHalf),
552 #[cfg(feature = "turmoil")]
553 Turmoil(turmoil::net::tcp::OwnedReadHalf),
554}
555
556impl AsyncRead for StreamReadHalf {
557 fn poll_read(
558 self: Pin<&mut Self>,
559 cx: &mut Context,
560 buf: &mut ReadBuf,
561 ) -> Poll<io::Result<()>> {
562 match self.get_mut() {
563 Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf),
564 Self::Unix(rx) => Pin::new(rx).poll_read(cx, buf),
565 #[cfg(feature = "turmoil")]
566 Self::Turmoil(rx) => Pin::new(rx).poll_read(cx, buf),
567 }
568 }
569}
570
571#[derive(Debug)]
573pub enum StreamWriteHalf {
574 Tcp(tcp::OwnedWriteHalf),
575 Unix(unix::OwnedWriteHalf),
576 #[cfg(feature = "turmoil")]
577 Turmoil(turmoil::net::tcp::OwnedWriteHalf),
578}
579
580impl AsyncWrite for StreamWriteHalf {
581 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
582 match self.get_mut() {
583 Self::Tcp(tx) => Pin::new(tx).poll_write(cx, buf),
584 Self::Unix(tx) => Pin::new(tx).poll_write(cx, buf),
585 #[cfg(feature = "turmoil")]
586 Self::Turmoil(tx) => Pin::new(tx).poll_write(cx, buf),
587 }
588 }
589
590 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
591 match self.get_mut() {
592 Self::Tcp(tx) => Pin::new(tx).poll_flush(cx),
593 Self::Unix(tx) => Pin::new(tx).poll_flush(cx),
594 #[cfg(feature = "turmoil")]
595 Self::Turmoil(tx) => Pin::new(tx).poll_flush(cx),
596 }
597 }
598
599 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
600 match self.get_mut() {
601 Self::Tcp(tx) => Pin::new(tx).poll_shutdown(cx),
602 Self::Unix(tx) => Pin::new(tx).poll_shutdown(cx),
603 #[cfg(feature = "turmoil")]
604 Self::Turmoil(tx) => Pin::new(tx).poll_shutdown(cx),
605 }
606 }
607}
608
609#[cfg(test)]
610mod tests {
611 use std::net::{Ipv4Addr, SocketAddrV4};
612
613 use super::*;
614
615 #[crate::test]
616 fn test_parse() {
617 for (input, expected) in [
618 (
619 "/valid/path",
620 Ok(SocketAddr::Unix(
621 UnixSocketAddr::from_pathname("/valid/path").unwrap(),
622 )),
623 ),
624 (
625 "/",
626 Ok(SocketAddr::Unix(
627 UnixSocketAddr::from_pathname("/").unwrap(),
628 )),
629 ),
630 (
631 "/\0",
632 Err(
633 "invalid unix socket address syntax: paths must not contain interior null bytes",
634 ),
635 ),
636 (
637 "1.2.3.4:5678",
638 Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
639 Ipv4Addr::new(1, 2, 3, 4),
640 5678,
641 )))),
642 ),
643 ("1.2.3.4", Err("invalid internet socket address syntax")),
644 ("bad", Err("invalid internet socket address syntax")),
645 (
646 "turmoil:1.2.3.4:5678",
647 Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
648 ),
649 ] {
650 let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
651 let expected = expected.map_err(|e| e.to_string());
652 assert_eq!(actual, expected, "input: {}", input);
653 }
654 }
655}