turmoil/net/tcp/
stream.rs

1use bytes::{Buf, Bytes};
2use std::future::poll_fn;
3use std::{
4    fmt::Debug,
5    io::{self, Error, Result},
6    net::SocketAddr,
7    pin::Pin,
8    sync::Arc,
9    task::{ready, Context, Poll},
10};
11use tokio::{
12    io::{AsyncRead, AsyncWrite, ReadBuf},
13    runtime::Handle,
14    sync::{mpsc, oneshot},
15    time::sleep,
16};
17
18use crate::{
19    envelope::{Envelope, Protocol, Segment, Syn},
20    host::is_same,
21    host::SequencedSegment,
22    net::SocketPair,
23    world::World,
24    ToSocketAddrs, TRACING_TARGET,
25};
26
27use super::split_owned::{OwnedReadHalf, OwnedWriteHalf};
28
29/// A simulated TCP stream between a local and a remote socket.
30///
31/// All methods must be called from a host within a Turmoil simulation.
32#[derive(Debug)]
33pub struct TcpStream {
34    read_half: ReadHalf,
35    write_half: WriteHalf,
36}
37
38impl TcpStream {
39    pub(crate) fn new(pair: SocketPair, receiver: mpsc::Receiver<SequencedSegment>) -> Self {
40        let pair = Arc::new(pair);
41        let read_half = ReadHalf {
42            pair: pair.clone(),
43            rx: Rx {
44                recv: receiver,
45                buffer: None,
46            },
47            is_closed: false,
48        };
49
50        let write_half = WriteHalf {
51            pair,
52            is_shutdown: false,
53        };
54
55        Self {
56            read_half,
57            write_half,
58        }
59    }
60
61    /// Opens a TCP connection to a remote host.
62    pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
63        let (ack, syn_ack) = oneshot::channel();
64
65        let (pair, rx) = World::current(|world| {
66            let dst = addr.to_socket_addr(&world.dns);
67
68            let host = world.current_host_mut();
69            let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());
70            if dst.ip().is_loopback() {
71                local_addr.set_ip(dst.ip());
72            }
73
74            let pair = SocketPair::new(local_addr, dst);
75            let rx = host.tcp.new_stream(pair);
76
77            let syn = Protocol::Tcp(Segment::Syn(Syn { ack }));
78            if !is_same(local_addr, dst) {
79                world.send_message(local_addr, dst, syn)?;
80            } else {
81                send_loopback(local_addr, dst, syn);
82            };
83
84            Ok::<_, Error>((pair, rx))
85        })?;
86
87        syn_ack.await.map_err(|_| {
88            io::Error::new(io::ErrorKind::ConnectionRefused, pair.remote.to_string())
89        })?;
90
91        tracing::trace!(target: TRACING_TARGET, src = ?pair.remote, dst = ?pair.local, protocol = %"TCP SYN-ACK", "Recv");
92
93        Ok(TcpStream::new(pair, rx))
94    }
95
96    /// Try to write a buffer to the stream, returning how many bytes were
97    /// written.
98    ///
99    /// The function will attempt to write the entire contents of `buf`, but
100    /// only part of the buffer may be written.
101    ///
102    /// This function is usually paired with `writable()`.
103    ///
104    /// # Return
105    ///
106    /// If data is successfully written, `Ok(n)` is returned, where `n` is the
107    /// number of bytes written. If the stream is not ready to write data,
108    /// `Err(io::ErrorKind::WouldBlock)` is returned.
109    pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
110        self.write_half.try_write(buf)
111    }
112
113    /// Returns the local address that this stream is bound to.
114    pub fn local_addr(&self) -> Result<SocketAddr> {
115        Ok(self.read_half.pair.local)
116    }
117
118    /// Returns the remote address that this stream is connected to.
119    pub fn peer_addr(&self) -> Result<SocketAddr> {
120        Ok(self.read_half.pair.remote)
121    }
122
123    pub(crate) fn reunite(read_half: ReadHalf, write_half: WriteHalf) -> Self {
124        Self {
125            read_half,
126            write_half,
127        }
128    }
129
130    /// Waits for the socket to become writable.
131    ///
132    /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
133    /// paired with `try_write()`.
134    ///
135    /// # Cancel safety
136    ///
137    /// This method is cancel safe. Once a readiness event occurs, the method
138    /// will continue to return immediately until the readiness event is
139    /// consumed by an attempt to write that fails with `WouldBlock` or
140    /// `Poll::Pending`.
141    pub async fn writable(&self) -> Result<()> {
142        Ok(())
143    }
144
145    /// Splits a `TcpStream` into a read half and a write half, which can be used
146    /// to read and write the stream concurrently.
147    ///
148    /// **Note:** Dropping the write half will shut down the write half of the TCP
149    /// stream. This is equivalent to calling [`shutdown()`] on the `TcpStream`.
150    ///
151    /// [`shutdown()`]: fn@tokio::io::AsyncWriteExt::shutdown
152    pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
153        (
154            OwnedReadHalf {
155                inner: self.read_half,
156            },
157            OwnedWriteHalf {
158                inner: self.write_half,
159            },
160        )
161    }
162
163    /// Has no effect in turmoil. API parity with
164    /// https://docs.rs/tokio/latest/tokio/net/struct.TcpStream.html#method.set_nodelay
165    pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
166        Ok(())
167    }
168
169    /// Receives data on the socket from the remote address to which it is
170    /// connected, without removing that data from the queue. On success,
171    /// returns the number of bytes peeked.
172    ///
173    /// Successive calls return the same data.
174    pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
175        self.read_half.peek(buf).await
176    }
177
178    /// Attempts to receive data on the socket, without removing that data from
179    /// the queue, registering the current task for wakeup if data is not yet
180    /// available.
181    pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
182        self.read_half.poll_peek(cx, buf)
183    }
184}
185
186pub(crate) struct ReadHalf {
187    pub(crate) pair: Arc<SocketPair>,
188    rx: Rx,
189    /// FIN received, EOF for reads
190    is_closed: bool,
191}
192
193struct Rx {
194    recv: mpsc::Receiver<SequencedSegment>,
195    /// The remaining bytes of a received data segment.
196    ///
197    /// This is used to support read impls by stashing available bytes for
198    /// subsequent reads.
199    buffer: Option<Bytes>,
200}
201
202impl ReadHalf {
203    fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
204        if self.is_closed || buf.capacity() == 0 {
205            return Poll::Ready(Ok(()));
206        }
207
208        if let Some(bytes) = self.rx.buffer.take() {
209            self.rx.buffer = Self::put_slice(bytes, buf);
210
211            return Poll::Ready(Ok(()));
212        }
213
214        match ready!(self.rx.recv.poll_recv(cx)) {
215            Some(seg) => {
216                tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
217
218                match seg {
219                    SequencedSegment::Data(bytes) => {
220                        self.rx.buffer = Self::put_slice(bytes, buf);
221                    }
222                    SequencedSegment::Fin => {
223                        self.is_closed = true;
224                    }
225                }
226
227                Poll::Ready(Ok(()))
228            }
229            None => Poll::Ready(Err(io::Error::new(
230                io::ErrorKind::ConnectionReset,
231                "Connection reset",
232            ))),
233        }
234    }
235
236    /// Put bytes in `buf` based on the minimum of `avail` and its remaining
237    /// capacity.
238    ///
239    /// Returns an optional `Bytes` containing any remainder of `avail` that was
240    /// not consumed.
241    fn put_slice(mut avail: Bytes, buf: &mut ReadBuf) -> Option<Bytes> {
242        let amt = std::cmp::min(avail.len(), buf.remaining());
243
244        buf.put_slice(&avail[..amt]);
245        avail.advance(amt);
246
247        if avail.is_empty() {
248            None
249        } else {
250            Some(avail)
251        }
252    }
253
254    pub(crate) fn poll_peek(
255        &mut self,
256        cx: &mut Context<'_>,
257        buf: &mut ReadBuf,
258    ) -> Poll<Result<usize>> {
259        if self.is_closed || buf.capacity() == 0 {
260            return Poll::Ready(Ok(0));
261        }
262
263        // If we have buffered data, peek from it
264        if let Some(bytes) = &self.rx.buffer {
265            let len = std::cmp::min(bytes.len(), buf.remaining());
266            buf.put_slice(&bytes[..len]);
267            return Poll::Ready(Ok(len));
268        }
269
270        match ready!(self.rx.recv.poll_recv(cx)) {
271            Some(seg) => {
272                tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
273
274                match seg {
275                    SequencedSegment::Data(bytes) => {
276                        let len = std::cmp::min(bytes.len(), buf.remaining());
277                        buf.put_slice(&bytes[..len]);
278                        self.rx.buffer = Some(bytes);
279
280                        Poll::Ready(Ok(len))
281                    }
282                    SequencedSegment::Fin => {
283                        self.is_closed = true;
284                        Poll::Ready(Ok(0))
285                    }
286                }
287            }
288            None => Poll::Ready(Err(io::Error::new(
289                io::ErrorKind::ConnectionReset,
290                "Connection reset",
291            ))),
292        }
293    }
294
295    pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
296        let mut buf = ReadBuf::new(buf);
297        poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
298    }
299}
300
301impl Debug for ReadHalf {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("ReadHalf")
304            .field("pair", &self.pair)
305            .field("is_closed", &self.is_closed)
306            .finish()
307    }
308}
309
310pub(crate) struct WriteHalf {
311    pub(crate) pair: Arc<SocketPair>,
312    /// FIN sent, closed for writes
313    is_shutdown: bool,
314}
315
316impl WriteHalf {
317    fn try_write(&self, buf: &[u8]) -> Result<usize> {
318        if buf.remaining() == 0 {
319            return Ok(0);
320        }
321
322        if self.is_shutdown {
323            return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
324        }
325
326        World::current(|world| {
327            let bytes = Bytes::copy_from_slice(buf);
328            let len = bytes.len();
329
330            let seq = self.seq(world)?;
331            self.send(world, Segment::Data(seq, bytes))?;
332
333            Ok(len)
334        })
335    }
336
337    fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
338        Poll::Ready(self.try_write(buf))
339    }
340
341    fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
342        if self.is_shutdown {
343            return Poll::Ready(Err(io::Error::new(
344                io::ErrorKind::NotConnected,
345                "Socket is not connected",
346            )));
347        }
348
349        let res = World::current(|world| {
350            let seq = self.seq(world)?;
351            self.send(world, Segment::Fin(seq))?;
352
353            self.is_shutdown = true;
354
355            Ok(())
356        });
357
358        Poll::Ready(res)
359    }
360
361    // If a seq is not assignable the connection has been reset by the
362    // peer.
363    fn seq(&self, world: &mut World) -> Result<u64> {
364        world
365            .current_host_mut()
366            .tcp
367            .assign_send_seq(*self.pair)
368            .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"))
369    }
370
371    fn send(&self, world: &mut World, segment: Segment) -> Result<()> {
372        let message = Protocol::Tcp(segment);
373        if is_same(self.pair.local, self.pair.remote) {
374            send_loopback(self.pair.local, self.pair.remote, message);
375        } else {
376            world.send_message(self.pair.local, self.pair.remote, message)?;
377        }
378        Ok(())
379    }
380}
381
382fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
383    // Check for a runtime before spawning as this code is hit in the drop path
384    // as streams attempt to send FINs.
385    // TODO: Investigate drop ordering within the Sim to ensure things are unrolling
386    // as expected.
387    if Handle::try_current().is_err() {
388        return;
389    }
390
391    tokio::spawn(async move {
392        // FIXME: Forces delivery on the next step which better aligns with the
393        // remote networking behavior.
394        // https://github.com/tokio-rs/turmoil/issues/132
395        let tick_duration = World::current(|world| world.tick_duration);
396        sleep(tick_duration).await;
397
398        World::current(|world| {
399            if let Err(rst) =
400                world
401                    .current_host_mut()
402                    .receive_from_network(Envelope { src, dst, message })
403            {
404                _ = world.current_host_mut().receive_from_network(Envelope {
405                    src: dst,
406                    dst: src,
407                    message: rst,
408                });
409            }
410        })
411    });
412}
413
414impl Debug for WriteHalf {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        f.debug_struct("WriteHalf")
417            .field("pair", &self.pair)
418            .field("is_shutdown", &self.is_shutdown)
419            .finish()
420    }
421}
422
423impl AsyncRead for ReadHalf {
424    fn poll_read(
425        mut self: Pin<&mut Self>,
426        cx: &mut Context<'_>,
427        buf: &mut ReadBuf,
428    ) -> Poll<Result<()>> {
429        self.poll_read_priv(cx, buf)
430    }
431}
432
433impl AsyncRead for TcpStream {
434    fn poll_read(
435        mut self: Pin<&mut Self>,
436        cx: &mut Context<'_>,
437        buf: &mut ReadBuf,
438    ) -> Poll<Result<()>> {
439        Pin::new(&mut self.read_half).poll_read(cx, buf)
440    }
441}
442
443impl AsyncWrite for WriteHalf {
444    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
445        self.poll_write_priv(cx, buf)
446    }
447
448    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
449        Poll::Ready(Ok(()))
450    }
451
452    fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
453        self.poll_shutdown_priv()
454    }
455}
456
457impl AsyncWrite for TcpStream {
458    fn poll_write(
459        mut self: Pin<&mut Self>,
460        cx: &mut Context<'_>,
461        buf: &[u8],
462    ) -> Poll<Result<usize>> {
463        Pin::new(&mut self.write_half).poll_write(cx, buf)
464    }
465
466    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
467        Pin::new(&mut self.write_half).poll_flush(cx)
468    }
469
470    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
471        Pin::new(&mut self.write_half).poll_shutdown(cx)
472    }
473}
474
475impl Drop for ReadHalf {
476    fn drop(&mut self) {
477        World::current_if_set(|world| {
478            world.current_host_mut().tcp.close_stream_half(*self.pair);
479        })
480    }
481}
482
483impl Drop for WriteHalf {
484    fn drop(&mut self) {
485        World::current_if_set(|world| {
486            // skip sending Fin if the write half is already shutdown
487            if !self.is_shutdown {
488                if let Ok(seq) = self.seq(world) {
489                    let _ = self.send(world, Segment::Fin(seq));
490                }
491            }
492            world.current_host_mut().tcp.close_stream_half(*self.pair);
493        })
494    }
495}