turmoil/net/tcp/
listener.rs
1use std::{
2 io::{Error, ErrorKind, Result},
3 net::SocketAddr,
4 sync::Arc,
5};
6
7use tokio::sync::Notify;
8
9use crate::{
10 net::{SocketPair, TcpStream},
11 world::World,
12 ToSocketAddrs, TRACING_TARGET,
13};
14
15pub struct TcpListener {
19 local_addr: SocketAddr,
20 notify: Arc<Notify>,
21}
22
23impl TcpListener {
24 pub(crate) fn new(local_addr: SocketAddr, notify: Arc<Notify>) -> Self {
25 Self { local_addr, notify }
26 }
27
28 pub async fn bind<A: ToSocketAddrs>(addr: A) -> Result<TcpListener> {
38 World::current(|world| {
39 let mut addr = addr.to_socket_addr(&world.dns);
40 let host = world.current_host_mut();
41
42 if !addr.ip().is_unspecified() && !addr.ip().is_loopback() {
43 return Err(Error::new(
44 ErrorKind::AddrNotAvailable,
45 format!("{addr} is not supported"),
46 ));
47 }
48
49 if addr.is_ipv4() != host.addr.is_ipv4() {
50 panic!("ip version mismatch: {:?} host: {:?}", addr, host.addr)
51 }
52
53 if addr.port() == 0 {
54 addr.set_port(host.assign_ephemeral_port());
55 }
56
57 host.tcp.bind(addr)
58 })
59 }
60
61 pub async fn accept(&self) -> Result<(TcpStream, SocketAddr)> {
67 let origin = loop {
68 let maybe_accept = World::current(|world| {
69 let host = world.current_host_mut();
70 host.tcp.accept(self.local_addr)
71 });
72
73 let Some((syn, origin)) = maybe_accept else {
74 self.notify.notified().await;
76 continue;
77 };
78
79 tracing::trace!(target: TRACING_TARGET, src = ?origin, dst = ?self.local_addr, protocol = %"TCP SYN", "Recv");
80
81 let ack = syn.ack.send(());
84 tracing::trace!(target: TRACING_TARGET, src = ?self.local_addr, dst = ?origin, protocol = %"TCP SYN-ACK", "Send");
85
86 if ack.is_ok() {
87 break origin;
88 }
89 };
90
91 let stream = World::current(|world| {
92 let host = world.current_host_mut();
93
94 let mut my_addr = self.local_addr;
95 if origin.ip().is_loopback() {
96 my_addr.set_ip(origin.ip());
97 }
98 if my_addr.ip().is_unspecified() {
99 my_addr.set_ip(host.addr);
100 }
101
102 let pair = SocketPair::new(my_addr, origin);
103 let rx = host.tcp.new_stream(pair);
104 TcpStream::new(pair, rx)
105 });
106
107 tracing::trace!(target: TRACING_TARGET, src = ?self.local_addr, dst = ?origin, "Accepted");
108 Ok((stream, origin))
109 }
110
111 pub fn local_addr(&self) -> Result<SocketAddr> {
113 Ok(self.local_addr)
114 }
115}
116
117impl Drop for TcpListener {
118 fn drop(&mut self) {
119 World::current_if_set(|world| world.current_host_mut().tcp.unbind(self.local_addr));
120 }
121}