1use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll, ready};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream, tcp, unix};
27use tonic::transport::server::{Connected, TcpConnectInfo, UdsConnectInfo};
28use tracing::warn;
29
30use crate::error::ErrorExt;
31
32#[derive(Debug, Clone, Copy)]
34pub enum SocketAddrType {
35 Inet,
37 Unix,
39 Turmoil,
41}
42
43impl SocketAddrType {
44 pub fn guess(s: &str) -> SocketAddrType {
51 if s.starts_with('/') {
52 SocketAddrType::Unix
53 } else if s.starts_with("turmoil:") {
54 SocketAddrType::Turmoil
55 } else {
56 SocketAddrType::Inet
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub enum SocketAddr {
64 Inet(InetSocketAddr),
66 Unix(UnixSocketAddr),
68 Turmoil(String),
70}
71
72impl PartialEq for SocketAddr {
73 fn eq(&self, other: &Self) -> bool {
74 match (self, other) {
75 (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
76 (
77 SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
78 SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
79 ) => path1 == path2,
80 (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
81 _ => false,
82 }
83 }
84}
85
86impl fmt::Display for SocketAddr {
87 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88 match &self {
89 SocketAddr::Inet(addr) => addr.fmt(f),
90 SocketAddr::Unix(addr) => addr.fmt(f),
91 SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
92 }
93 }
94}
95
96impl FromStr for SocketAddr {
97 type Err = AddrParseError;
98
99 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
104 match SocketAddrType::guess(s) {
105 SocketAddrType::Unix => {
106 let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
107 kind: AddrParseErrorKind::Unix(e),
108 })?;
109 Ok(SocketAddr::Unix(addr))
110 }
111 SocketAddrType::Inet => {
112 let addr = s.parse().map_err(|_| AddrParseError {
113 kind: AddrParseErrorKind::Inet,
116 })?;
117 Ok(SocketAddr::Inet(addr))
118 }
119 SocketAddrType::Turmoil => {
120 let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
121 kind: AddrParseErrorKind::Turmoil,
122 })?;
123 Ok(SocketAddr::Turmoil(addr.into()))
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct UnixSocketAddr {
132 path: Option<String>,
133}
134
135impl UnixSocketAddr {
136 pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
148 where
149 S: Into<String>,
150 {
151 let path = path.into();
152 let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
153 Ok(UnixSocketAddr { path: Some(path) })
154 }
155
156 pub fn unnamed() -> UnixSocketAddr {
159 UnixSocketAddr { path: None }
160 }
161
162 pub fn as_pathname(&self) -> Option<&str> {
165 self.path.as_deref()
166 }
167}
168
169impl fmt::Display for UnixSocketAddr {
170 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
171 match &self.path {
172 None => f.write_str("<unnamed>"),
173 Some(path) => f.write_str(path),
174 }
175 }
176}
177
178#[derive(Debug)]
180pub struct AddrParseError {
181 kind: AddrParseErrorKind,
182}
183
184#[derive(Debug)]
185pub enum AddrParseErrorKind {
186 Inet,
187 Unix(io::Error),
188 Turmoil,
189}
190
191impl fmt::Display for AddrParseError {
192 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 match &self.kind {
194 AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
195 AddrParseErrorKind::Unix(e) => {
196 f.write_str("invalid unix socket address syntax: ")?;
197 e.fmt(f)
198 }
199 AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
200 }
201 }
202}
203
204impl Error for AddrParseError {}
205
206#[async_trait]
208pub trait ToSocketAddrs {
209 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
211}
212
213#[async_trait]
214impl ToSocketAddrs for SocketAddr {
215 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
216 Ok(vec![self.clone()])
217 }
218}
219
220#[async_trait]
221impl<'a> ToSocketAddrs for &'a [SocketAddr] {
222 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
223 Ok(self.to_vec())
224 }
225}
226
227#[async_trait]
228impl ToSocketAddrs for InetSocketAddr {
229 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
230 Ok(vec![SocketAddr::Inet(*self)])
231 }
232}
233
234#[async_trait]
235impl ToSocketAddrs for UnixSocketAddr {
236 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
237 Ok(vec![SocketAddr::Unix(self.clone())])
238 }
239}
240
241#[async_trait]
242impl ToSocketAddrs for str {
243 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
244 match self.parse() {
245 Ok(addr) => Ok(vec![addr]),
246 Err(_) => {
247 let addrs = net::lookup_host(self).await?;
248 Ok(addrs.map(SocketAddr::Inet).collect())
249 }
250 }
251 }
252}
253
254#[async_trait]
255impl ToSocketAddrs for String {
256 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
257 (**self).to_socket_addrs().await
258 }
259}
260
261#[async_trait]
262impl<T> ToSocketAddrs for &T
263where
264 T: ToSocketAddrs + Send + Sync + ?Sized,
265{
266 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
267 (**self).to_socket_addrs().await
268 }
269}
270
271pub enum Listener {
273 Tcp(TcpListener),
275 Unix(UnixListener),
277 #[cfg(feature = "turmoil")]
279 Turmoil(turmoil::net::TcpListener),
280}
281
282impl Listener {
283 pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
288 where
289 A: ToSocketAddrs,
290 {
291 let mut last_err = None;
292 for addr in addr.to_socket_addrs().await? {
293 match Listener::bind_addr(addr).await {
294 Ok(listener) => return Ok(listener),
295 Err(e) => last_err = Some(e),
296 }
297 }
298 Err(last_err.unwrap_or_else(|| {
299 io::Error::new(
300 io::ErrorKind::InvalidInput,
301 "could not resolve to any address",
302 )
303 }))
304 }
305
306 async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
307 match &addr {
308 SocketAddr::Inet(addr) => {
309 let listener = TcpListener::bind(addr).await?;
310 Ok(Listener::Tcp(listener))
311 }
312 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
313 if let Err(e) = fs::remove_file(path).await {
318 if e.kind() != io::ErrorKind::NotFound {
319 warn!(
320 "unable to remove {path} while binding unix domain socket: {}",
321 e.display_with_causes(),
322 );
323 }
324 }
325 let listener = UnixListener::bind(path)?;
326 Ok(Listener::Unix(listener))
327 }
328 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
329 io::ErrorKind::Other,
330 "cannot bind to unnamed Unix socket",
331 )),
332 #[cfg(feature = "turmoil")]
333 SocketAddr::Turmoil(addr) => {
334 let listener = turmoil::net::TcpListener::bind(addr).await?;
335 Ok(Listener::Turmoil(listener))
336 }
337 #[cfg(not(feature = "turmoil"))]
338 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
339 }
340 }
341
342 pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
347 match self {
348 Listener::Tcp(listener) => {
349 let (stream, addr) = listener.accept().await?;
350 stream.set_nodelay(true)?;
351 let stream = Stream::Tcp(stream);
352 let addr = SocketAddr::Inet(addr);
353 Ok((stream, addr))
354 }
355 Listener::Unix(listener) => {
356 let (stream, addr) = listener.accept().await?;
357 let stream = Stream::Unix(stream);
358 assert!(addr.is_unnamed());
359 let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
360 Ok((stream, addr))
361 }
362 #[cfg(feature = "turmoil")]
363 Listener::Turmoil(listener) => {
364 let (stream, addr) = listener.accept().await?;
365 let stream = Stream::Turmoil(stream);
366 let addr = SocketAddr::Inet(addr);
367 Ok((stream, addr))
368 }
369 }
370 }
371
372 fn poll_accept(
373 self: Pin<&mut Self>,
374 cx: &mut Context<'_>,
375 ) -> Poll<Option<Result<Stream, io::Error>>> {
376 match self.get_mut() {
377 Listener::Tcp(listener) => {
378 let (stream, _addr) = ready!(listener.poll_accept(cx))?;
379 stream.set_nodelay(true)?;
380 Poll::Ready(Some(Ok(Stream::Tcp(stream))))
381 }
382 Listener::Unix(listener) => {
383 let (stream, _addr) = ready!(listener.poll_accept(cx))?;
384 Poll::Ready(Some(Ok(Stream::Unix(stream))))
385 }
386 #[cfg(feature = "turmoil")]
387 Listener::Turmoil(_) => {
388 unimplemented!("`turmoil::net::TcpListener::poll_accept`");
389 }
390 }
391 }
392}
393
394impl fmt::Debug for Listener {
395 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
396 match self {
397 Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
398 Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
399 #[cfg(feature = "turmoil")]
400 Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
401 }
402 }
403}
404
405impl futures::stream::Stream for Listener {
406 type Item = Result<Stream, io::Error>;
407
408 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
409 self.poll_accept(cx)
410 }
411}
412
413#[derive(Debug)]
415pub enum Stream {
416 Tcp(TcpStream),
418 Unix(UnixStream),
420 #[cfg(feature = "turmoil")]
422 Turmoil(turmoil::net::TcpStream),
423}
424
425impl Stream {
426 pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
431 where
432 A: ToSocketAddrs,
433 {
434 let mut last_err = None;
435 for addr in addr.to_socket_addrs().await? {
436 match Stream::connect_addr(addr).await {
437 Ok(stream) => return Ok(stream),
438 Err(e) => last_err = Some(e),
439 }
440 }
441 Err(last_err.unwrap_or_else(|| {
442 io::Error::new(
443 io::ErrorKind::InvalidInput,
444 "could not resolve to any address",
445 )
446 }))
447 }
448
449 async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
450 match addr {
451 SocketAddr::Inet(addr) => {
452 let stream = TcpStream::connect(addr).await?;
453 stream.set_nodelay(true)?;
454 Ok(Stream::Tcp(stream))
455 }
456 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
457 let stream = UnixStream::connect(path).await?;
458 Ok(Stream::Unix(stream))
459 }
460 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
461 io::ErrorKind::Other,
462 "cannot connected to unnamed Unix socket",
463 )),
464 #[cfg(feature = "turmoil")]
465 SocketAddr::Turmoil(addr) => {
466 let stream = turmoil::net::TcpStream::connect(addr).await?;
467 Ok(Stream::Turmoil(stream))
468 }
469 #[cfg(not(feature = "turmoil"))]
470 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
471 }
472 }
473
474 pub fn is_tcp(&self) -> bool {
476 matches!(self, Stream::Tcp(_))
477 }
478
479 pub fn is_unix(&self) -> bool {
481 matches!(self, Stream::Unix(_))
482 }
483
484 pub fn unwrap_tcp(self) -> TcpStream {
490 match self {
491 Stream::Tcp(stream) => stream,
492 Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
493 #[cfg(feature = "turmoil")]
494 Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
495 }
496 }
497
498 pub fn unwrap_unix(self) -> UnixStream {
504 match self {
505 Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
506 Stream::Unix(stream) => stream,
507 #[cfg(feature = "turmoil")]
508 Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
509 }
510 }
511
512 pub fn split(self) -> (StreamReadHalf, StreamWriteHalf) {
515 match self {
516 Stream::Tcp(stream) => {
517 let (rx, tx) = stream.into_split();
518 (StreamReadHalf::Tcp(rx), StreamWriteHalf::Tcp(tx))
519 }
520 Stream::Unix(stream) => {
521 let (rx, tx) = stream.into_split();
522 (StreamReadHalf::Unix(rx), StreamWriteHalf::Unix(tx))
523 }
524 #[cfg(feature = "turmoil")]
525 Stream::Turmoil(stream) => {
526 let (rx, tx) = stream.into_split();
527 (StreamReadHalf::Turmoil(rx), StreamWriteHalf::Turmoil(tx))
528 }
529 }
530 }
531}
532
533impl AsyncRead for Stream {
534 fn poll_read(
535 self: Pin<&mut Self>,
536 cx: &mut Context,
537 buf: &mut ReadBuf,
538 ) -> Poll<io::Result<()>> {
539 match self.get_mut() {
540 Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
541 Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
542 #[cfg(feature = "turmoil")]
543 Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
544 }
545 }
546}
547
548impl AsyncWrite for Stream {
549 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
550 match self.get_mut() {
551 Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
552 Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
553 #[cfg(feature = "turmoil")]
554 Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
555 }
556 }
557
558 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
559 match self.get_mut() {
560 Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
561 Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
562 #[cfg(feature = "turmoil")]
563 Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
564 }
565 }
566
567 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
568 match self.get_mut() {
569 Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
570 Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
571 #[cfg(feature = "turmoil")]
572 Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
573 }
574 }
575}
576
577impl Connected for Stream {
578 type ConnectInfo = ConnectInfo;
579
580 fn connect_info(&self) -> Self::ConnectInfo {
581 match self {
582 Stream::Tcp(stream) => ConnectInfo::Tcp(stream.connect_info()),
583 Stream::Unix(stream) => ConnectInfo::Unix(stream.connect_info()),
584 #[cfg(feature = "turmoil")]
585 Stream::Turmoil(stream) => ConnectInfo::Turmoil(TcpConnectInfo {
586 local_addr: stream.local_addr().ok(),
587 remote_addr: stream.peer_addr().ok(),
588 }),
589 }
590 }
591}
592
593#[derive(Debug)]
595pub enum StreamReadHalf {
596 Tcp(tcp::OwnedReadHalf),
597 Unix(unix::OwnedReadHalf),
598 #[cfg(feature = "turmoil")]
599 Turmoil(turmoil::net::tcp::OwnedReadHalf),
600}
601
602impl AsyncRead for StreamReadHalf {
603 fn poll_read(
604 self: Pin<&mut Self>,
605 cx: &mut Context,
606 buf: &mut ReadBuf,
607 ) -> Poll<io::Result<()>> {
608 match self.get_mut() {
609 Self::Tcp(rx) => Pin::new(rx).poll_read(cx, buf),
610 Self::Unix(rx) => Pin::new(rx).poll_read(cx, buf),
611 #[cfg(feature = "turmoil")]
612 Self::Turmoil(rx) => Pin::new(rx).poll_read(cx, buf),
613 }
614 }
615}
616
617#[derive(Debug)]
619pub enum StreamWriteHalf {
620 Tcp(tcp::OwnedWriteHalf),
621 Unix(unix::OwnedWriteHalf),
622 #[cfg(feature = "turmoil")]
623 Turmoil(turmoil::net::tcp::OwnedWriteHalf),
624}
625
626impl AsyncWrite for StreamWriteHalf {
627 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
628 match self.get_mut() {
629 Self::Tcp(tx) => Pin::new(tx).poll_write(cx, buf),
630 Self::Unix(tx) => Pin::new(tx).poll_write(cx, buf),
631 #[cfg(feature = "turmoil")]
632 Self::Turmoil(tx) => Pin::new(tx).poll_write(cx, buf),
633 }
634 }
635
636 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
637 match self.get_mut() {
638 Self::Tcp(tx) => Pin::new(tx).poll_flush(cx),
639 Self::Unix(tx) => Pin::new(tx).poll_flush(cx),
640 #[cfg(feature = "turmoil")]
641 Self::Turmoil(tx) => Pin::new(tx).poll_flush(cx),
642 }
643 }
644
645 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
646 match self.get_mut() {
647 Self::Tcp(tx) => Pin::new(tx).poll_shutdown(cx),
648 Self::Unix(tx) => Pin::new(tx).poll_shutdown(cx),
649 #[cfg(feature = "turmoil")]
650 Self::Turmoil(tx) => Pin::new(tx).poll_shutdown(cx),
651 }
652 }
653}
654
655#[derive(Debug, Clone)]
657pub enum ConnectInfo {
658 Tcp(TcpConnectInfo),
660 Unix(UdsConnectInfo),
662 Turmoil(TcpConnectInfo),
664}
665
666#[cfg(test)]
667mod tests {
668 use std::net::{Ipv4Addr, SocketAddrV4};
669
670 use super::*;
671
672 #[crate::test]
673 fn test_parse() {
674 for (input, expected) in [
675 (
676 "/valid/path",
677 Ok(SocketAddr::Unix(
678 UnixSocketAddr::from_pathname("/valid/path").unwrap(),
679 )),
680 ),
681 (
682 "/",
683 Ok(SocketAddr::Unix(
684 UnixSocketAddr::from_pathname("/").unwrap(),
685 )),
686 ),
687 (
688 "/\0",
689 Err(
690 "invalid unix socket address syntax: paths must not contain interior null bytes",
691 ),
692 ),
693 (
694 "1.2.3.4:5678",
695 Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
696 Ipv4Addr::new(1, 2, 3, 4),
697 5678,
698 )))),
699 ),
700 ("1.2.3.4", Err("invalid internet socket address syntax")),
701 ("bad", Err("invalid internet socket address syntax")),
702 (
703 "turmoil:1.2.3.4:5678",
704 Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
705 ),
706 ] {
707 let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
708 let expected = expected.map_err(|e| e.to_string());
709 assert_eq!(actual, expected, "input: {}", input);
710 }
711 }
712}