1use bytes::Bytes;
2use indexmap::{IndexMap, IndexSet};
3use tokio::{
4 sync::{mpsc, Mutex},
5 time::sleep,
6};
7
8use crate::{
9 envelope::{Datagram, Envelope, Protocol},
10 host::is_same,
11 ToSocketAddrs, World, TRACING_TARGET,
12};
13
14use std::{
15 cmp,
16 io::{self, Error, ErrorKind, Result},
17 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
18};
19
20pub struct UdpSocket {
24 local_addr: SocketAddr,
25 rx: Mutex<Rx>,
26}
27
28impl std::fmt::Debug for UdpSocket {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("UdpSocket")
31 .field("local_addr", &self.local_addr)
32 .finish()
33 }
34}
35
36#[derive(Debug, Default)]
37pub(crate) struct MulticastGroups(IndexMap<SocketAddr, IndexSet<SocketAddr>>);
38
39impl MulticastGroups {
40 fn destination_addresses(&self, group: SocketAddr) -> IndexSet<SocketAddr> {
41 self.0.get(&group).cloned().unwrap_or_default()
42 }
43
44 fn contains_destination_address(&self, group: IpAddr, member: SocketAddr) -> bool {
45 self.0
46 .get(&SocketAddr::new(group, member.port()))
47 .and_then(|members| members.get(&member))
48 .is_some()
49 }
50
51 fn join(&mut self, group: IpAddr, member: SocketAddr) {
52 self.0
53 .entry(SocketAddr::new(group, member.port()))
54 .and_modify(|members| {
55 members.insert(member);
56 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Join multicast group");
57 })
58 .or_insert_with(|| IndexSet::from([member]));
59 }
60
61 fn leave(&mut self, group: IpAddr, member: SocketAddr) {
62 let index = self
63 .0
64 .entry(SocketAddr::new(group, member.port()))
65 .and_modify(|members| {
66 members.swap_remove(&member);
67 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Leave multicast group");
68 })
69 .index();
70
71 if self
72 .0
73 .get_index(index)
74 .map(|(_, members)| members.is_empty())
75 .unwrap_or(false)
76 {
77 self.0.swap_remove_index(index);
78 }
79 }
80
81 fn leave_all(&mut self, member: SocketAddr) {
82 for (group, members) in self.0.iter_mut() {
83 members.swap_remove(&member);
84 tracing::info!(target: TRACING_TARGET, ?member, group = ?group, protocol = %"UDP", "Leave multicast group");
85 }
86 self.0.retain(|_, members| !members.is_empty());
87 }
88}
89
90struct Rx {
91 recv: mpsc::Receiver<(Datagram, SocketAddr)>,
92 buffer: Option<(Datagram, SocketAddr)>,
97}
98
99impl Rx {
100 pub fn try_recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, Datagram, SocketAddr)> {
102 let (datagram, origin) = if let Some(datagram) = self.buffer.take() {
103 datagram
104 } else {
105 self.recv.try_recv().map_err(|_| {
106 io::Error::new(io::ErrorKind::WouldBlock, "socket receive queue is empty")
107 })?
108 };
109
110 let bytes = &datagram.0;
111 let limit = cmp::min(buf.len(), bytes.len());
112
113 buf[..limit].copy_from_slice(&bytes[..limit]);
114
115 Ok((limit, datagram, origin))
116 }
117
118 async fn readable(&mut self) -> Result<()> {
133 if self.buffer.is_some() {
134 return Ok(());
135 }
136
137 let datagram = self
138 .recv
139 .recv()
140 .await
141 .expect("sender should never be dropped");
142
143 self.buffer = Some(datagram);
144
145 Ok(())
146 }
147}
148
149impl UdpSocket {
150 pub(crate) fn new(local_addr: SocketAddr, rx: mpsc::Receiver<(Datagram, SocketAddr)>) -> Self {
151 Self {
152 local_addr,
153 rx: Mutex::new(Rx {
154 recv: rx,
155 buffer: None,
156 }),
157 }
158 }
159 pub async fn connect<A: ToSocketAddrs>(&self, addr: A) -> Result<()> {
160 World::current(|world| {
161 let addr = addr.to_socket_addr(&world.dns)?;
162 let host = world.current_host_mut();
163
164 host.udp.connect(self.local_addr, addr);
165 Ok(())
166 })
167 }
168
169 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
178 World::current(|world| {
179 let mut addr = addr.to_socket_addr(&world.dns)?;
180 let host = world.current_host_mut();
181
182 verify_ipv4_bind_interface(addr.ip(), host.addr)?;
183
184 if addr.port() == 0 {
185 addr.set_port(host.assign_ephemeral_port());
186 }
187
188 host.udp.bind(addr)
189 })
190 }
191
192 pub async fn send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
212 World::current(|world| {
213 let dst = target.to_socket_addr(&world.dns)?;
214 self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
215 Ok(buf.len())
216 })
217 }
218
219 pub fn try_send_to<A: ToSocketAddrs>(&self, buf: &[u8], target: A) -> Result<usize> {
234 World::current(|world| {
235 let dst = target.to_socket_addr(&world.dns)?;
236 self.send(world, dst, Datagram(Bytes::copy_from_slice(buf)))?;
237 Ok(buf.len())
238 })
239 }
240
241 pub async fn writable(&self) -> Result<()> {
256 Ok(())
258 }
259
260 pub async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
267 let mut rx = self.rx.lock().await;
268 rx.readable().await?;
269
270 let (limit, datagram, origin) = rx
271 .try_recv_from(buf)
272 .expect("queue should be ready after readable yields");
273
274 tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %datagram, "Recv");
275
276 Ok((limit, origin))
277 }
278
279 pub fn try_recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
289 let mut rx = self.rx.try_lock().map_err(|_| {
290 io::Error::new(
291 io::ErrorKind::WouldBlock,
292 "socket is being read by another task",
293 )
294 })?;
295
296 let (limit, datagram, origin) = rx.try_recv_from(buf).map_err(|_| {
297 io::Error::new(io::ErrorKind::WouldBlock, "socket receive queue is empty")
298 })?;
299
300 tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %datagram, "Recv");
301
302 Ok((limit, origin))
303 }
304 pub fn try_recv(&self, buf: &mut [u8]) -> io::Result<usize> {
315 self.try_recv_from(buf).map(|(size, _)| size)
316 }
317
318 pub async fn readable(&self) -> Result<()> {
333 let mut rx = self.rx.lock().await;
334 rx.readable().await?;
335 Ok(())
336 }
337
338 pub fn local_addr(&self) -> Result<SocketAddr> {
356 Ok(self.local_addr)
357 }
358
359 fn send(&self, world: &mut World, dst: SocketAddr, packet: Datagram) -> Result<()> {
360 let mut src = self.local_addr;
361 if dst.ip().is_loopback() {
362 src.set_ip(dst.ip());
363 }
364 if src.ip().is_unspecified() {
365 src.set_ip(world.current_host_mut().addr);
366 }
367
368 match dst {
369 SocketAddr::V4(dst) if dst.ip().is_broadcast() => {
370 let host = world.current_host();
371 match host.udp.is_broadcast_enabled(src.port()) {
372 true => world
373 .hosts
374 .iter()
375 .filter(|(_, host)| host.udp.is_port_assigned(dst.port()))
376 .map(|(addr, _)| SocketAddr::new(*addr, dst.port()))
377 .collect::<Vec<_>>()
378 .into_iter()
379 .try_for_each(|dst| match dst {
380 dst if src.ip() == dst.ip() => {
381 send_loopback(src, dst, Protocol::Udp(packet.clone()));
382 Ok(())
383 }
384 dst => world.send_message(src, dst, Protocol::Udp(packet.clone())),
385 }),
386 false => Err(Error::new(
387 ErrorKind::PermissionDenied,
388 "Broadcast is not enabled",
389 )),
390 }
391 }
392 dst if dst.ip().is_multicast() => world
393 .multicast_groups
394 .destination_addresses(dst)
395 .into_iter()
396 .try_for_each(|dst| match dst {
397 dst if src.ip() == dst.ip() => {
398 let host = world.current_host();
399 if host.udp.is_multicast_loop_enabled(dst.port()) {
400 send_loopback(src, dst, Protocol::Udp(packet.clone()));
401 }
402 Ok(())
403 }
404 dst => world.send_message(src, dst, Protocol::Udp(packet.clone())),
405 }),
406 dst if is_same(src, dst) => {
407 send_loopback(src, dst, Protocol::Udp(packet));
408 Ok(())
409 }
410 _ => world.send_message(src, dst, Protocol::Udp(packet)),
411 }
412 }
413
414 pub fn broadcast(&self) -> io::Result<bool> {
420 let local_port = self.local_addr.port();
421 World::current(|world| Ok(world.current_host().udp.is_broadcast_enabled(local_port)))
422 }
423
424 pub fn set_broadcast(&self, on: bool) -> io::Result<()> {
429 let local_port = match self.local_addr {
430 SocketAddr::V4(addr) => addr.port(),
431 _ => return Ok(()),
432 };
433 World::current(|world| {
434 world.current_host_mut().udp.set_broadcast(local_port, on);
435 Ok(())
436 })
437 }
438
439 pub fn multicast_loop_v4(&self) -> io::Result<bool> {
445 let local_port = self.local_addr.port();
446 World::current(|world| {
447 Ok(world
448 .current_host()
449 .udp
450 .is_multicast_loop_enabled(local_port))
451 })
452 }
453
454 pub fn set_multicast_loop_v4(&self, on: bool) -> io::Result<()> {
462 let local_port = match self.local_addr {
463 SocketAddr::V4(addr) => addr.port(),
464 _ => return Ok(()),
465 };
466 World::current(|world| {
467 world
468 .current_host_mut()
469 .udp
470 .set_multicast_loop(local_port, on);
471 Ok(())
472 })
473 }
474
475 pub fn multicast_loop_v6(&self) -> io::Result<bool> {
481 let local_port = self.local_addr.port();
482 World::current(|world| {
483 Ok(world
484 .current_host()
485 .udp
486 .is_multicast_loop_enabled(local_port))
487 })
488 }
489
490 pub fn set_multicast_loop_v6(&self, on: bool) -> Result<()> {
498 let local_port = match self.local_addr {
499 SocketAddr::V6(addr) => addr.port(),
500 _ => return Ok(()),
501 };
502 World::current(|world| {
503 world
504 .current_host_mut()
505 .udp
506 .set_multicast_loop(local_port, on);
507 Ok(())
508 })
509 }
510
511 pub fn join_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> Result<()> {
521 if !multiaddr.is_multicast() {
522 return Err(Error::new(
523 ErrorKind::InvalidInput,
524 "Invalid multicast address",
525 ));
526 }
527
528 World::current(|world| {
529 let dst = destination_address(world, self);
530 verify_ipv4_bind_interface(interface, dst.ip())?;
531
532 world.multicast_groups.join(IpAddr::V4(multiaddr), dst);
533
534 Ok(())
535 })
536 }
537
538 pub fn join_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> Result<()> {
546 verify_ipv6_bind_interface(interface)?;
547 if !multiaddr.is_multicast() {
548 return Err(Error::new(
549 ErrorKind::InvalidInput,
550 "Invalid multicast address",
551 ));
552 }
553
554 World::current(|world| {
555 let dst = destination_address(world, self);
556
557 world.multicast_groups.join(IpAddr::V6(*multiaddr), dst);
558
559 Ok(())
560 })
561 }
562
563 pub fn leave_multicast_v4(&self, multiaddr: Ipv4Addr, interface: Ipv4Addr) -> io::Result<()> {
569 if !multiaddr.is_multicast() {
570 return Err(Error::new(
571 ErrorKind::InvalidInput,
572 "Invalid multicast address",
573 ));
574 }
575
576 World::current(|world| {
577 let dst = destination_address(world, self);
578 verify_ipv4_bind_interface(interface, dst.ip())?;
579
580 if !world
581 .multicast_groups
582 .contains_destination_address(IpAddr::V4(multiaddr), dst)
583 {
584 return Err(Error::new(
585 ErrorKind::AddrNotAvailable,
586 "Leaving a multicast group that has not been previously joined",
587 ));
588 }
589
590 world.multicast_groups.leave(IpAddr::V4(multiaddr), dst);
591
592 Ok(())
593 })
594 }
595
596 pub fn leave_multicast_v6(&self, multiaddr: &Ipv6Addr, interface: u32) -> io::Result<()> {
602 verify_ipv6_bind_interface(interface)?;
603 if !multiaddr.is_multicast() {
604 return Err(Error::new(
605 ErrorKind::InvalidInput,
606 "Invalid multicast address",
607 ));
608 }
609
610 World::current(|world| {
611 let dst = destination_address(world, self);
612
613 if !world
614 .multicast_groups
615 .contains_destination_address(IpAddr::V6(*multiaddr), dst)
616 {
617 return Err(Error::new(
618 ErrorKind::AddrNotAvailable,
619 "Leaving a multicast group that has not been previously joined",
620 ));
621 }
622
623 world.multicast_groups.leave(IpAddr::V6(*multiaddr), dst);
624
625 Ok(())
626 })
627 }
628}
629
630fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
631 tokio::spawn(async move {
632 let tick_duration = World::current(|world| world.tick_duration);
636 sleep(tick_duration).await;
637
638 World::current(|world| {
639 world
640 .current_host_mut()
641 .receive_from_network(Envelope { src, dst, message })
642 .expect("UDP does not get feedback on delivery errors");
643 })
644 });
645}
646
647fn verify_ipv4_bind_interface<A>(interface: A, addr: IpAddr) -> Result<()>
648where
649 A: Into<IpAddr>,
650{
651 let interface = interface.into();
652
653 if !interface.is_unspecified() && !interface.is_loopback() {
654 return Err(Error::new(
655 ErrorKind::AddrNotAvailable,
656 format!("{interface} is not supported"),
657 ));
658 }
659
660 if interface.is_ipv4() != addr.is_ipv4() {
661 panic!("ip version mismatch: {interface:?} host: {addr:?}")
662 }
663
664 Ok(())
665}
666
667fn verify_ipv6_bind_interface(interface: u32) -> Result<()> {
668 if interface != 0 {
669 return Err(Error::new(
670 ErrorKind::AddrNotAvailable,
671 format!("interface {interface} is not supported"),
672 ));
673 }
674
675 Ok(())
676}
677
678fn destination_address(world: &World, socket: &UdpSocket) -> SocketAddr {
679 let local_port = socket
680 .local_addr()
681 .expect("local_addr is always present in simulation")
682 .port();
683 let host_addr = world.current_host().addr;
684 SocketAddr::from((host_addr, local_port))
685}
686
687impl Drop for UdpSocket {
688 fn drop(&mut self) {
689 World::current_if_set(|world| {
690 world
691 .multicast_groups
692 .leave_all(destination_address(world, self));
693 world.current_host_mut().udp.unbind(self.local_addr);
694 });
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 mod multicast_group {
703 use super::*;
704
705 #[test]
706 fn joining_does_not_produce_duplicate_addresses() {
707 let member = "[fe80::1]:9000".parse().unwrap();
708 let group = "ff08::1".parse().unwrap();
709 let mut groups = MulticastGroups::default();
710 groups.join(group, member);
711 groups.join(group, member);
712
713 let addrs = groups.0.values().flatten().collect::<Vec<_>>();
714 assert_eq!(addrs.as_slice(), &[&member]);
715 }
716
717 #[test]
718 fn leaving_does_not_remove_entire_group() {
719 let member1 = "[fe80::1]:9000".parse().unwrap();
720 let memeber2 = "[fe80::2]:9000".parse().unwrap();
721 let group = "ff08::1".parse().unwrap();
722 let mut groups = MulticastGroups::default();
723 groups.join(group, member1);
724 groups.join(group, memeber2);
725 groups.leave(group, memeber2);
726
727 let addrs = groups.0.values().flatten().collect::<Vec<_>>();
728 assert_eq!(addrs.as_slice(), &[&member1]);
729 }
730
731 #[test]
732 fn leaving_removes_empty_group() {
733 let member = "[fe80::1]:9000".parse().unwrap();
734 let group = "ff08::1".parse().unwrap();
735 let mut groups = MulticastGroups::default();
736 groups.join(group, member);
737 groups.leave(group, member);
738
739 assert_eq!(groups.0.len(), 0);
740 }
741
742 #[test]
743 fn leaving_removes_empty_groups() {
744 let member = "[fe80::1]:9000".parse().unwrap();
745 let group1 = "ff08::1".parse().unwrap();
746 let group2 = "ff08::2".parse().unwrap();
747 let mut groups = MulticastGroups::default();
748 groups.join(group1, member);
749 groups.join(group2, member);
750 groups.leave_all(member);
751
752 assert_eq!(groups.0.len(), 0);
753 }
754 }
755}