1use bytes::Bytes;
2use indexmap::{IndexMap, IndexSet};
3use tokio::{
4 sync::{mpsc, Mutex},
5 time::sleep,
6};
7
8use crate::{
9 envelope::{Datagram, Envelope, Protocol},
10 host::is_same,
11 ToSocketAddrs, World, TRACING_TARGET,
12};
13
14use std::{
15 cmp,
16 io::{self, Error, ErrorKind, Result},
17 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
18};
19
20pub struct UdpSocket {
24 local_addr: SocketAddr,
25 rx: Mutex<Rx>,
26}
27
28impl std::fmt::Debug for UdpSocket {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("UdpSocket")
31 .field("local_addr", &self.local_addr)
32 .finish()
33 }
34}
35
36#[derive(Debug, Default)]
37pub(crate) struct MulticastGroups(IndexMap<SocketAddr, IndexSet<SocketAddr>>);
38
39impl MulticastGroups {
40 fn destination_addresses(&self, group: SocketAddr) -> IndexSet<SocketAddr> {
41 self.0.get(&group).cloned().unwrap_or_default()
42 }
43
44 fn contains_destination_address(&self, group: IpAddr, member: SocketAddr) -> bool {
45 self.0
46 .get(&SocketAddr::new(group, member.port()))
47 .and_then(|members| members.get(&member))
48 .is_some()
49 }
50
51 fn join(&mut self, group: IpAddr, member: SocketAddr) {
52 self.0
53 .entry(SocketAddr::new(group, member.port()))
54 .and_modify(|members| {
55 members.insert(member);
56 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Join multicast group");
57 })
58 .or_insert_with(|| IndexSet::from([member]));
59 }
60
61 fn leave(&mut self, group: IpAddr, member: SocketAddr) {
62 let index = self
63 .0
64 .entry(SocketAddr::new(group, member.port()))
65 .and_modify(|members| {
66 members.swap_remove(&member);
67 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Leave multicast group");
68 })
69 .index();
70
71 if self
72 .0
73 .get_index(index)
74 .map(|(_, members)| members.is_empty())
75 .unwrap_or(false)
76 {
77 self.0.swap_remove_index(index);
78 }
79 }
80
81 fn leave_all(&mut self, member: SocketAddr) {
82 for (group, members) in self.0.iter_mut() {
83 members.swap_remove(&member);
84 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Leave multicast group");
85 }
86 self.0.retain(|_, members| !members.is_empty());
87 }
88}
89
90struct Rx {
91 recv: mpsc::Receiver<(Datagram, SocketAddr)>,
92 buffer: Option<(Datagram, SocketAddr)>,
97}
98
99impl Rx {
100 pub fn try_recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, Datagram, SocketAddr)> {
102 let (datagram, origin) = if let Some(datagram) = self.buffer.take() {
103 datagram
104 } else {
105 self.recv.try_recv().map_err(|_| {
106 io::Error::new(io::ErrorKind::WouldBlock, "socket receive queue is empty")
107 })?
108 };
109
110 let bytes = &datagram.0;
111 let limit = cmp::min(buf.len(), bytes.len());
112
113 buf[..limit].copy_from_slice(&bytes[..limit]);
114
115 Ok((limit, datagram, origin))
116 }
117
118 pub async fn readable(&mut self) -> Result<()> {
133 if self.buffer.is_some() {
134 return Ok(());
135 }
136
137 let datagram = self
138 .recv
139 .recv()
140 .await
141 .expect("sender should never be dropped");
142
143 self.buffer = Some(datagram);
144
145 Ok(())
146 }
147}
148
149impl UdpSocket {
150 pub(crate) fn new(local_addr: SocketAddr, rx: mpsc::Receiver<(Datagram, SocketAddr)>) -> Self {
151 Self {
152 local_addr,
153 rx: Mutex::new(Rx {
154 recv: rx,
155 buffer: None,
156 }),
157 }
158 }
159
160 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
169 World::current(|world| {
170 let mut addr = addr.to_socket_addr(&world.dns);
171 let host = world.current_host_mut();
172
173 verify_ipv4_bind_interface(addr.ip(), host.addr)?;
174
175 if addr.port() == 0 {
176 addr.set_port(host.assign_ephemeral_port());
177 }
178
179 host.udp.bind(addr)
180 })
181 }
182
183 pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
203 World::current(|world| {
204 let dst = target.to_socket_addr(&world.dns);
205 self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
206 Ok(buf.len())
207 })
208 }
209
210 pub fn try_send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
225 World::current(|world| {
226 let dst = target.to_socket_addr(&world.dns);
227 self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
228 Ok(buf.len())
229 })
230 }
231
232 pub async fn writable(&self) -> Result<()> {
247 Ok(())
249 }
250
251 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
258 let mut rx = self.rx.lock().await;
259 rx.readable().await?;
260
261 let (limit, datagram, origin) = rx
262 .try_recv_from(buf)
263 .expect("queue should be ready after readable yields");
264
265 tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %datagram, "Recv");
266
267 Ok((limit, origin))
268 }
269
270 pub fn try_recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
280 let mut rx = self.rx.try_lock().map_err(|_| {
281 io::Error::new(
282 io::ErrorKind::WouldBlock,
283 "socket is being read by another task",
284 )
285 })?;
286
287 let (limit, datagram, origin) = rx.try_recv_from(buf).map_err(|_| {
288 io::Error::new(io::ErrorKind::WouldBlock, "socket receive queue is empty")
289 })?;
290
291 tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %datagram, "Recv");
292
293 Ok((limit, origin))
294 }
295
296 pub async fn readable(&self) -> Result<()> {
311 let mut rx = self.rx.lock().await;
312 rx.readable().await?;
313 Ok(())
314 }
315
316 pub fn local_addr(&self) -> Result<SocketAddr> {
334 Ok(self.local_addr)
335 }
336
337 fn send(&self, world: &mut World, dst: SocketAddr, packet: Datagram) -> Result<()> {
338 let mut src = self.local_addr;
339 if dst.ip().is_loopback() {
340 src.set_ip(dst.ip());
341 }
342 if src.ip().is_unspecified() {
343 src.set_ip(world.current_host_mut().addr);
344 }
345
346 match dst {
347 SocketAddr::V4(dst) if dst.ip().is_broadcast() => {
348 let host = world.current_host();
349 match host.udp.is_broadcast_enabled(src.port()) {
350 true => world
351 .hosts
352 .iter()
353 .filter(|(_, host)| host.udp.is_port_assigned(dst.port()))
354 .map(|(addr, _)| SocketAddr::new(*addr, dst.port()))
355 .collect::<Vec<_>>()
356 .into_iter()
357 .try_for_each(|dst| match dst {
358 dst if src.ip() == dst.ip() => {
359 send_loopback(src, dst, Protocol::Udp(packet.clone()));
360 Ok(())
361 }
362 dst => world.send_message(src, dst, Protocol::Udp(packet.clone())),
363 }),
364 false => Err(Error::new(
365 ErrorKind::PermissionDenied,
366 "Broadcast is not enabled",
367 )),
368 }
369 }
370 dst if dst.ip().is_multicast() => world
371 .multicast_groups
372 .destination_addresses(dst)
373 .into_iter()
374 .try_for_each(|dst| match dst {
375 dst if src.ip() == dst.ip() => {
376 let host = world.current_host();
377 if host.udp.is_multicast_loop_enabled(dst.port()) {
378 send_loopback(src, dst, Protocol::Udp(packet.clone()));
379 }
380 Ok(())
381 }
382 dst => world.send_message(src, dst, Protocol::Udp(packet.clone())),
383 }),
384 dst if is_same(src, dst) => {
385 send_loopback(src, dst, Protocol::Udp(packet));
386 Ok(())
387 }
388 _ => world.send_message(src, dst, Protocol::Udp(packet)),
389 }
390 }
391
392 pub fn broadcast(&self) -> io::Result<bool> {
398 let local_port = self.local_addr.port();
399 World::current(|world| Ok(world.current_host().udp.is_broadcast_enabled(local_port)))
400 }
401
402 pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
407 let local_port = match self.local_addr {
408 SocketAddr::V4(addr) => addr.port(),
409 _ => return Ok(()),
410 };
411 World::current(|world| {
412 world.current_host_mut().udp.set_broadcast(local_port, on);
413 Ok(())
414 })
415 }
416
417 pub fn multicast_loop_v4(&self) -> io::Result<bool> {
423 let local_port = self.local_addr.port();
424 World::current(|world| {
425 Ok(world
426 .current_host()
427 .udp
428 .is_multicast_loop_enabled(local_port))
429 })
430 }
431
432 pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
440 let local_port = match self.local_addr {
441 SocketAddr::V4(addr) => addr.port(),
442 _ => return Ok(()),
443 };
444 World::current(|world| {
445 world
446 .current_host_mut()
447 .udp
448 .set_multicast_loop(local_port, on);
449 Ok(())
450 })
451 }
452
453 pub fn multicast_loop_v6(&self) -> io::Result<bool> {
459 let local_port = self.local_addr.port();
460 World::current(|world| {
461 Ok(world
462 .current_host()
463 .udp
464 .is_multicast_loop_enabled(local_port))
465 })
466 }
467
468 pub fn set_multicast_loop_v6(&self, on: bool) -> Result<()> {
476 let local_port = match self.local_addr {
477 SocketAddr::V6(addr) => addr.port(),
478 _ => return Ok(()),
479 };
480 World::current(|world| {
481 world
482 .current_host_mut()
483 .udp
484 .set_multicast_loop(local_port, on);
485 Ok(())
486 })
487 }
488
489 pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
499 if !multiaddr.is_multicast() {
500 return Err(Error::new(
501 ErrorKind::InvalidInput,
502 "Invalid multicast address",
503 ));
504 }
505
506 World::current(|world| {
507 let dst = destination_address(world, self);
508 verify_ipv4_bind_interface(interface, dst.ip())?;
509
510 world.multicast_groups.join(IpAddr::V4(multiaddr), dst);
511
512 Ok(())
513 })
514 }
515
516 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()> {
524 verify_ipv6_bind_interface(interface)?;
525 if !multiaddr.is_multicast() {
526 return Err(Error::new(
527 ErrorKind::InvalidInput,
528 "Invalid multicast address",
529 ));
530 }
531
532 World::current(|world| {
533 let dst = destination_address(world, self);
534
535 world.multicast_groups.join(IpAddr::V6(*multiaddr), dst);
536
537 Ok(())
538 })
539 }
540
541 pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
547 if !multiaddr.is_multicast() {
548 return Err(Error::new(
549 ErrorKind::InvalidInput,
550 "Invalid multicast address",
551 ));
552 }
553
554 World::current(|world| {
555 let dst = destination_address(world, self);
556 verify_ipv4_bind_interface(interface, dst.ip())?;
557
558 if !world
559 .multicast_groups
560 .contains_destination_address(IpAddr::V4(multiaddr), dst)
561 {
562 return Err(Error::new(
563 ErrorKind::AddrNotAvailable,
564 "Leaving a multicast group that has not been previously joined",
565 ));
566 }
567
568 world.multicast_groups.leave(IpAddr::V4(multiaddr), dst);
569
570 Ok(())
571 })
572 }
573
574 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
580 verify_ipv6_bind_interface(interface)?;
581 if !multiaddr.is_multicast() {
582 return Err(Error::new(
583 ErrorKind::InvalidInput,
584 "Invalid multicast address",
585 ));
586 }
587
588 World::current(|world| {
589 let dst = destination_address(world, self);
590
591 if !world
592 .multicast_groups
593 .contains_destination_address(IpAddr::V6(*multiaddr), dst)
594 {
595 return Err(Error::new(
596 ErrorKind::AddrNotAvailable,
597 "Leaving a multicast group that has not been previously joined",
598 ));
599 }
600
601 world.multicast_groups.leave(IpAddr::V6(*multiaddr), dst);
602
603 Ok(())
604 })
605 }
606}
607
608fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
609 tokio::spawn(async move {
610 let tick_duration = World::current(|world| world.tick_duration);
614 sleep(tick_duration).await;
615
616 World::current(|world| {
617 world
618 .current_host_mut()
619 .receive_from_network(Envelope { src, dst, message })
620 .expect("UDP does not get feedback on delivery errors");
621 })
622 });
623}
624
625fn verify_ipv4_bind_interface<A>(interface: A, addr: IpAddr) -> Result<()>
626where
627 A: Into<IpAddr>,
628{
629 let interface = interface.into();
630
631 if !interface.is_unspecified() && !interface.is_loopback() {
632 return Err(Error::new(
633 ErrorKind::AddrNotAvailable,
634 format!("{interface} is not supported"),
635 ));
636 }
637
638 if interface.is_ipv4() != addr.is_ipv4() {
639 panic!("ip version mismatch: {:?} host: {:?}", interface, addr)
640 }
641
642 Ok(())
643}
644
645fn verify_ipv6_bind_interface(interface: u32) -> Result<()> {
646 if interface != 0 {
647 return Err(Error::new(
648 ErrorKind::AddrNotAvailable,
649 format!("interface {interface} is not supported"),
650 ));
651 }
652
653 Ok(())
654}
655
656fn destination_address(world: &World, socket: &UdpSocket) -> SocketAddr {
657 let local_port = socket
658 .local_addr()
659 .expect("local_addr is always present in simulation")
660 .port();
661 let host_addr = world.current_host().addr;
662 SocketAddr::from((host_addr, local_port))
663}
664
665impl Drop for UdpSocket {
666 fn drop(&mut self) {
667 World::current_if_set(|world| {
668 world
669 .multicast_groups
670 .leave_all(destination_address(world, self));
671 world.current_host_mut().udp.unbind(self.local_addr);
672 });
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 mod multicast_group {
681 use super::*;
682
683 #[test]
684 fn joining_does_not_produce_duplicate_addresses() {
685 let member = "[fe80::1]:9000".parse().unwrap();
686 let group = "ff08::1".parse().unwrap();
687 let mut groups = MulticastGroups::default();
688 groups.join(group, member);
689 groups.join(group, member);
690
691 let addrs = groups.0.values().flatten().collect::<Vec<_>>();
692 assert_eq!(addrs.as_slice(), &[&member]);
693 }
694
695 #[test]
696 fn leaving_does_not_remove_entire_group() {
697 let member1 = "[fe80::1]:9000".parse().unwrap();
698 let memeber2 = "[fe80::2]:9000".parse().unwrap();
699 let group = "ff08::1".parse().unwrap();
700 let mut groups = MulticastGroups::default();
701 groups.join(group, member1);
702 groups.join(group, memeber2);
703 groups.leave(group, memeber2);
704
705 let addrs = groups.0.values().flatten().collect::<Vec<_>>();
706 assert_eq!(addrs.as_slice(), &[&member1]);
707 }
708
709 #[test]
710 fn leaving_removes_empty_group() {
711 let member = "[fe80::1]:9000".parse().unwrap();
712 let group = "ff08::1".parse().unwrap();
713 let mut groups = MulticastGroups::default();
714 groups.join(group, member);
715 groups.leave(group, member);
716
717 assert_eq!(groups.0.len(), 0);
718 }
719
720 #[test]
721 fn leaving_removes_empty_groups() {
722 let member = "[fe80::1]:9000".parse().unwrap();
723 let group1 = "ff08::1".parse().unwrap();
724 let group2 = "ff08::2".parse().unwrap();
725 let mut groups = MulticastGroups::default();
726 groups.join(group1, member);
727 groups.join(group2, member);
728 groups.leave_all(member);
729
730 assert_eq!(groups.0.len(), 0);
731 }
732 }
733}