Skip to main content

turmoil/net/tcp/
stream.rs

1use bytes::{Buf, Bytes};
2use std::future::poll_fn;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Mutex;
5use std::task::Waker;
6use std::{
7    fmt::Debug,
8    io::{self, Error, Result},
9    net::SocketAddr,
10    pin::Pin,
11    sync::Arc,
12    task::{ready, Context, Poll},
13};
14use tokio::{
15    io::{AsyncRead, AsyncWrite, ReadBuf},
16    runtime::Handle,
17    sync::{mpsc, oneshot},
18    time::sleep,
19};
20
21use crate::{
22    envelope::{Envelope, Protocol, Segment, Syn},
23    host::{is_same, SequencedSegment},
24    net::SocketPair,
25    world::World,
26    ToSocketAddrs, TRACING_TARGET,
27};
28
29use super::split_owned::{OwnedReadHalf, OwnedWriteHalf};
30
31/// A simulated TCP stream between a local and a remote socket.
32///
33/// All methods must be called from a host within a Turmoil simulation.
34#[derive(Debug)]
35pub struct TcpStream {
36    read_half: ReadHalf,
37    write_half: WriteHalf,
38}
39
40impl TcpStream {
41    pub(crate) fn new(
42        pair: SocketPair,
43        receiver: mpsc::Receiver<SequencedSegment>,
44        flow_control: BidiFlowControl,
45    ) -> Self {
46        let pair = Arc::new(pair);
47        let read_half = ReadHalf {
48            pair: pair.clone(),
49            rx: Rx {
50                recv: receiver,
51                buffer: None,
52            },
53            is_closed: false,
54            flow_control: flow_control.read,
55        };
56
57        let write_half = WriteHalf {
58            pair,
59            is_shutdown: false,
60            flow_control: flow_control.write,
61        };
62
63        Self {
64            read_half,
65            write_half,
66        }
67    }
68
69    /// Opens a TCP connection to a remote host.
70    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
71        let (ack, syn_ack) = oneshot::channel();
72
73        let (pair, rx, bidi) = World::current(|world| {
74            let dst = addr.to_socket_addr(&world.dns)?;
75
76            let (pair, rx, bidi) = {
77                let host = world.current_host_mut();
78                let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());
79                if dst.ip().is_loopback() {
80                    local_addr.set_ip(dst.ip());
81                }
82
83                let pair = SocketPair::new(local_addr, dst);
84                let (rx, bidi) = host.tcp.new_stream(pair);
85                (pair, rx, bidi)
86            };
87
88            let syn = Protocol::Tcp(Segment::Syn(Syn { ack }));
89            if !is_same(pair.local, pair.remote) {
90                world.send_message(pair.local, pair.remote, syn)?;
91            } else {
92                send_loopback(pair.local, pair.remote, syn);
93            };
94
95            Ok::<_, Error>((pair, rx, bidi))
96        })?;
97
98        syn_ack.await.map_err(|_| {
99            io::Error::new(io::ErrorKind::ConnectionRefused, pair.remote.to_string())
100        })?;
101
102        tracing::trace!(target: TRACING_TARGET, src = ?pair.remote, dst = ?pair.local, protocol = %"TCP SYN-ACK", "Recv");
103
104        Ok(TcpStream::new(pair, rx, bidi))
105    }
106
107    /// Try to write a buffer to the stream, returning how many bytes were
108    /// written.
109    ///
110    /// The function will attempt to write the entire contents of `buf`, but
111    /// only part of the buffer may be written.
112    ///
113    /// This function is usually paired with `writable()`.
114    ///
115    /// # Return
116    ///
117    /// If data is successfully written, `Ok(n)` is returned, where `n` is the
118    /// number of bytes written. If the stream is not ready to write data,
119    /// `Err(io::ErrorKind::WouldBlock)` is returned.
120    pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
121        self.write_half.try_write(buf)
122    }
123
124    /// Returns the local address that this stream is bound to.
125    pub fn local_addr(&self) -> Result<SocketAddr> {
126        Ok(self.read_half.pair.local)
127    }
128
129    /// Returns the remote address that this stream is connected to.
130    pub fn peer_addr(&self) -> Result<SocketAddr> {
131        Ok(self.read_half.pair.remote)
132    }
133
134    pub(crate) fn reunite(read_half: ReadHalf, write_half: WriteHalf) -> Self {
135        Self {
136            read_half,
137            write_half,
138        }
139    }
140
141    /// Waits for the socket to become writable.
142    ///
143    /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
144    /// paired with `try_write()`.
145    ///
146    /// # Cancel safety
147    ///
148    /// This method is cancel safe. Once a readiness event occurs, the method
149    /// will continue to return immediately until the readiness event is
150    /// consumed by an attempt to write that fails with `WouldBlock` or
151    /// `Poll::Pending`.
152    pub async fn writable(&self) -> Result<()> {
153        poll_fn(|cx| self.write_half.poll_writable(cx)).await
154    }
155
156    /// Splits a `TcpStream` into a read half and a write half, which can be used
157    /// to read and write the stream concurrently.
158    ///
159    /// **Note:** Dropping the write half will shut down the write half of the TCP
160    /// stream. This is equivalent to calling [`shutdown()`] on the `TcpStream`.
161    ///
162    /// [`shutdown()`]: fn@tokio::io::AsyncWriteExt::shutdown
163    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
164        (
165            OwnedReadHalf {
166                inner: self.read_half,
167            },
168            OwnedWriteHalf {
169                inner: self.write_half,
170            },
171        )
172    }
173
174    /// Has no effect in turmoil. API parity with
175    /// https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_nodelay
176    pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
177        Ok(())
178    }
179
180    /// Receives data on the socket from the remote address to which it is
181    /// connected, without removing that data from the queue. On success,
182    /// returns the number of bytes peeked.
183    ///
184    /// Successive calls return the same data.
185    pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
186        self.read_half.peek(buf).await
187    }
188
189    /// Attempts to receive data on the socket, without removing that data from
190    /// the queue, registering the current task for wakeup if data is not yet
191    /// available.
192    pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
193        self.read_half.poll_peek(cx, buf)
194    }
195}
196
197pub(crate) struct ReadHalf {
198    pub(crate) pair: Arc<SocketPair>,
199    rx: Rx,
200    /// FIN received, EOF for reads
201    is_closed: bool,
202    flow_control: Arc<FlowControl>,
203}
204
205struct Rx {
206    recv: mpsc::Receiver<SequencedSegment>,
207    /// The remaining bytes of a received data segment.
208    ///
209    /// This is used to support read impls by stashing available bytes for
210    /// subsequent reads.
211    buffer: Option<Bytes>,
212}
213
214impl ReadHalf {
215    fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
216        if self.is_closed || buf.capacity() == 0 {
217            return Poll::Ready(Ok(()));
218        }
219
220        if let Some(bytes) = self.rx.buffer.take() {
221            self.rx.buffer = Self::put_slice(bytes, buf);
222
223            return Poll::Ready(Ok(()));
224        }
225
226        match ready!(self.rx.recv.poll_recv(cx)) {
227            Some(seg) => {
228                tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
229
230                match seg {
231                    SequencedSegment::Data(bytes) => {
232                        self.flow_control.release();
233                        self.rx.buffer = Self::put_slice(bytes, buf);
234                    }
235                    SequencedSegment::Fin => {
236                        self.is_closed = true;
237                    }
238                }
239
240                Poll::Ready(Ok(()))
241            }
242            None => Poll::Ready(Err(io::Error::new(
243                io::ErrorKind::ConnectionReset,
244                "Connection reset",
245            ))),
246        }
247    }
248
249    /// Put bytes in `buf` based on the minimum of `avail` and its remaining
250    /// capacity.
251    ///
252    /// Returns an optional `Bytes` containing any remainder of `avail` that was
253    /// not consumed.
254    fn put_slice(mut avail: Bytes, buf: &mut ReadBuf) -> Option<Bytes> {
255        let amt = std::cmp::min(avail.len(), buf.remaining());
256
257        buf.put_slice(&avail[..amt]);
258        avail.advance(amt);
259
260        if avail.is_empty() {
261            None
262        } else {
263            Some(avail)
264        }
265    }
266
267    pub(crate) fn poll_peek(
268        &mut self,
269        cx: &mut Context<'_>,
270        buf: &mut ReadBuf,
271    ) -> Poll<Result<usize>> {
272        if self.is_closed || buf.capacity() == 0 {
273            return Poll::Ready(Ok(0));
274        }
275
276        // If we have buffered data, peek from it
277        if let Some(bytes) = &self.rx.buffer {
278            let len = std::cmp::min(bytes.len(), buf.remaining());
279            buf.put_slice(&bytes[..len]);
280            return Poll::Ready(Ok(len));
281        }
282
283        match ready!(self.rx.recv.poll_recv(cx)) {
284            Some(seg) => {
285                tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
286
287                match seg {
288                    SequencedSegment::Data(bytes) => {
289                        self.flow_control.release();
290                        let len = std::cmp::min(bytes.len(), buf.remaining());
291                        buf.put_slice(&bytes[..len]);
292                        self.rx.buffer = Some(bytes);
293
294                        Poll::Ready(Ok(len))
295                    }
296                    SequencedSegment::Fin => {
297                        self.is_closed = true;
298                        Poll::Ready(Ok(0))
299                    }
300                }
301            }
302            None => Poll::Ready(Err(io::Error::new(
303                io::ErrorKind::ConnectionReset,
304                "Connection reset",
305            ))),
306        }
307    }
308
309    pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
310        let mut buf = ReadBuf::new(buf);
311        poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
312    }
313}
314
315impl Debug for ReadHalf {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        f.debug_struct("ReadHalf")
318            .field("pair", &self.pair)
319            .field("is_closed", &self.is_closed)
320            .finish()
321    }
322}
323
324pub(crate) struct WriteHalf {
325    pub(crate) pair: Arc<SocketPair>,
326    /// FIN sent, closed for writes
327    is_shutdown: bool,
328    flow_control: Arc<FlowControl>,
329}
330
331impl WriteHalf {
332    fn try_write(&self, buf: &[u8]) -> Result<usize> {
333        if buf.remaining() == 0 {
334            return Ok(0);
335        }
336
337        if self.is_shutdown {
338            return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
339        }
340
341        if !self.flow_control.try_acquire() {
342            return Err(io::Error::new(
343                io::ErrorKind::WouldBlock,
344                "send buffer full",
345            ));
346        }
347
348        World::current(|world| {
349            let bytes = Bytes::copy_from_slice(buf);
350            let len = bytes.len();
351            let seq = self.seq(world)?;
352            self.send(world, Segment::Data(seq, bytes))?;
353            Ok(len)
354        })
355    }
356
357    fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
358        if self.is_shutdown {
359            return Poll::Ready(Err(io::Error::new(
360                io::ErrorKind::BrokenPipe,
361                "Broken pipe",
362            )));
363        }
364        if self.flow_control.has_credits() {
365            return Poll::Ready(Ok(()));
366        }
367        self.flow_control.register_waker(cx.waker().clone());
368        Poll::Pending
369    }
370
371    fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
372        if self.is_shutdown {
373            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
374        }
375
376        match self.try_write(buf) {
377            // If the socket is full, behave like non-blocking port, and register a waker
378            // for the socket to become writable again.
379            Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
380                self.flow_control.register_waker(cx.waker().clone());
381                Poll::Pending
382            }
383            result => Poll::Ready(result),
384        }
385    }
386
387    fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
388        if self.is_shutdown {
389            return Poll::Ready(Err(io::Error::new(
390                io::ErrorKind::NotConnected,
391                "Socket is not connected",
392            )));
393        }
394
395        let res = World::current(|world| {
396            let seq = self.seq(world)?;
397            self.send(world, Segment::Fin(seq))?;
398
399            self.is_shutdown = true;
400
401            Ok(())
402        });
403
404        Poll::Ready(res)
405    }
406
407    // If a seq is not assignable the connection has been reset by the
408    // peer.
409    fn seq(&self, world: &mut World) -> Result<u64> {
410        world
411            .current_host_mut()
412            .tcp
413            .assign_send_seq(*self.pair)
414            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"))
415    }
416
417    fn send(&self, world: &mut World, segment: Segment) -> Result<()> {
418        let message = Protocol::Tcp(segment);
419        if is_same(self.pair.local, self.pair.remote) {
420            send_loopback(self.pair.local, self.pair.remote, message);
421        } else {
422            world.send_message(self.pair.local, self.pair.remote, message)?;
423        }
424        Ok(())
425    }
426}
427
428fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
429    // Check for a runtime before spawning as this code is hit in the drop path
430    // as streams attempt to send FINs.
431    // TODO: Investigate drop ordering within the Sim to ensure things are unrolling
432    // as expected.
433    if Handle::try_current().is_err() {
434        return;
435    }
436
437    tokio::spawn(async move {
438        // FIXME: Forces delivery on the next step which better aligns with the
439        // remote networking behavior.
440        // https://github.com/tokio-rs/turmoil/issues/132
441        let tick_duration = World::current(|world| world.tick_duration);
442        sleep(tick_duration).await;
443
444        World::current(|world| {
445            if let Err(rst) =
446                world
447                    .current_host_mut()
448                    .receive_from_network(Envelope { src, dst, message })
449            {
450                _ = world.current_host_mut().receive_from_network(Envelope {
451                    src: dst,
452                    dst: src,
453                    message: rst,
454                });
455            }
456        })
457    });
458}
459
460/// Bidirectional flow control for a TCP connection.
461///
462/// Wraps both directional `FlowControl` instances so they can be
463/// managed as a single unit and inverted for the accepting side.
464#[derive(Clone, Debug)]
465pub(crate) struct BidiFlowControl {
466    write: Arc<FlowControl>,
467    read: Arc<FlowControl>,
468}
469
470impl BidiFlowControl {
471    pub(crate) fn new(capacity: usize) -> Self {
472        Self {
473            write: Arc::new(FlowControl::new(capacity)),
474            read: Arc::new(FlowControl::new(capacity)),
475        }
476    }
477
478    pub(crate) fn invert(self) -> Self {
479        Self {
480            write: self.read,
481            read: self.write,
482        }
483    }
484}
485
486/// End-to-end flow control for a single TCP stream direction.
487///
488/// Shared between the sender's `WriteHalf` and receiver's `ReadHalf` via
489/// `Arc`. Credits are initialized to the mpsc channel capacity so the
490/// invariant `credits + segments_in_flight = capacity` holds.
491pub(crate) struct FlowControl {
492    credits: AtomicUsize,
493    waker: Mutex<Option<Waker>>,
494}
495
496impl FlowControl {
497    fn new(capacity: usize) -> Self {
498        Self {
499            credits: AtomicUsize::new(capacity),
500            waker: Mutex::new(None),
501        }
502    }
503
504    fn try_acquire(&self) -> bool {
505        self.credits
506            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
507            .is_ok()
508    }
509
510    fn release(&self) {
511        self.credits.fetch_add(1, Ordering::Release);
512        if let Some(waker) = self.waker.lock().unwrap().take() {
513            waker.wake();
514        }
515    }
516
517    fn register_waker(&self, waker: Waker) {
518        *self.waker.lock().unwrap() = Some(waker);
519    }
520
521    fn has_credits(&self) -> bool {
522        self.credits.load(Ordering::Acquire) > 0
523    }
524}
525
526impl Debug for FlowControl {
527    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
528        f.debug_struct("FlowControl")
529            .field("credits", &self.credits.load(Ordering::Relaxed))
530            .finish()
531    }
532}
533
534impl Debug for WriteHalf {
535    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536        f.debug_struct("WriteHalf")
537            .field("pair", &self.pair)
538            .field("is_shutdown", &self.is_shutdown)
539            .finish()
540    }
541}
542
543impl AsyncRead for ReadHalf {
544    fn poll_read(
545        mut self: Pin<&mut Self>,
546        cx: &mut Context<'_>,
547        buf: &mut ReadBuf,
548    ) -> Poll<Result<()>> {
549        self.poll_read_priv(cx, buf)
550    }
551}
552
553impl AsyncRead for TcpStream {
554    fn poll_read(
555        mut self: Pin<&mut Self>,
556        cx: &mut Context<'_>,
557        buf: &mut ReadBuf,
558    ) -> Poll<Result<()>> {
559        Pin::new(&mut self.read_half).poll_read(cx, buf)
560    }
561}
562
563impl AsyncWrite for WriteHalf {
564    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
565        self.poll_write_priv(cx, buf)
566    }
567
568    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
569        Poll::Ready(Ok(()))
570    }
571
572    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
573        self.poll_shutdown_priv()
574    }
575}
576
577impl AsyncWrite for TcpStream {
578    fn poll_write(
579        mut self: Pin<&mut Self>,
580        cx: &mut Context<'_>,
581        buf: &[u8],
582    ) -> Poll<Result<usize>> {
583        Pin::new(&mut self.write_half).poll_write(cx, buf)
584    }
585
586    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
587        Pin::new(&mut self.write_half).poll_flush(cx)
588    }
589
590    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
591        Pin::new(&mut self.write_half).poll_shutdown(cx)
592    }
593}
594
595impl Drop for ReadHalf {
596    fn drop(&mut self) {
597        World::current_if_set(|world| {
598            // RFC 9293 §3.10.4: closing with unread data MUST send a RST so
599            // the peer learns data was lost. "Unread data" = application
600            // bytes received but not consumed: partial segment bytes stashed
601            // in the rx buffer, a Data segment still queued on the mpsc, or
602            // a Data segment parked in the host's reorder buffer. A queued
603            // FIN is not a reset condition — a graceful close followed by a
604            // drop should stay graceful.
605            let has_unread = !self.is_closed
606                && (self.rx.buffer.is_some()
607                    || matches!(self.rx.recv.try_recv(), Ok(SequencedSegment::Data(_)))
608                    || world.current_host_mut().tcp.has_buffered_data(*self.pair));
609
610            if has_unread {
611                let pair = *self.pair;
612                let message = Protocol::Tcp(Segment::Rst);
613                if is_same(pair.local, pair.remote) {
614                    send_loopback(pair.local, pair.remote, message);
615                } else {
616                    let _ = world.send_message(pair.local, pair.remote, message);
617                }
618                // Tear down locally so the sibling WriteHalf drop does not
619                // also send a FIN — seq() will return Err after the socket
620                // is gone.
621                world.current_host_mut().tcp.reset_stream(pair);
622                return;
623            }
624
625            world.current_host_mut().tcp.close_stream_half(*self.pair);
626        })
627    }
628}
629
630impl Drop for WriteHalf {
631    fn drop(&mut self) {
632        World::current_if_set(|world| {
633            // skip sending Fin if the write half is already shutdown
634            if !self.is_shutdown {
635                if let Ok(seq) = self.seq(world) {
636                    let _ = self.send(world, Segment::Fin(seq));
637                }
638            }
639            world.current_host_mut().tcp.close_stream_half(*self.pair);
640        })
641    }
642}