turmoil/
dns.rs

1use indexmap::IndexMap;
2#[cfg(feature = "regex")]
3use regex::Regex;
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5
6use crate::ip::IpVersionAddrIter;
7
8/// Each new host has an IP in the subnet defined by the
9/// ip version of the simulation.
10///
11/// Ipv4 simulations use the subnet 192.168.0.0/16.
12/// Ipv6 simulations use the link local subnet fe80:::/64
13pub struct Dns {
14    addrs: IpVersionAddrIter,
15    names: IndexMap<String, IpAddr>,
16}
17
18/// Converts or resolves to an [`IpAddr`].
19pub trait ToIpAddr: sealed::Sealed {
20    #[doc(hidden)]
21    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr;
22}
23
24/// Converts or resolves to one or more [`IpAddr`] values.
25pub trait ToIpAddrs: sealed::Sealed {
26    #[doc(hidden)]
27    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr>;
28}
29
30/// A simulated version of `tokio::net::ToSocketAddrs`.
31pub trait ToSocketAddrs: sealed::Sealed {
32    #[doc(hidden)]
33    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr;
34}
35
36impl Dns {
37    pub(crate) fn new(addrs: IpVersionAddrIter) -> Dns {
38        Dns {
39            addrs,
40            names: IndexMap::new(),
41        }
42    }
43
44    pub(crate) fn lookup(&mut self, addr: impl ToIpAddr) -> IpAddr {
45        addr.to_ip_addr(self)
46    }
47
48    pub(crate) fn lookup_many(&mut self, addrs: impl ToIpAddrs) -> Vec<IpAddr> {
49        addrs.to_ip_addrs(self)
50    }
51
52    pub(crate) fn reverse(&self, addr: IpAddr) -> Option<&str> {
53        self.names
54            .iter()
55            .find(|(_, a)| **a == addr)
56            .map(|(name, _)| name.as_str())
57    }
58}
59
60impl ToIpAddr for String {
61    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
62        (&self[..]).to_ip_addr(dns)
63    }
64}
65
66impl ToIpAddr for &str {
67    fn to_ip_addr(&self, dns: &mut Dns) -> IpAddr {
68        if let Ok(ipaddr) = self.parse() {
69            return ipaddr;
70        }
71
72        *dns.names
73            .entry(self.to_string())
74            .or_insert_with(|| dns.addrs.next())
75    }
76}
77
78impl ToIpAddr for IpAddr {
79    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
80        *self
81    }
82}
83
84impl ToIpAddr for Ipv4Addr {
85    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
86        IpAddr::V4(*self)
87    }
88}
89
90impl ToIpAddr for Ipv6Addr {
91    fn to_ip_addr(&self, _: &mut Dns) -> IpAddr {
92        IpAddr::V6(*self)
93    }
94}
95
96impl<T> ToIpAddrs for T
97where
98    T: ToIpAddr,
99{
100    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
101        vec![self.to_ip_addr(dns)]
102    }
103}
104
105#[cfg(feature = "regex")]
106impl ToIpAddrs for Regex {
107    fn to_ip_addrs(&self, dns: &mut Dns) -> Vec<IpAddr> {
108        #[allow(clippy::needless_collect)]
109        let hosts = dns.names.keys().cloned().collect::<Vec<_>>();
110        hosts
111            .into_iter()
112            .filter_map(|h| self.is_match(&h).then(|| h.to_ip_addr(dns)))
113            .collect::<Vec<_>>()
114    }
115}
116
117// Hostname and port
118impl ToSocketAddrs for (String, u16) {
119    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
120        (&self.0[..], self.1).to_socket_addr(dns)
121    }
122}
123
124impl ToSocketAddrs for (&str, u16) {
125    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
126        // When IP address is passed directly as a str.
127        if let Ok(ip) = self.0.parse::<IpAddr>() {
128            return (ip, self.1).into();
129        }
130
131        match dns.names.get(self.0) {
132            Some(ip) => (*ip, self.1).into(),
133            None => panic!("no ip address found for a hostname: {}", self.0),
134        }
135    }
136}
137
138impl ToSocketAddrs for SocketAddr {
139    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
140        *self
141    }
142}
143
144impl ToSocketAddrs for SocketAddrV4 {
145    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
146        SocketAddr::V4(*self)
147    }
148}
149
150impl ToSocketAddrs for SocketAddrV6 {
151    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
152        SocketAddr::V6(*self)
153    }
154}
155
156impl ToSocketAddrs for (IpAddr, u16) {
157    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
158        (*self).into()
159    }
160}
161
162impl ToSocketAddrs for (Ipv4Addr, u16) {
163    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
164        (*self).into()
165    }
166}
167
168impl ToSocketAddrs for (Ipv6Addr, u16) {
169    fn to_socket_addr(&self, _: &Dns) -> SocketAddr {
170        (*self).into()
171    }
172}
173
174impl<T: ToSocketAddrs + ?Sized> ToSocketAddrs for &T {
175    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
176        (**self).to_socket_addr(dns)
177    }
178}
179
180impl ToSocketAddrs for str {
181    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
182        let socketaddr: Result<SocketAddr, _> = self.parse();
183
184        if let Ok(s) = socketaddr {
185            return s;
186        }
187
188        // Borrowed from std
189        // https://github.com/rust-lang/rust/blob/1b225414f325593f974c6b41e671a0a0dc5d7d5e/library/std/src/sys_common/net.rs#L175
190        macro_rules! try_opt {
191            ($e:expr, $msg:expr) => {
192                match $e {
193                    Some(r) => r,
194                    None => panic!("Unable to parse dns: {}", $msg),
195                }
196            };
197        }
198
199        // split the string by ':' and convert the second part to u16
200        let (host, port_str) = try_opt!(self.rsplit_once(':'), "invalid socket address");
201        let port: u16 = try_opt!(port_str.parse().ok(), "invalid port value");
202
203        (host, port).to_socket_addr(dns)
204    }
205}
206
207impl ToSocketAddrs for String {
208    fn to_socket_addr(&self, dns: &Dns) -> SocketAddr {
209        self.as_str().to_socket_addr(dns)
210    }
211}
212
213mod sealed {
214
215    pub trait Sealed {}
216
217    impl<T: ?Sized> Sealed for T {}
218}
219
220#[cfg(test)]
221mod tests {
222    use crate::{dns::Dns, ip::IpVersionAddrIter, ToSocketAddrs};
223    use std::net::Ipv4Addr;
224
225    #[test]
226    fn parse_str() {
227        let mut dns = Dns::new(IpVersionAddrIter::default());
228        let generated_addr = dns.lookup("foo");
229
230        let hostname_port = "foo:5000".to_socket_addr(&dns);
231        let ipv4_port = "127.0.0.1:5000";
232        let ipv6_port = "[::1]:5000";
233
234        assert_eq!(
235            hostname_port,
236            format!("{generated_addr}:5000").parse().unwrap()
237        );
238        assert_eq!(ipv4_port.to_socket_addr(&dns), ipv4_port.parse().unwrap());
239        assert_eq!(ipv6_port.to_socket_addr(&dns), ipv6_port.parse().unwrap());
240    }
241
242    #[test]
243    fn raw_value_parsing() {
244        // lookups of raw ip addrs should be consistent
245        // between to_ip_addr() and to_socket_addr()
246        // for &str and IpAddr
247        let mut dns = Dns::new(IpVersionAddrIter::default());
248        let addr = dns.lookup(Ipv4Addr::new(192, 168, 2, 2));
249        assert_eq!(addr, Ipv4Addr::new(192, 168, 2, 2));
250
251        let addr = dns.lookup("192.168.3.3");
252        assert_eq!(addr, Ipv4Addr::new(192, 168, 3, 3));
253
254        let addr = "192.168.3.3:0".to_socket_addr(&dns);
255        assert_eq!(addr.ip(), Ipv4Addr::new(192, 168, 3, 3));
256    }
257}