1use std::error::Error;
17use std::net::SocketAddr as InetSocketAddr;
18use std::pin::Pin;
19use std::str::FromStr;
20use std::task::{Context, Poll, ready};
21use std::{fmt, io};
22
23use async_trait::async_trait;
24use tokio::fs;
25use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
26use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream};
27use tonic::transport::server::{Connected, TcpConnectInfo, UdsConnectInfo};
28use tracing::warn;
29
30use crate::error::ErrorExt;
31
32#[derive(Debug, Clone, Copy)]
34pub enum SocketAddrType {
35 Inet,
37 Unix,
39 Turmoil,
41}
42
43impl SocketAddrType {
44 pub fn guess(s: &str) -> SocketAddrType {
51 if s.starts_with('/') {
52 SocketAddrType::Unix
53 } else if s.starts_with("turmoil:") {
54 SocketAddrType::Turmoil
55 } else {
56 SocketAddrType::Inet
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub enum SocketAddr {
64 Inet(InetSocketAddr),
66 Unix(UnixSocketAddr),
68 Turmoil(String),
70}
71
72impl PartialEq for SocketAddr {
73 fn eq(&self, other: &Self) -> bool {
74 match (self, other) {
75 (SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
76 (
77 SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
78 SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
79 ) => path1 == path2,
80 (SocketAddr::Turmoil(addr1), SocketAddr::Turmoil(addr2)) => addr1 == addr2,
81 _ => false,
82 }
83 }
84}
85
86impl fmt::Display for SocketAddr {
87 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
88 match &self {
89 SocketAddr::Inet(addr) => addr.fmt(f),
90 SocketAddr::Unix(addr) => addr.fmt(f),
91 SocketAddr::Turmoil(addr) => write!(f, "turmoil:{addr}"),
92 }
93 }
94}
95
96impl FromStr for SocketAddr {
97 type Err = AddrParseError;
98
99 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
104 match SocketAddrType::guess(s) {
105 SocketAddrType::Unix => {
106 let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
107 kind: AddrParseErrorKind::Unix(e),
108 })?;
109 Ok(SocketAddr::Unix(addr))
110 }
111 SocketAddrType::Inet => {
112 let addr = s.parse().map_err(|_| AddrParseError {
113 kind: AddrParseErrorKind::Inet,
116 })?;
117 Ok(SocketAddr::Inet(addr))
118 }
119 SocketAddrType::Turmoil => {
120 let addr = s.strip_prefix("turmoil:").ok_or(AddrParseError {
121 kind: AddrParseErrorKind::Turmoil,
122 })?;
123 Ok(SocketAddr::Turmoil(addr.into()))
124 }
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct UnixSocketAddr {
132 path: Option<String>,
133}
134
135impl UnixSocketAddr {
136 pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
148 where
149 S: Into<String>,
150 {
151 let path = path.into();
152 let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
153 Ok(UnixSocketAddr { path: Some(path) })
154 }
155
156 pub fn unnamed() -> UnixSocketAddr {
159 UnixSocketAddr { path: None }
160 }
161
162 pub fn as_pathname(&self) -> Option<&str> {
165 self.path.as_deref()
166 }
167}
168
169impl fmt::Display for UnixSocketAddr {
170 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
171 match &self.path {
172 None => f.write_str("<unnamed>"),
173 Some(path) => f.write_str(path),
174 }
175 }
176}
177
178#[derive(Debug)]
180pub struct AddrParseError {
181 kind: AddrParseErrorKind,
182}
183
184#[derive(Debug)]
185pub enum AddrParseErrorKind {
186 Inet,
187 Unix(io::Error),
188 Turmoil,
189}
190
191impl fmt::Display for AddrParseError {
192 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 match &self.kind {
194 AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
195 AddrParseErrorKind::Unix(e) => {
196 f.write_str("invalid unix socket address syntax: ")?;
197 e.fmt(f)
198 }
199 AddrParseErrorKind::Turmoil => f.write_str("missing 'turmoil:' prefix"),
200 }
201 }
202}
203
204impl Error for AddrParseError {}
205
206#[async_trait]
208pub trait ToSocketAddrs {
209 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
211}
212
213#[async_trait]
214impl ToSocketAddrs for SocketAddr {
215 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
216 Ok(vec![self.clone()])
217 }
218}
219
220#[async_trait]
221impl<'a> ToSocketAddrs for &'a [SocketAddr] {
222 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
223 Ok(self.to_vec())
224 }
225}
226
227#[async_trait]
228impl ToSocketAddrs for InetSocketAddr {
229 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
230 Ok(vec![SocketAddr::Inet(*self)])
231 }
232}
233
234#[async_trait]
235impl ToSocketAddrs for UnixSocketAddr {
236 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
237 Ok(vec![SocketAddr::Unix(self.clone())])
238 }
239}
240
241#[async_trait]
242impl ToSocketAddrs for str {
243 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
244 match self.parse() {
245 Ok(addr) => Ok(vec![addr]),
246 Err(_) => {
247 let addrs = net::lookup_host(self).await?;
248 Ok(addrs.map(SocketAddr::Inet).collect())
249 }
250 }
251 }
252}
253
254#[async_trait]
255impl ToSocketAddrs for String {
256 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
257 (**self).to_socket_addrs().await
258 }
259}
260
261#[async_trait]
262impl<T> ToSocketAddrs for &T
263where
264 T: ToSocketAddrs + Send + Sync + ?Sized,
265{
266 async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
267 (**self).to_socket_addrs().await
268 }
269}
270
271pub enum Listener {
273 Tcp(TcpListener),
275 Unix(UnixListener),
277 #[cfg(feature = "turmoil")]
279 Turmoil(turmoil::net::TcpListener),
280}
281
282impl Listener {
283 pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
288 where
289 A: ToSocketAddrs,
290 {
291 let mut last_err = None;
292 for addr in addr.to_socket_addrs().await? {
293 match Listener::bind_addr(addr).await {
294 Ok(listener) => return Ok(listener),
295 Err(e) => last_err = Some(e),
296 }
297 }
298 Err(last_err.unwrap_or_else(|| {
299 io::Error::new(
300 io::ErrorKind::InvalidInput,
301 "could not resolve to any address",
302 )
303 }))
304 }
305
306 async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
307 match &addr {
308 SocketAddr::Inet(addr) => {
309 let listener = TcpListener::bind(addr).await?;
310 Ok(Listener::Tcp(listener))
311 }
312 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
313 if let Err(e) = fs::remove_file(path).await {
318 if e.kind() != io::ErrorKind::NotFound {
319 warn!(
320 "unable to remove {path} while binding unix domain socket: {}",
321 e.display_with_causes(),
322 );
323 }
324 }
325 let listener = UnixListener::bind(path)?;
326 Ok(Listener::Unix(listener))
327 }
328 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
329 io::ErrorKind::Other,
330 "cannot bind to unnamed Unix socket",
331 )),
332 #[cfg(feature = "turmoil")]
333 SocketAddr::Turmoil(addr) => {
334 let listener = turmoil::net::TcpListener::bind(addr).await?;
335 Ok(Listener::Turmoil(listener))
336 }
337 #[cfg(not(feature = "turmoil"))]
338 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
339 }
340 }
341
342 pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
344 match self {
345 Listener::Tcp(listener) => {
346 let (stream, addr) = listener.accept().await?;
347 stream.set_nodelay(true)?;
348 let stream = Stream::Tcp(stream);
349 let addr = SocketAddr::Inet(addr);
350 Ok((stream, addr))
351 }
352 Listener::Unix(listener) => {
353 let (stream, addr) = listener.accept().await?;
354 let stream = Stream::Unix(stream);
355 assert!(addr.is_unnamed());
356 let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
357 Ok((stream, addr))
358 }
359 #[cfg(feature = "turmoil")]
360 Listener::Turmoil(listener) => {
361 let (stream, addr) = listener.accept().await?;
362 let stream = Stream::Turmoil(stream);
363 let addr = SocketAddr::Inet(addr);
364 Ok((stream, addr))
365 }
366 }
367 }
368
369 fn poll_accept(
370 self: Pin<&mut Self>,
371 cx: &mut Context<'_>,
372 ) -> Poll<Option<Result<Stream, io::Error>>> {
373 match self.get_mut() {
374 Listener::Tcp(listener) => {
375 let (stream, _addr) = ready!(listener.poll_accept(cx))?;
376 stream.set_nodelay(true)?;
377 Poll::Ready(Some(Ok(Stream::Tcp(stream))))
378 }
379 Listener::Unix(listener) => {
380 let (stream, _addr) = ready!(listener.poll_accept(cx))?;
381 Poll::Ready(Some(Ok(Stream::Unix(stream))))
382 }
383 #[cfg(feature = "turmoil")]
384 Listener::Turmoil(_) => {
385 unimplemented!("`turmoil::net::TcpListener::poll_accept`");
386 }
387 }
388 }
389}
390
391impl fmt::Debug for Listener {
392 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
393 match self {
394 Self::Tcp(inner) => f.debug_tuple("Tcp").field(inner).finish(),
395 Self::Unix(inner) => f.debug_tuple("Unix").field(inner).finish(),
396 #[cfg(feature = "turmoil")]
397 Self::Turmoil(_) => f.debug_tuple("Turmoil").finish(),
398 }
399 }
400}
401
402impl futures::stream::Stream for Listener {
403 type Item = Result<Stream, io::Error>;
404
405 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
406 self.poll_accept(cx)
407 }
408}
409
410#[derive(Debug)]
412pub enum Stream {
413 Tcp(TcpStream),
415 Unix(UnixStream),
417 #[cfg(feature = "turmoil")]
419 Turmoil(turmoil::net::TcpStream),
420}
421
422impl Stream {
423 pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
425 where
426 A: ToSocketAddrs,
427 {
428 let mut last_err = None;
429 for addr in addr.to_socket_addrs().await? {
430 match Stream::connect_addr(addr).await {
431 Ok(stream) => return Ok(stream),
432 Err(e) => last_err = Some(e),
433 }
434 }
435 Err(last_err.unwrap_or_else(|| {
436 io::Error::new(
437 io::ErrorKind::InvalidInput,
438 "could not resolve to any address",
439 )
440 }))
441 }
442
443 async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
444 match addr {
445 SocketAddr::Inet(addr) => {
446 let stream = TcpStream::connect(addr).await?;
447 Ok(Stream::Tcp(stream))
448 }
449 SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
450 let stream = UnixStream::connect(path).await?;
451 Ok(Stream::Unix(stream))
452 }
453 SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
454 io::ErrorKind::Other,
455 "cannot connected to unnamed Unix socket",
456 )),
457 #[cfg(feature = "turmoil")]
458 SocketAddr::Turmoil(addr) => {
459 let stream = turmoil::net::TcpStream::connect(addr).await?;
460 Ok(Stream::Turmoil(stream))
461 }
462 #[cfg(not(feature = "turmoil"))]
463 SocketAddr::Turmoil(_) => panic!("`turmoil` feature not enabled"),
464 }
465 }
466
467 pub fn is_tcp(&self) -> bool {
469 matches!(self, Stream::Tcp(_))
470 }
471
472 pub fn is_unix(&self) -> bool {
474 matches!(self, Stream::Unix(_))
475 }
476
477 pub fn unwrap_tcp(self) -> TcpStream {
483 match self {
484 Stream::Tcp(stream) => stream,
485 Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
486 #[cfg(feature = "turmoil")]
487 Stream::Turmoil(_) => panic!("Stream::unwrap_tcp called on a `turmoil` stream"),
488 }
489 }
490
491 pub fn unwrap_unix(self) -> UnixStream {
497 match self {
498 Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
499 Stream::Unix(stream) => stream,
500 #[cfg(feature = "turmoil")]
501 Stream::Turmoil(_) => panic!("Stream::unwrap_unix called on a `turmoil` stream"),
502 }
503 }
504}
505
506impl AsyncRead for Stream {
507 fn poll_read(
508 self: Pin<&mut Self>,
509 cx: &mut Context,
510 buf: &mut ReadBuf,
511 ) -> Poll<io::Result<()>> {
512 match self.get_mut() {
513 Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
514 Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
515 #[cfg(feature = "turmoil")]
516 Stream::Turmoil(stream) => Pin::new(stream).poll_read(cx, buf),
517 }
518 }
519}
520
521impl AsyncWrite for Stream {
522 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
523 match self.get_mut() {
524 Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
525 Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
526 #[cfg(feature = "turmoil")]
527 Stream::Turmoil(stream) => Pin::new(stream).poll_write(cx, buf),
528 }
529 }
530
531 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
532 match self.get_mut() {
533 Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
534 Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
535 #[cfg(feature = "turmoil")]
536 Stream::Turmoil(stream) => Pin::new(stream).poll_flush(cx),
537 }
538 }
539
540 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
541 match self.get_mut() {
542 Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
543 Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
544 #[cfg(feature = "turmoil")]
545 Stream::Turmoil(stream) => Pin::new(stream).poll_shutdown(cx),
546 }
547 }
548}
549
550impl Connected for Stream {
551 type ConnectInfo = ConnectInfo;
552
553 fn connect_info(&self) -> Self::ConnectInfo {
554 match self {
555 Stream::Tcp(stream) => ConnectInfo::Tcp(stream.connect_info()),
556 Stream::Unix(stream) => ConnectInfo::Unix(stream.connect_info()),
557 #[cfg(feature = "turmoil")]
558 Stream::Turmoil(stream) => ConnectInfo::Turmoil(TcpConnectInfo {
559 local_addr: stream.local_addr().ok(),
560 remote_addr: stream.peer_addr().ok(),
561 }),
562 }
563 }
564}
565
566#[derive(Debug, Clone)]
568pub enum ConnectInfo {
569 Tcp(TcpConnectInfo),
571 Unix(UdsConnectInfo),
573 Turmoil(TcpConnectInfo),
575}
576
577#[cfg(test)]
578mod tests {
579 use std::net::{Ipv4Addr, SocketAddrV4};
580
581 use super::*;
582
583 #[crate::test]
584 fn test_parse() {
585 for (input, expected) in [
586 (
587 "/valid/path",
588 Ok(SocketAddr::Unix(
589 UnixSocketAddr::from_pathname("/valid/path").unwrap(),
590 )),
591 ),
592 (
593 "/",
594 Ok(SocketAddr::Unix(
595 UnixSocketAddr::from_pathname("/").unwrap(),
596 )),
597 ),
598 (
599 "/\0",
600 Err(
601 "invalid unix socket address syntax: paths must not contain interior null bytes",
602 ),
603 ),
604 (
605 "1.2.3.4:5678",
606 Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(
607 Ipv4Addr::new(1, 2, 3, 4),
608 5678,
609 )))),
610 ),
611 ("1.2.3.4", Err("invalid internet socket address syntax")),
612 ("bad", Err("invalid internet socket address syntax")),
613 (
614 "turmoil:1.2.3.4:5678",
615 Ok(SocketAddr::Turmoil("1.2.3.4:5678".into())),
616 ),
617 ] {
618 let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
619 let expected = expected.map_err(|e| e.to_string());
620 assert_eq!(actual, expected, "input: {}", input);
621 }
622 }
623}