1use crate::envelope::{hex, Datagram, Protocol, Segment, Syn};
2use crate::net::{SocketPair, TcpListener, UdpSocket};
3use crate::{Envelope, TRACING_TARGET};
4
5use bytes::Bytes;
6use indexmap::IndexMap;
7use std::collections::VecDeque;
8use std::fmt::Display;
9use std::io;
10use std::net::{IpAddr, SocketAddr};
11use std::ops::RangeInclusive;
12use std::sync::Arc;
13use tokio::sync::{mpsc, Notify};
14use tokio::time::{Duration, Instant};
15
16const DEFAULT_BROADCAST: bool = false;
17const DEFAULT_MULTICAST_LOOP: bool = true;
18
19pub(crate) struct Host {
25 pub(crate) nodename: String,
27
28 pub(crate) addr: IpAddr,
30
31 pub(crate) timer: HostTimer,
33
34 pub(crate) udp: Udp,
36
37 pub(crate) tcp: Tcp,
39
40 next_ephemeral_port: u16,
41 ephemeral_ports: RangeInclusive<u16>,
42}
43
44impl Host {
45 pub(crate) fn new(
46 nodename: impl Into<String>,
47 addr: IpAddr,
48 timer: HostTimer,
49 ephemeral_ports: RangeInclusive<u16>,
50 tcp_capacity: usize,
51 udp_capacity: usize,
52 ) -> Host {
53 Host {
54 nodename: nodename.into(),
55 addr,
56 udp: Udp::new(udp_capacity),
57 tcp: Tcp::new(tcp_capacity),
58 timer,
59 next_ephemeral_port: *ephemeral_ports.start(),
60 ephemeral_ports,
61 }
62 }
63
64 pub(crate) fn assign_ephemeral_port(&mut self) -> u16 {
65 for _ in self.ephemeral_ports.clone() {
66 let ret = self.next_ephemeral_port;
67
68 if self.next_ephemeral_port == *self.ephemeral_ports.end() {
69 self.next_ephemeral_port = *self.ephemeral_ports.start();
71 } else {
72 self.next_ephemeral_port += 1;
74 }
75
76 if self.udp.is_port_assigned(ret) || self.tcp.is_port_assigned(ret) {
78 continue;
79 }
80
81 return ret;
82 }
83
84 panic!("Host: '{}' ports exhausted", self.nodename)
85 }
86
87 pub(crate) fn receive_from_network(&mut self, envelope: Envelope) -> Result<(), Protocol> {
95 let Envelope { src, dst, message } = envelope;
96
97 tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %message, "Delivered");
98
99 match message {
100 Protocol::Tcp(segment) => self.tcp.receive_from_network(src, dst, segment),
101 Protocol::Udp(datagram) => {
102 self.udp.receive_from_network(src, dst, datagram);
103 Ok(())
104 }
105 }
106 }
107}
108
109pub(crate) struct HostTimer {
110 elapsed: Duration,
112
113 now: Option<Instant>,
115
116 start_offset: Duration,
119}
120
121impl HostTimer {
122 pub(crate) fn new(start_offset: Duration) -> Self {
123 Self {
124 elapsed: Duration::ZERO,
125 now: None,
126 start_offset,
127 }
128 }
129
130 pub(crate) fn tick(&mut self, duration: Duration) {
131 self.elapsed += duration
132 }
133
134 pub(crate) fn now(&mut self, now: Instant) {
142 self.now.replace(now);
143 }
144
145 pub(crate) fn elapsed(&self) -> Duration {
147 let run_duration = self.now.expect("host instant not set").elapsed();
148 self.elapsed + run_duration
149 }
150
151 pub(crate) fn sim_elapsed(&self) -> Duration {
154 self.start_offset + self.elapsed()
155 }
156}
157
158pub(crate) struct Udp {
160 binds: IndexMap<u16, UdpBind>,
162
163 capacity: usize,
165}
166
167struct UdpBind {
168 bind_addr: SocketAddr,
169 broadcast: bool,
170 multicast_loop: bool,
171 queue: mpsc::Sender<(Datagram, SocketAddr)>,
172}
173
174impl Udp {
175 fn new(capacity: usize) -> Self {
176 Self {
177 binds: IndexMap::new(),
178 capacity,
179 }
180 }
181
182 pub(crate) fn is_port_assigned(&self, port: u16) -> bool {
183 self.binds.keys().any(|p| *p == port)
184 }
185
186 pub(crate) fn is_broadcast_enabled(&self, port: u16) -> bool {
187 self.binds
188 .get(&port)
189 .map(|bind| bind.broadcast)
190 .unwrap_or(DEFAULT_BROADCAST)
191 }
192
193 pub(crate) fn is_multicast_loop_enabled(&self, port: u16) -> bool {
194 self.binds
195 .get(&port)
196 .map(|bind| bind.multicast_loop)
197 .unwrap_or(DEFAULT_MULTICAST_LOOP)
198 }
199
200 pub(crate) fn set_broadcast(&mut self, port: u16, on: bool) {
201 self.binds
202 .entry(port)
203 .and_modify(|bind| bind.broadcast = on);
204 }
205
206 pub(crate) fn set_multicast_loop(&mut self, port: u16, on: bool) {
207 self.binds
208 .entry(port)
209 .and_modify(|bind| bind.multicast_loop = on);
210 }
211
212 pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<UdpSocket> {
213 let (tx, rx) = mpsc::channel(self.capacity);
214 let bind = UdpBind {
215 bind_addr: addr,
216 broadcast: DEFAULT_BROADCAST,
217 multicast_loop: DEFAULT_MULTICAST_LOOP,
218 queue: tx,
219 };
220
221 match self.binds.entry(addr.port()) {
222 indexmap::map::Entry::Occupied(_) => {
223 return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
224 }
225 indexmap::map::Entry::Vacant(entry) => entry.insert(bind),
226 };
227
228 tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Bind");
229
230 Ok(UdpSocket::new(addr, rx))
231 }
232
233 fn receive_from_network(&mut self, src: SocketAddr, dst: SocketAddr, datagram: Datagram) {
234 if let Some(bind) = self.binds.get_mut(&dst.port()) {
235 if !matches(bind.bind_addr, dst) {
236 tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Addr not bound)");
237 return;
238 }
239 if let Err(err) = bind.queue.try_send((datagram, src)) {
240 match err {
242 mpsc::error::TrySendError::Full((datagram, _)) => {
243 tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Full buffer)");
244 }
245 mpsc::error::TrySendError::Closed((datagram, _)) => {
246 tracing::trace!(target: TRACING_TARGET, ?src, ?dst, protocol = %Protocol::Udp(datagram), "Dropped (Receiver closed)");
247 }
248 }
249 }
250 }
251 }
252
253 pub(crate) fn unbind(&mut self, addr: SocketAddr) {
254 let exists = self.binds.swap_remove(&addr.port());
255
256 assert!(exists.is_some(), "unknown bind {addr}");
257
258 tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"UDP", "Unbind");
259 }
260}
261
262pub(crate) struct Tcp {
263 binds: IndexMap<u16, ServerSocket>,
265
266 server_socket_capacity: usize,
268
269 sockets: IndexMap<SocketPair, StreamSocket>,
271
272 socket_capacity: usize,
274}
275
276struct ServerSocket {
277 bind_addr: SocketAddr,
278
279 notify: Arc<Notify>,
281
282 deque: VecDeque<(Syn, SocketAddr)>,
284}
285
286struct StreamSocket {
287 local_addr: SocketAddr,
288 buf: IndexMap<u64, SequencedSegment>,
289 next_send_seq: u64,
290 recv_seq: u64,
291 sender: mpsc::Sender<SequencedSegment>,
292 ref_ct: usize,
295}
296
297#[derive(Debug)]
300pub(crate) enum SequencedSegment {
301 Data(Bytes),
302 Fin,
303}
304
305impl Display for SequencedSegment {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 match self {
308 SequencedSegment::Data(data) => hex("TCP", data, f),
309 SequencedSegment::Fin => write!(f, "TCP FIN"),
310 }
311 }
312}
313
314impl StreamSocket {
315 fn new(local_addr: SocketAddr, capacity: usize) -> (Self, mpsc::Receiver<SequencedSegment>) {
316 let (tx, rx) = mpsc::channel(capacity);
317 let sock = Self {
318 local_addr,
319 buf: IndexMap::new(),
320 next_send_seq: 1,
321 recv_seq: 0,
322 sender: tx,
323 ref_ct: 2,
324 };
325
326 (sock, rx)
327 }
328
329 fn assign_seq(&mut self) -> u64 {
330 let seq = self.next_send_seq;
331 self.next_send_seq += 1;
332 seq
333 }
334
335 fn buffer(&mut self, seq: u64, segment: SequencedSegment) -> Result<(), Protocol> {
338 use mpsc::error::TrySendError::*;
339
340 let exists = self.buf.insert(seq, segment);
341
342 assert!(exists.is_none(), "duplicate segment {seq}");
343
344 while self.buf.contains_key(&(self.recv_seq + 1)) {
345 self.recv_seq += 1;
346
347 let segment = self.buf.swap_remove(&self.recv_seq).unwrap();
348 self.sender.try_send(segment).map_err(|e| match e {
349 Closed(_) => Protocol::Tcp(Segment::Rst),
350 Full(_) => panic!("{} socket buffer full", self.local_addr),
351 })?;
352 }
353
354 Ok(())
355 }
356}
357
358impl Tcp {
359 fn new(capacity: usize) -> Self {
360 Self {
361 binds: IndexMap::new(),
362 sockets: IndexMap::new(),
363 server_socket_capacity: capacity,
364 socket_capacity: capacity,
365 }
366 }
367
368 fn is_port_assigned(&self, port: u16) -> bool {
369 self.binds.keys().any(|p| *p == port) || self.sockets.keys().any(|a| a.local.port() == port)
370 }
371
372 pub(crate) fn bind(&mut self, addr: SocketAddr) -> io::Result<TcpListener> {
373 let notify = Arc::new(Notify::new());
374 let sock = ServerSocket {
375 bind_addr: addr,
376 notify: notify.clone(),
377 deque: VecDeque::new(),
378 };
379
380 match self.binds.entry(addr.port()) {
381 indexmap::map::Entry::Occupied(_) => {
382 return Err(io::Error::new(io::ErrorKind::AddrInUse, addr.to_string()));
383 }
384 indexmap::map::Entry::Vacant(entry) => entry.insert(sock),
385 };
386
387 tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Bind");
388
389 Ok(TcpListener::new(addr, notify))
390 }
391
392 pub(crate) fn new_stream(&mut self, pair: SocketPair) -> mpsc::Receiver<SequencedSegment> {
393 let (sock, rx) = StreamSocket::new(pair.local, self.socket_capacity);
394
395 let exists = self.sockets.insert(pair, sock);
396
397 assert!(exists.is_none(), "{pair:?} is already connected");
398
399 rx
400 }
401
402 pub(crate) fn stream_count(&self) -> usize {
403 self.sockets.len()
404 }
405
406 pub(crate) fn accept(&mut self, addr: SocketAddr) -> Option<(Syn, SocketAddr)> {
407 self.binds[&addr.port()].deque.pop_front()
408 }
409
410 pub(crate) fn assign_send_seq(&mut self, pair: SocketPair) -> Option<u64> {
413 let sock = self.sockets.get_mut(&pair)?;
414 Some(sock.assign_seq())
415 }
416
417 fn receive_from_network(
418 &mut self,
419 src: SocketAddr,
420 dst: SocketAddr,
421 segment: Segment,
422 ) -> Result<(), Protocol> {
423 match segment {
424 Segment::Syn(syn) => {
425 if let Some(b) = self.binds.get_mut(&dst.port()) {
428 if b.deque.len() == self.server_socket_capacity {
429 panic!("{} server socket buffer full", dst);
430 }
431
432 if matches(b.bind_addr, dst) {
433 b.deque.push_back((syn, src));
434 b.notify.notify_one();
435 }
436 }
437 }
438 Segment::Data(seq, data) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
439 Some(sock) => sock.buffer(seq, SequencedSegment::Data(data))?,
440 None => return Err(Protocol::Tcp(Segment::Rst)),
441 },
442 Segment::Fin(seq) => match self.sockets.get_mut(&SocketPair::new(dst, src)) {
443 Some(sock) => sock.buffer(seq, SequencedSegment::Fin)?,
444 None => return Err(Protocol::Tcp(Segment::Rst)),
445 },
446 Segment::Rst => {
447 if self.sockets.get(&SocketPair::new(dst, src)).is_some() {
448 self.sockets
449 .swap_remove(&SocketPair::new(dst, src))
450 .unwrap();
451 }
452 }
453 };
454
455 Ok(())
456 }
457
458 pub(crate) fn close_stream_half(&mut self, pair: SocketPair) {
459 if let Some(sock) = self.sockets.get_mut(&pair) {
462 sock.ref_ct -= 1;
463
464 if sock.ref_ct == 0 {
465 self.sockets.swap_remove(&pair).unwrap();
466 }
467 }
468 }
469
470 pub(crate) fn unbind(&mut self, addr: SocketAddr) {
471 let exists = self.binds.swap_remove(&addr.port());
472
473 assert!(exists.is_some(), "unknown bind {addr}");
474
475 tracing::info!(target: TRACING_TARGET, ?addr, protocol = %"TCP", "Unbind");
476 }
477}
478
479pub fn matches(bind: SocketAddr, dst: SocketAddr) -> bool {
481 if bind.ip().is_unspecified() && bind.port() == dst.port() {
482 return true;
483 }
484
485 bind == dst
486}
487
488pub(crate) fn is_same(src: SocketAddr, dst: SocketAddr) -> bool {
491 dst.ip().is_loopback() || src.ip() == dst.ip()
492}
493
494#[cfg(test)]
495mod test {
496 use std::time::Duration;
497
498 use crate::{host::HostTimer, Host, Result};
499
500 #[test]
501 fn recycle_ports() -> Result {
502 let mut host = Host::new(
503 "host",
504 std::net::Ipv4Addr::UNSPECIFIED.into(),
505 HostTimer::new(Duration::ZERO),
506 49152..=49162,
507 1,
508 1,
509 );
510
511 host.udp.bind((host.addr, 49161).into())?;
512 host.udp.bind((host.addr, 49162).into())?;
513
514 for _ in 49152..49161 {
515 host.assign_ephemeral_port();
516 }
517
518 assert_eq!(49152, host.assign_ephemeral_port());
519
520 Ok(())
521 }
522}