use std::error::Error;
use std::net::SocketAddr as InetSocketAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::task::{ready, Context, Poll};
use std::{fmt, io};
use async_trait::async_trait;
use tokio::fs;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::{self, TcpListener, TcpStream, UnixListener, UnixStream};
use tonic::transport::server::{Connected, TcpConnectInfo, UdsConnectInfo};
use tracing::warn;
use crate::error::ErrorExt;
#[derive(Debug, Clone, Copy)]
pub enum SocketAddrType {
Inet,
Unix,
}
impl SocketAddrType {
pub fn guess(s: &str) -> SocketAddrType {
match s.starts_with('/') {
true => SocketAddrType::Unix,
false => SocketAddrType::Inet,
}
}
}
#[derive(Debug, Clone)]
pub enum SocketAddr {
Inet(InetSocketAddr),
Unix(UnixSocketAddr),
}
impl PartialEq for SocketAddr {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(SocketAddr::Inet(addr1), SocketAddr::Inet(addr2)) => addr1 == addr2,
(
SocketAddr::Unix(UnixSocketAddr { path: Some(path1) }),
SocketAddr::Unix(UnixSocketAddr { path: Some(path2) }),
) => path1 == path2,
_ => false,
}
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
SocketAddr::Inet(addr) => addr.fmt(f),
SocketAddr::Unix(addr) => addr.fmt(f),
}
}
}
impl FromStr for SocketAddr {
type Err = AddrParseError;
fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
match SocketAddrType::guess(s) {
SocketAddrType::Unix => {
let addr = UnixSocketAddr::from_pathname(s).map_err(|e| AddrParseError {
kind: AddrParseErrorKind::Unix(e),
})?;
Ok(SocketAddr::Unix(addr))
}
SocketAddrType::Inet => {
let addr = s.parse().map_err(|_| AddrParseError {
kind: AddrParseErrorKind::Inet,
})?;
Ok(SocketAddr::Inet(addr))
}
}
}
}
#[derive(Debug, Clone)]
pub struct UnixSocketAddr {
path: Option<String>,
}
impl UnixSocketAddr {
pub fn from_pathname<S>(path: S) -> Result<UnixSocketAddr, io::Error>
where
S: Into<String>,
{
let path = path.into();
let _ = std::os::unix::net::SocketAddr::from_pathname(&path)?;
Ok(UnixSocketAddr { path: Some(path) })
}
pub fn unnamed() -> UnixSocketAddr {
UnixSocketAddr { path: None }
}
pub fn as_pathname(&self) -> Option<&str> {
self.path.as_deref()
}
}
impl fmt::Display for UnixSocketAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.path {
None => f.write_str("<unnamed>"),
Some(path) => f.write_str(path),
}
}
}
#[derive(Debug)]
pub struct AddrParseError {
kind: AddrParseErrorKind,
}
#[derive(Debug)]
pub enum AddrParseErrorKind {
Inet,
Unix(io::Error),
}
impl fmt::Display for AddrParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.kind {
AddrParseErrorKind::Inet => f.write_str("invalid internet socket address syntax"),
AddrParseErrorKind::Unix(e) => {
f.write_str("invalid unix socket address syntax: ")?;
e.fmt(f)
}
}
}
}
impl Error for AddrParseError {}
#[async_trait]
pub trait ToSocketAddrs {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error>;
}
#[async_trait]
impl ToSocketAddrs for SocketAddr {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
Ok(vec![self.clone()])
}
}
#[async_trait]
impl<'a> ToSocketAddrs for &'a [SocketAddr] {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
Ok(self.to_vec())
}
}
#[async_trait]
impl ToSocketAddrs for InetSocketAddr {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
Ok(vec![SocketAddr::Inet(*self)])
}
}
#[async_trait]
impl ToSocketAddrs for UnixSocketAddr {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
Ok(vec![SocketAddr::Unix(self.clone())])
}
}
#[async_trait]
impl ToSocketAddrs for str {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
match self.parse() {
Ok(addr) => Ok(vec![addr]),
Err(_) => {
let addrs = net::lookup_host(self).await?;
Ok(addrs.map(SocketAddr::Inet).collect())
}
}
}
}
#[async_trait]
impl ToSocketAddrs for String {
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
(**self).to_socket_addrs().await
}
}
#[async_trait]
impl<T> ToSocketAddrs for &T
where
T: ToSocketAddrs + Send + Sync + ?Sized,
{
async fn to_socket_addrs(&self) -> Result<Vec<SocketAddr>, io::Error> {
(**self).to_socket_addrs().await
}
}
#[derive(Debug)]
pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}
impl Listener {
pub async fn bind<A>(addr: A) -> Result<Listener, io::Error>
where
A: ToSocketAddrs,
{
let mut last_err = None;
for addr in addr.to_socket_addrs().await? {
match Listener::bind_addr(addr).await {
Ok(listener) => return Ok(listener),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
async fn bind_addr(addr: SocketAddr) -> Result<Listener, io::Error> {
match &addr {
SocketAddr::Inet(addr) => {
let listener = TcpListener::bind(addr).await?;
Ok(Listener::Tcp(listener))
}
SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
if let Err(e) = fs::remove_file(path).await {
if e.kind() != io::ErrorKind::NotFound {
warn!(
"unable to remove {path} while binding unix domain socket: {}",
e.display_with_causes(),
);
}
}
let listener = UnixListener::bind(path)?;
Ok(Listener::Unix(listener))
}
SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
io::ErrorKind::Other,
"cannot bind to unnamed Unix socket",
)),
}
}
pub async fn accept(&self) -> Result<(Stream, SocketAddr), io::Error> {
match self {
Listener::Tcp(listener) => {
let (stream, addr) = listener.accept().await?;
stream.set_nodelay(true)?;
let stream = Stream::Tcp(stream);
let addr = SocketAddr::Inet(addr);
Ok((stream, addr))
}
Listener::Unix(listener) => {
let (stream, addr) = listener.accept().await?;
let stream = Stream::Unix(stream);
assert!(addr.is_unnamed());
let addr = SocketAddr::Unix(UnixSocketAddr::unnamed());
Ok((stream, addr))
}
}
}
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Stream, io::Error>>> {
match self.get_mut() {
Listener::Tcp(listener) => {
let (stream, _addr) = ready!(listener.poll_accept(cx))?;
stream.set_nodelay(true)?;
Poll::Ready(Some(Ok(Stream::Tcp(stream))))
}
Listener::Unix(listener) => {
let (stream, _addr) = ready!(listener.poll_accept(cx))?;
Poll::Ready(Some(Ok(Stream::Unix(stream))))
}
}
}
}
impl futures::stream::Stream for Listener {
type Item = Result<Stream, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_accept(cx)
}
}
#[derive(Debug)]
pub enum Stream {
Tcp(TcpStream),
Unix(UnixStream),
}
impl Stream {
pub async fn connect<A>(addr: A) -> Result<Stream, io::Error>
where
A: ToSocketAddrs,
{
let mut last_err = None;
for addr in addr.to_socket_addrs().await? {
match Stream::connect_addr(addr).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"could not resolve to any address",
)
}))
}
async fn connect_addr(addr: SocketAddr) -> Result<Stream, io::Error> {
match addr {
SocketAddr::Inet(addr) => {
let stream = TcpStream::connect(addr).await?;
Ok(Stream::Tcp(stream))
}
SocketAddr::Unix(UnixSocketAddr { path: Some(path) }) => {
let stream = UnixStream::connect(path).await?;
Ok(Stream::Unix(stream))
}
SocketAddr::Unix(UnixSocketAddr { path: None }) => Err(io::Error::new(
io::ErrorKind::Other,
"cannot connected to unnamed Unix socket",
)),
}
}
pub fn is_tcp(&self) -> bool {
matches!(self, Stream::Tcp(_))
}
pub fn is_unix(&self) -> bool {
matches!(self, Stream::Unix(_))
}
pub fn unwrap_tcp(self) -> TcpStream {
match self {
Stream::Tcp(stream) => stream,
Stream::Unix(_) => panic!("Stream::unwrap_tcp called on a Unix stream"),
}
}
pub fn unwrap_unix(self) -> UnixStream {
match self {
Stream::Tcp(_) => panic!("Stream::unwrap_unix called on a TCP stream"),
Stream::Unix(stream) => stream,
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
match self.get_mut() {
Stream::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
Stream::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for Stream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
match self.get_mut() {
Stream::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
Stream::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
Stream::Tcp(stream) => Pin::new(stream).poll_flush(cx),
Stream::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
Stream::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
Stream::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl Connected for Stream {
type ConnectInfo = ConnectInfo;
fn connect_info(&self) -> Self::ConnectInfo {
match self {
Stream::Tcp(stream) => ConnectInfo::Tcp(stream.connect_info()),
Stream::Unix(stream) => ConnectInfo::Unix(stream.connect_info()),
}
}
}
#[derive(Debug, Clone)]
pub enum ConnectInfo {
Tcp(TcpConnectInfo),
Unix(UdsConnectInfo),
}
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, SocketAddrV4};
use super::*;
#[crate::test]
fn test_parse() {
for (input, expected) in [
("/valid/path", Ok(SocketAddr::Unix(UnixSocketAddr::from_pathname("/valid/path").unwrap()))),
("/", Ok(SocketAddr::Unix(UnixSocketAddr::from_pathname("/").unwrap()))),
("/\0", Err("invalid unix socket address syntax: paths must not contain interior null bytes")),
("1.2.3.4:5678", Ok(SocketAddr::Inet(InetSocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), 5678))))),
("1.2.3.4", Err("invalid internet socket address syntax")),
("bad", Err("invalid internet socket address syntax")),
] {
let actual = SocketAddr::from_str(input).map_err(|e| e.to_string());
let expected = expected.map_err(|e| e.to_string());
assert_eq!(actual, expected, "input: {}", input);
}
}
}