Skip to main content

turmoil/
dns.rs

1use indexmap::IndexMap;
2#[cfg(feature = "regex")]
3use regex::Regex;
4use std::io;
5use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
6
7use crate::ip::IpVersionAddrIter;
8
9/// Each new host has an IP in the subnet defined by the
10/// ip version of the simulation.
11///
12/// Ipv4 simulations use the subnet 192.168.0.0/16.
13/// Ipv6 simulations use the link local subnet fe80:::/64
14pub struct Dns {
15    addrs: IpVersionAddrIter,
16    names: IndexMap<String, IpAddr>,
17}
18
19/// Converts or resolves to an [`IpAddr`].
20pub trait ToIpAddr: sealed::Sealed {
21    #[doc(hidden)]
22    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr;
23}
24
25/// Converts or resolves to one or more [`IpAddr`] values.
26pub trait ToIpAddrs: sealed::Sealed {
27    #[doc(hidden)]
28    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr>;
29}
30
31/// A simulated version of `tokio::net::ToSocketAddrs`.
32pub trait ToSocketAddrs: sealed::Sealed {
33    #[doc(hidden)]
34    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr>;
35}
36
37impl Dns {
38    pub(crate) fn new(addrs: IpVersionAddrIter) -> Dns {
39        Dns {
40            addrs,
41            names: IndexMap::new(),
42        }
43    }
44
45    pub(crate) fn lookup(&mut self, addr: impl ToIpAddr) -> IpAddr {
46        addr.to_ip_addr(self)
47    }
48
49    pub(crate) fn lookup_many(&mut self, addrs: impl ToIpAddrs) -> Vec<IpAddr> {
50        addrs.to_ip_addrs(self)
51    }
52
53    pub(crate) fn reverse(&self, addr: IpAddr) -> Option<&str> {
54        self.names
55            .iter()
56            .find(|(_, a)| **a == addr)
57            .map(|(name, _)| name.as_str())
58    }
59}
60
61impl ToIpAddr for String {
62    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
63        (&self[..]).to_ip_addr(dns)
64    }
65}
66
67impl ToIpAddr for &str {
68    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
69        if let Ok(ipaddr) = self.parse() {
70            return ipaddr;
71        }
72
73        *dns.names
74            .entry(self.to_string())
75            .or_insert_with(|| dns.addrs.next())
76    }
77}
78
79impl ToIpAddr for IpAddr {
80    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
81        *self
82    }
83}
84
85impl ToIpAddr for Ipv4Addr {
86    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
87        IpAddr::V4(*self)
88    }
89}
90
91impl ToIpAddr for Ipv6Addr {
92    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
93        IpAddr::V6(*self)
94    }
95}
96
97impl<T> ToIpAddrs for T
98where
99    T: ToIpAddr,
100{
101    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
102        vec![self.to_ip_addr(dns)]
103    }
104}
105
106#[cfg(feature = "regex")]
107impl ToIpAddrs for Regex {
108    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
109        #[allow(clippy::needless_collect)]
110        let hosts = dns.names.keys().cloned().collect::<Vec<_>>();
111        hosts
112            .into_iter()
113            .filter_map(|h| self.is_match(&h).then(|| h.to_ip_addr(dns)))
114            .collect::<Vec<_>>()
115    }
116}
117
118// Hostname and port
119impl ToSocketAddrs for (String, u16) {
120    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
121        (&self.0[..], self.1).to_socket_addr(dns)
122    }
123}
124
125impl ToSocketAddrs for (&str, u16) {
126    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
127        // When IP address is passed directly as a str.
128        if let Ok(ip) = self.0.parse::<IpAddr>() {
129            return Ok((ip, self.1).into());
130        }
131
132        match dns.names.get(self.0) {
133            Some(ip) => Ok((*ip, self.1).into()),
134            None => Err(io::Error::new(
135                io::ErrorKind::NotFound,
136                format!("no ip address found for a hostname: {}", self.0),
137            )),
138        }
139    }
140}
141
142impl ToSocketAddrs for SocketAddr {
143    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
144        Ok(*self)
145    }
146}
147
148impl ToSocketAddrs for SocketAddrV4 {
149    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
150        Ok(SocketAddr::V4(*self))
151    }
152}
153
154impl ToSocketAddrs for SocketAddrV6 {
155    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
156        Ok(SocketAddr::V6(*self))
157    }
158}
159
160impl ToSocketAddrs for (IpAddr, u16) {
161    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
162        Ok((*self).into())
163    }
164}
165
166impl ToSocketAddrs for (Ipv4Addr, u16) {
167    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
168        Ok((*self).into())
169    }
170}
171
172impl ToSocketAddrs for (Ipv6Addr, u16) {
173    fn to_socket_addr(&self, _: &Dns) -> io::Result<SocketAddr> {
174        Ok((*self).into())
175    }
176}
177
178impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {
179    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
180        (**self).to_socket_addr(dns)
181    }
182}
183
184impl ToSocketAddrs for str {
185    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
186        let socketaddr: Result<SocketAddr, _> = self.parse();
187
188        if let Ok(s) = socketaddr {
189            return Ok(s);
190        }
191
192        // Borrowed from std
193        // https://github.com/rust-lang/rust/blob/1b225414f325593f974c6b41e671a0a0dc5d7d5e/library/std/src/sys_common/net.rs#L175
194        macro_rules! try_opt {
195            ($e:expr, $msg:expr) => {
196                match $e {
197                    Some(r) => r,
198                    None => return Err(io::Error::new(io::ErrorKind::InvalidInput, $msg)),
199                }
200            };
201        }
202
203        // split the string by ':' and convert the second part to u16
204        let (host, port_str) = try_opt!(self.rsplit_once(':'), "invalid socket address");
205        let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
206
207        (host, port).to_socket_addr(dns)
208    }
209}
210
211impl ToSocketAddrs for String {
212    fn to_socket_addr(&self, dns: &Dns) -> io::Result<SocketAddr> {
213        self.as_str().to_socket_addr(dns)
214    }
215}
216
217mod sealed {
218
219    pub trait Sealed {}
220
221    impl<T: ?Sized> Sealed for T {}
222}
223
224#[cfg(test)]
225mod tests {
226    use crate::{dns::Dns, ip::IpVersionAddrIter, ToSocketAddrs};
227    use std::net::Ipv4Addr;
228
229    #[test]
230    fn parse_str() {
231        let mut dns = Dns::new(IpVersionAddrIter::default());
232        let generated_addr = dns.lookup("foo");
233
234        let hostname_port = "foo:5000";
235        let ipv4_port = "127.0.0.1:5000";
236        let ipv6_port = "[::1]:5000";
237
238        assert_eq!(
239            hostname_port.to_socket_addr(&dns).unwrap(),
240            format!("{generated_addr}:5000").parse().unwrap()
241        );
242        assert_eq!(
243            ipv4_port.to_socket_addr(&dns).unwrap(),
244            ipv4_port.parse().unwrap()
245        );
246        assert_eq!(
247            ipv6_port.to_socket_addr(&dns).unwrap(),
248            ipv6_port.parse().unwrap()
249        );
250    }
251
252    #[test]
253    fn raw_value_parsing() {
254        // lookups of raw ip addrs should be consistent
255        // between to_ip_addr() and to_socket_addr()
256        // for &str and IpAddr
257        let mut dns = Dns::new(IpVersionAddrIter::default());
258        let addr = dns.lookup(Ipv4Addr::new(192, 168, 2, 2));
259        assert_eq!(addr, Ipv4Addr::new(192, 168, 2, 2));
260
261        let addr = dns.lookup("192.168.3.3");
262        assert_eq!(addr, Ipv4Addr::new(192, 168, 3, 3));
263
264        let addr = "192.168.3.3:0".to_socket_addr(&dns).unwrap();
265        assert_eq!(addr.ip(), Ipv4Addr::new(192, 168, 3, 3));
266    }
267}