turmoil/
host.rs

1use crate::envelope::{hex, Datagram, Protocol, Segment, Syn};
2use crate::net::{SocketPair, TcpListener, UdpSocket};
3use crate::{Envelope, TRACING_TARGET};
4
5use bytes::Bytes;
6use indexmap::IndexMap;
7use std::collections::VecDeque;
8use std::fmt::Display;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::ops::RangeInclusive;
12use std::sync::Arc;
13use tokio::sync::{mpsc, Notify};
14use tokio::time::{Duration, Instant};
15
16const DEFAULT_BROADCAST: bool = false;
17const DEFAULT_MULTICAST_LOOP: bool = true;
18
19/// A host in the simulated network.
20///
21/// Hosts have [`Udp`] and [`Tcp`] software available for networking.
22///
23/// Both modes may be used simultaneously.
24pub(crate) struct Host {
25    /// Host name.
26    pub(crate) nodename: String,
27
28    /// Host ip address.
29    pub(crate) addr: IpAddr,
30
31    /// Tracks elapsed time for host and overall simulation.
32    pub(crate) timer: HostTimer,
33
34    /// L4 User Datagram Protocol (UDP).
35    pub(crate) udp: Udp,
36
37    /// L4 Transmission Control Protocol (TCP).
38    pub(crate) tcp: Tcp,
39
40    next_ephemeral_port: u16,
41    ephemeral_ports: RangeInclusive<u16>,
42}
43
44impl Host {
45    pub(crate) fn new(
46        nodename: impl Into<String>,
47        addr: IpAddr,
48        timer: HostTimer,
49        ephemeral_ports: RangeInclusive<u16>,
50        tcp_capacity: usize,
51        udp_capacity: usize,
52    ) -> Host {
53        Host {
54            nodename: nodename.into(),
55            addr,
56            udp: Udp::new(udp_capacity),
57            tcp: Tcp::new(tcp_capacity),
58            timer,
59            next_ephemeral_port: *ephemeral_ports.start(),
60            ephemeral_ports,
61        }
62    }
63
64    pub(crate) fn assign_ephemeral_port(&mut self) -> u16 {
65        for _ in self.ephemeral_ports.clone() {
66            let ret = self.next_ephemeral_port;
67
68            if self.next_ephemeral_port == *self.ephemeral_ports.end() {
69                // re-load
70                self.next_ephemeral_port = *self.ephemeral_ports.start();
71            } else {
72                // advance
73                self.next_ephemeral_port += 1;
74            }
75
76            // Check for existing binds and connections to avoid port conflicts
77            if self.udp.is_port_assigned(ret) || self.tcp.is_port_assigned(ret) {
78                continue;
79            }
80
81            return ret;
82        }
83
84        panic!("Host: '{}' ports exhausted", self.nodename)
85    }
86
87    /// Receive the `envelope` from the network.
88    ///
89    /// Returns an Err if a message needs to be sent in response to a failed
90    /// delivery, e.g. TCP RST.
91    // FIXME: This funkiness is necessary due to how message sending works. The
92    // key problem is that the Host doesn't actually send messages, rather the
93    // World is borrowed, and it sends.
94    pub(crate) fn receive_from_network(&mut self, envelope: Envelope) -> Result<(), Protocol> {
95        let Envelope { src, dst, message } = envelope;
96
97        tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %message, "Delivered");
98
99        match message {
100            Protocol::Tcp(segment) => self.tcp.receive_from_network(src, dst, segment),
101            Protocol::Udp(datagram) => {
102                self.udp.receive_from_network(src, dst, datagram);
103                Ok(())
104            }
105        }
106    }
107}
108
109pub(crate) struct HostTimer {
110    /// Host elapsed time.
111    elapsed: Duration,
112
113    /// Set each time the software is run.
114    now: Option<Instant>,
115
116    /// Time from the start of the simulation until this host was initialized.
117    /// Used to calculate total simulation time below.
118    start_offset: Duration,
119}
120
121impl HostTimer {
122    pub(crate) fn new(start_offset: Duration) -> Self {
123        Self {
124            elapsed: Duration::ZERO,
125            now: None,
126            start_offset,
127        }
128    }
129
130    pub(crate) fn tick(&mut self, duration: Duration) {
131        self.elapsed += duration
132    }
133
134    /// Set a new `Instant` for each iteration of the simulation. `elapsed` is
135    /// updated after each iteration via `tick()`, where as this value is
136    /// necessary to accurately calculate elapsed time while the software is
137    /// running.
138    ///
139    /// This is required to track logical time across host restarts as a single
140    /// `Instant` resets when the tokio runtime is recreated.
141    pub(crate) fn now(&mut self, now: Instant) {
142        self.now.replace(now);
143    }
144
145    /// Returns how long the host has been executing for in virtual time.
146    pub(crate) fn elapsed(&self) -> Duration {
147        let run_duration = self.now.expect("host instant not set").elapsed();
148        self.elapsed + run_duration
149    }
150
151    /// Returns the total simulation virtual time. If the simulation ran for
152    /// some time before this host was registered then sim_elapsed > elapsed.
153    pub(crate) fn sim_elapsed(&self) -> Duration {
154        self.start_offset + self.elapsed()
155    }
156}
157
158/// Simulated UDP host software.
159pub(crate) struct Udp {
160    /// Bound udp sockets
161    binds: IndexMap<u16, UdpBind>,
162
163    /// UdpSocket channel capacity
164    capacity: usize,
165}
166
167struct UdpBind {
168    bind_addr: SocketAddr,
169    broadcast: bool,
170    multicast_loop: bool,
171    queue: mpsc::Sender<(Datagram, SocketAddr)>,
172}
173
174impl Udp {
175    fn new(capacity: usize) -> Self {
176        Self {
177            binds: IndexMap::new(),
178            capacity,
179        }
180    }
181
182    pub(crate) fn is_port_assigned(&self, port: u16) -> bool {
183        self.binds.keys().any(|p| *p == port)
184    }
185
186    pub(crate) fn is_broadcast_enabled(&self, port: u16) -> bool {
187        self.binds
188            .get(&port)
189            .map(|bind| bind.broadcast)
190            .unwrap_or(DEFAULT_BROADCAST)
191    }
192
193    pub(crate) fn is_multicast_loop_enabled(&self, port: u16) -> bool {
194        self.binds
195            .get(&port)
196            .map(|bind| bind.multicast_loop)
197            .unwrap_or(DEFAULT_MULTICAST_LOOP)
198    }
199
200    pub(crate) fn set_broadcast(&mut self, port: u16, on: bool) {
201        self.binds
202            .entry(port)
203            .and_modify(|bind| bind.broadcast = on);
204    }
205
206    pub(crate) fn set_multicast_loop(&mut self, port: u16, on: bool) {
207        self.binds
208            .entry(port)
209            .and_modify(|bind| bind.multicast_loop = on);
210    }
211
212    pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<UdpSocket> {
213        let (tx, rx) = mpsc::channel(self.capacity);
214        let bind = UdpBind {
215            bind_addr: addr,
216            broadcast: DEFAULT_BROADCAST,
217            multicast_loop: DEFAULT_MULTICAST_LOOP,
218            queue: tx,
219        };
220
221        match self.binds.entry(addr.port()) {
222            indexmap::map::Entry::Occupied(_) => {
223                return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
224            }
225            indexmap::map::Entry::Vacant(entry) => entry.insert(bind),
226        };
227
228        tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Bind");
229
230        Ok(UdpSocket::new(addr, rx))
231    }
232
233    fn receive_from_network(&mut self, src: SocketAddr, dst: SocketAddr, datagram: Datagram) {
234        if let Some(bind) = self.binds.get_mut(&dst.port()) {
235            if !matches(bind.bind_addr, dst) {
236                tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Addr not bound)");
237                return;
238            }
239            if let Err(err) = bind.queue.try_send((datagram, src)) {
240                // drop any packets that exceed the capacity
241                match err {
242                    mpsc::error::TrySendError::Full((datagram, _)) => {
243                        tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Full buffer)");
244                    }
245                    mpsc::error::TrySendError::Closed((datagram, _)) => {
246                        tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Receiver closed)");
247                    }
248                }
249            }
250        }
251    }
252
253    pub(crate) fn unbind(&mut self, addr: SocketAddr) {
254        let exists = self.binds.swap_remove(&addr.port());
255
256        assert!(exists.is_some(), "unknown bind {addr}");
257
258        tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Unbind");
259    }
260}
261
262pub(crate) struct Tcp {
263    /// Bound server sockets
264    binds: IndexMap<u16, ServerSocket>,
265
266    /// TcpListener channel capacity
267    server_socket_capacity: usize,
268
269    /// Active stream sockets
270    sockets: IndexMap<SocketPair, StreamSocket>,
271
272    /// TcpStream channel capacity
273    socket_capacity: usize,
274}
275
276struct ServerSocket {
277    bind_addr: SocketAddr,
278
279    /// Notify the TcpListener when SYNs are delivered
280    notify: Arc<Notify>,
281
282    /// Pending connections for the TcpListener to accept
283    deque: VecDeque<(Syn, SocketAddr)>,
284}
285
286struct StreamSocket {
287    local_addr: SocketAddr,
288    buf: IndexMap<u64, SequencedSegment>,
289    next_send_seq: u64,
290    recv_seq: u64,
291    sender: mpsc::Sender<SequencedSegment>,
292    /// A simple reference counter for tracking read/write half drops. Once 0, the
293    /// socket may be removed from the host.
294    ref_ct: usize,
295}
296
297/// Stripped down version of [`Segment`] for delivery out to the application
298/// layer.
299#[derive(Debug)]
300pub(crate) enum SequencedSegment {
301    Data(Bytes),
302    Fin,
303}
304
305impl Display for SequencedSegment {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        match self {
308            SequencedSegment::Data(data) => hex("TCP", data, f),
309            SequencedSegment::Fin => write!(f, "TCP FIN"),
310        }
311    }
312}
313
314impl StreamSocket {
315    fn new(local_addr: SocketAddr, capacity: usize) -> (Self, mpsc::Receiver<SequencedSegment>) {
316        let (tx, rx) = mpsc::channel(capacity);
317        let sock = Self {
318            local_addr,
319            buf: IndexMap::new(),
320            next_send_seq: 1,
321            recv_seq: 0,
322            sender: tx,
323            ref_ct: 2,
324        };
325
326        (sock, rx)
327    }
328
329    fn assign_seq(&mut self) -> u64 {
330        let seq = self.next_send_seq;
331        self.next_send_seq += 1;
332        seq
333    }
334
335    // Buffer and re-order received segments by `seq` as the network may deliver
336    // them out of order.
337    fn buffer(&mut self, seq: u64, segment: SequencedSegment) -> Result<(), Protocol> {
338        use mpsc::error::TrySendError::*;
339
340        let exists = self.buf.insert(seq, segment);
341
342        assert!(exists.is_none(), "duplicate segment {seq}");
343
344        while self.buf.contains_key(&(self.recv_seq + 1)) {
345            self.recv_seq += 1;
346
347            let segment = self.buf.swap_remove(&self.recv_seq).unwrap();
348            self.sender.try_send(segment).map_err(|e| match e {
349                Closed(_) => Protocol::Tcp(Segment::Rst),
350                Full(_) => panic!("{} socket buffer full", self.local_addr),
351            })?;
352        }
353
354        Ok(())
355    }
356}
357
358impl Tcp {
359    fn new(capacity: usize) -> Self {
360        Self {
361            binds: IndexMap::new(),
362            sockets: IndexMap::new(),
363            server_socket_capacity: capacity,
364            socket_capacity: capacity,
365        }
366    }
367
368    fn is_port_assigned(&self, port: u16) -> bool {
369        self.binds.keys().any(|p| *p == port) || self.sockets.keys().any(|a| a.local.port() == port)
370    }
371
372    pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<TcpListener> {
373        let notify = Arc::new(Notify::new());
374        let sock = ServerSocket {
375            bind_addr: addr,
376            notify: notify.clone(),
377            deque: VecDeque::new(),
378        };
379
380        match self.binds.entry(addr.port()) {
381            indexmap::map::Entry::Occupied(_) => {
382                return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
383            }
384            indexmap::map::Entry::Vacant(entry) => entry.insert(sock),
385        };
386
387        tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Bind");
388
389        Ok(TcpListener::new(addr, notify))
390    }
391
392    pub(crate) fn new_stream(&mut self, pair: SocketPair) -> mpsc::Receiver<SequencedSegment> {
393        let (sock, rx) = StreamSocket::new(pair.local, self.socket_capacity);
394
395        let exists = self.sockets.insert(pair, sock);
396
397        assert!(exists.is_none(), "{pair:?} is already connected");
398
399        rx
400    }
401
402    pub(crate) fn stream_count(&self) -> usize {
403        self.sockets.len()
404    }
405
406    pub(crate) fn accept(&mut self, addr: SocketAddr) -> Option<(Syn, SocketAddr)> {
407        self.binds[&addr.port()].deque.pop_front()
408    }
409
410    // Ideally, we could "write through" the tcp software, but this is necessary
411    // due to borrowing the world to access the mut host and for sending.
412    pub(crate) fn assign_send_seq(&mut self, pair: SocketPair) -> Option<u64> {
413        let sock = self.sockets.get_mut(&pair)?;
414        Some(sock.assign_seq())
415    }
416
417    fn receive_from_network(
418        &mut self,
419        src: SocketAddr,
420        dst: SocketAddr,
421        segment: Segment,
422    ) -> Result<(), Protocol> {
423        match segment {
424            Segment::Syn(syn) => {
425                // If bound, queue the syn; else we drop the syn triggering
426                // connection refused on the client.
427                if let Some(b) = self.binds.get_mut(&dst.port()) {
428                    if b.deque.len() == self.server_socket_capacity {
429                        panic!("{} server socket buffer full", dst);
430                    }
431
432                    if matches(b.bind_addr, dst) {
433                        b.deque.push_back((syn, src));
434                        b.notify.notify_one();
435                    }
436                }
437            }
438            Segment::Data(seq, data) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
439                Some(sock) => sock.buffer(seq, SequencedSegment::Data(data))?,
440                None => return Err(Protocol::Tcp(Segment::Rst)),
441            },
442            Segment::Fin(seq) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
443                Some(sock) => sock.buffer(seq, SequencedSegment::Fin)?,
444                None => return Err(Protocol::Tcp(Segment::Rst)),
445            },
446            Segment::Rst => {
447                if self.sockets.get(&SocketPair::new(dst, src)).is_some() {
448                    self.sockets
449                        .swap_remove(&SocketPair::new(dst, src))
450                        .unwrap();
451                }
452            }
453        };
454
455        Ok(())
456    }
457
458    pub(crate) fn close_stream_half(&mut self, pair: SocketPair) {
459        // Receiving a RST removes the socket, so it's possible that has occurred
460        // when halves of the stream drop.
461        if let Some(sock) = self.sockets.get_mut(&pair) {
462            sock.ref_ct -= 1;
463
464            if sock.ref_ct == 0 {
465                self.sockets.swap_remove(&pair).unwrap();
466            }
467        }
468    }
469
470    pub(crate) fn unbind(&mut self, addr: SocketAddr) {
471        let exists = self.binds.swap_remove(&addr.port());
472
473        assert!(exists.is_some(), "unknown bind {addr}");
474
475        tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Unbind");
476    }
477}
478
479/// Returns whether the given bind addr can accept a packet routed to the given dst
480pub fn matches(bind: SocketAddr, dst: SocketAddr) -> bool {
481    if bind.ip().is_unspecified() && bind.port() == dst.port() {
482        return true;
483    }
484
485    bind == dst
486}
487
488/// Returns true if loopback is supported between two addresses, or
489/// if the IPs are the same (in which case turmoil treats it like loopback)
490pub(crate) fn is_same(src: SocketAddr, dst: SocketAddr) -> bool {
491    dst.ip().is_loopback() || src.ip() == dst.ip()
492}
493
494#[cfg(test)]
495mod test {
496    use std::time::Duration;
497
498    use crate::{host::HostTimer, Host, Result};
499
500    #[test]
501    fn recycle_ports() -> Result {
502        let mut host = Host::new(
503            "host",
504            std::net::Ipv4Addr::UNSPECIFIED.into(),
505            HostTimer::new(Duration::ZERO),
506            49152..=49162,
507            1,
508            1,
509        );
510
511        host.udp.bind((host.addr, 49161).into())?;
512        host.udp.bind((host.addr, 49162).into())?;
513
514        for _ in 49152..49161 {
515            host.assign_ephemeral_port();
516        }
517
518        assert_eq!(49152, host.assign_ephemeral_port());
519
520        Ok(())
521    }
522}