1use bytes::{Buf, Bytes};
2use std::future::poll_fn;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Mutex;
5use std::task::Waker;
6use std::{
7 fmt::Debug,
8 io::{self, Error, Result},
9 net::SocketAddr,
10 pin::Pin,
11 sync::Arc,
12 task::{ready, Context, Poll},
13};
14use tokio::{
15 io::{AsyncRead, AsyncWrite, ReadBuf},
16 runtime::Handle,
17 sync::{mpsc, oneshot},
18 time::sleep,
19};
20
21use crate::{
22 envelope::{Envelope, Protocol, Segment, Syn},
23 host::{is_same, SequencedSegment},
24 net::SocketPair,
25 world::World,
26 ToSocketAddrs, TRACING_TARGET,
27};
28
29use super::split_owned::{OwnedReadHalf, OwnedWriteHalf};
30
31#[derive(Debug)]
35pub struct TcpStream {
36 read_half: ReadHalf,
37 write_half: WriteHalf,
38}
39
40impl TcpStream {
41 pub(crate) fn new(
42 pair: SocketPair,
43 receiver: mpsc::Receiver<SequencedSegment>,
44 flow_control: BidiFlowControl,
45 ) -> Self {
46 let pair = Arc::new(pair);
47 let read_half = ReadHalf {
48 pair: pair.clone(),
49 rx: Rx {
50 recv: receiver,
51 buffer: None,
52 },
53 is_closed: false,
54 flow_control: flow_control.read,
55 };
56
57 let write_half = WriteHalf {
58 pair,
59 is_shutdown: false,
60 flow_control: flow_control.write,
61 };
62
63 Self {
64 read_half,
65 write_half,
66 }
67 }
68
69 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
71 let (ack, syn_ack) = oneshot::channel();
72
73 let (pair, rx, bidi) = World::current(|world| {
74 let dst = addr.to_socket_addr(&world.dns)?;
75
76 let (pair, rx, bidi) = {
77 let host = world.current_host_mut();
78 let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());
79 if dst.ip().is_loopback() {
80 local_addr.set_ip(dst.ip());
81 }
82
83 let pair = SocketPair::new(local_addr, dst);
84 let (rx, bidi) = host.tcp.new_stream(pair);
85 (pair, rx, bidi)
86 };
87
88 let syn = Protocol::Tcp(Segment::Syn(Syn { ack }));
89 if !is_same(pair.local, pair.remote) {
90 world.send_message(pair.local, pair.remote, syn)?;
91 } else {
92 send_loopback(pair.local, pair.remote, syn);
93 };
94
95 Ok::<_, Error>((pair, rx, bidi))
96 })?;
97
98 syn_ack.await.map_err(|_| {
99 io::Error::new(io::ErrorKind::ConnectionRefused, pair.remote.to_string())
100 })?;
101
102 tracing::trace!(target: TRACING_TARGET, src = ?pair.remote, dst = ?pair.local, protocol = %"TCP SYN-ACK", "Recv");
103
104 Ok(TcpStream::new(pair, rx, bidi))
105 }
106
107 pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
121 self.write_half.try_write(buf)
122 }
123
124 pub fn local_addr(&self) -> Result<SocketAddr> {
126 Ok(self.read_half.pair.local)
127 }
128
129 pub fn peer_addr(&self) -> Result<SocketAddr> {
131 Ok(self.read_half.pair.remote)
132 }
133
134 pub(crate) fn reunite(read_half: ReadHalf, write_half: WriteHalf) -> Self {
135 Self {
136 read_half,
137 write_half,
138 }
139 }
140
141 pub async fn writable(&self) -> Result<()> {
153 poll_fn(|cx| self.write_half.poll_writable(cx)).await
154 }
155
156 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
164 (
165 OwnedReadHalf {
166 inner: self.read_half,
167 },
168 OwnedWriteHalf {
169 inner: self.write_half,
170 },
171 )
172 }
173
174 pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
177 Ok(())
178 }
179
180 pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
186 self.read_half.peek(buf).await
187 }
188
189 pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
193 self.read_half.poll_peek(cx, buf)
194 }
195}
196
197pub(crate) struct ReadHalf {
198 pub(crate) pair: Arc<SocketPair>,
199 rx: Rx,
200 is_closed: bool,
202 flow_control: Arc<FlowControl>,
203}
204
205struct Rx {
206 recv: mpsc::Receiver<SequencedSegment>,
207 buffer: Option<Bytes>,
212}
213
214impl ReadHalf {
215 fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
216 if self.is_closed || buf.capacity() == 0 {
217 return Poll::Ready(Ok(()));
218 }
219
220 if let Some(bytes) = self.rx.buffer.take() {
221 self.rx.buffer = Self::put_slice(bytes, buf);
222
223 return Poll::Ready(Ok(()));
224 }
225
226 match ready!(self.rx.recv.poll_recv(cx)) {
227 Some(seg) => {
228 tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
229
230 match seg {
231 SequencedSegment::Data(bytes) => {
232 self.flow_control.release();
233 self.rx.buffer = Self::put_slice(bytes, buf);
234 }
235 SequencedSegment::Fin => {
236 self.is_closed = true;
237 }
238 }
239
240 Poll::Ready(Ok(()))
241 }
242 None => Poll::Ready(Err(io::Error::new(
243 io::ErrorKind::ConnectionReset,
244 "Connection reset",
245 ))),
246 }
247 }
248
249 fn put_slice(mut avail: Bytes, buf: &mut ReadBuf) -> Option<Bytes> {
255 let amt = std::cmp::min(avail.len(), buf.remaining());
256
257 buf.put_slice(&avail[..amt]);
258 avail.advance(amt);
259
260 if avail.is_empty() {
261 None
262 } else {
263 Some(avail)
264 }
265 }
266
267 pub(crate) fn poll_peek(
268 &mut self,
269 cx: &mut Context<'_>,
270 buf: &mut ReadBuf,
271 ) -> Poll<Result<usize>> {
272 if self.is_closed || buf.capacity() == 0 {
273 return Poll::Ready(Ok(0));
274 }
275
276 if let Some(bytes) = &self.rx.buffer {
278 let len = std::cmp::min(bytes.len(), buf.remaining());
279 buf.put_slice(&bytes[..len]);
280 return Poll::Ready(Ok(len));
281 }
282
283 match ready!(self.rx.recv.poll_recv(cx)) {
284 Some(seg) => {
285 tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
286
287 match seg {
288 SequencedSegment::Data(bytes) => {
289 self.flow_control.release();
290 let len = std::cmp::min(bytes.len(), buf.remaining());
291 buf.put_slice(&bytes[..len]);
292 self.rx.buffer = Some(bytes);
293
294 Poll::Ready(Ok(len))
295 }
296 SequencedSegment::Fin => {
297 self.is_closed = true;
298 Poll::Ready(Ok(0))
299 }
300 }
301 }
302 None => Poll::Ready(Err(io::Error::new(
303 io::ErrorKind::ConnectionReset,
304 "Connection reset",
305 ))),
306 }
307 }
308
309 pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
310 let mut buf = ReadBuf::new(buf);
311 poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
312 }
313}
314
315impl Debug for ReadHalf {
316 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317 f.debug_struct("ReadHalf")
318 .field("pair", &self.pair)
319 .field("is_closed", &self.is_closed)
320 .finish()
321 }
322}
323
324pub(crate) struct WriteHalf {
325 pub(crate) pair: Arc<SocketPair>,
326 is_shutdown: bool,
328 flow_control: Arc<FlowControl>,
329}
330
331impl WriteHalf {
332 fn try_write(&self, buf: &[u8]) -> Result<usize> {
333 if buf.remaining() == 0 {
334 return Ok(0);
335 }
336
337 if self.is_shutdown {
338 return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
339 }
340
341 if !self.flow_control.try_acquire() {
342 return Err(io::Error::new(
343 io::ErrorKind::WouldBlock,
344 "send buffer full",
345 ));
346 }
347
348 World::current(|world| {
349 let bytes = Bytes::copy_from_slice(buf);
350 let len = bytes.len();
351 let seq = self.seq(world)?;
352 self.send(world, Segment::Data(seq, bytes))?;
353 Ok(len)
354 })
355 }
356
357 fn poll_writable(&self, cx: &mut Context<'_>) -> Poll<Result<()>> {
358 if self.is_shutdown {
359 return Poll::Ready(Err(io::Error::new(
360 io::ErrorKind::BrokenPipe,
361 "Broken pipe",
362 )));
363 }
364 if self.flow_control.has_credits() {
365 return Poll::Ready(Ok(()));
366 }
367 self.flow_control.register_waker(cx.waker().clone());
368 Poll::Pending
369 }
370
371 fn poll_write_priv(&self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
372 if self.is_shutdown {
373 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
374 }
375
376 match self.try_write(buf) {
377 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
380 self.flow_control.register_waker(cx.waker().clone());
381 Poll::Pending
382 }
383 result => Poll::Ready(result),
384 }
385 }
386
387 fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
388 if self.is_shutdown {
389 return Poll::Ready(Err(io::Error::new(
390 io::ErrorKind::NotConnected,
391 "Socket is not connected",
392 )));
393 }
394
395 let res = World::current(|world| {
396 let seq = self.seq(world)?;
397 self.send(world, Segment::Fin(seq))?;
398
399 self.is_shutdown = true;
400
401 Ok(())
402 });
403
404 Poll::Ready(res)
405 }
406
407 fn seq(&self, world: &mut World) -> Result<u64> {
410 world
411 .current_host_mut()
412 .tcp
413 .assign_send_seq(*self.pair)
414 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"))
415 }
416
417 fn send(&self, world: &mut World, segment: Segment) -> Result<()> {
418 let message = Protocol::Tcp(segment);
419 if is_same(self.pair.local, self.pair.remote) {
420 send_loopback(self.pair.local, self.pair.remote, message);
421 } else {
422 world.send_message(self.pair.local, self.pair.remote, message)?;
423 }
424 Ok(())
425 }
426}
427
428fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
429 if Handle::try_current().is_err() {
434 return;
435 }
436
437 tokio::spawn(async move {
438 let tick_duration = World::current(|world| world.tick_duration);
442 sleep(tick_duration).await;
443
444 World::current(|world| {
445 if let Err(rst) =
446 world
447 .current_host_mut()
448 .receive_from_network(Envelope { src, dst, message })
449 {
450 _ = world.current_host_mut().receive_from_network(Envelope {
451 src: dst,
452 dst: src,
453 message: rst,
454 });
455 }
456 })
457 });
458}
459
460#[derive(Clone, Debug)]
465pub(crate) struct BidiFlowControl {
466 write: Arc<FlowControl>,
467 read: Arc<FlowControl>,
468}
469
470impl BidiFlowControl {
471 pub(crate) fn new(capacity: usize) -> Self {
472 Self {
473 write: Arc::new(FlowControl::new(capacity)),
474 read: Arc::new(FlowControl::new(capacity)),
475 }
476 }
477
478 pub(crate) fn invert(self) -> Self {
479 Self {
480 write: self.read,
481 read: self.write,
482 }
483 }
484}
485
486pub(crate) struct FlowControl {
492 credits: AtomicUsize,
493 waker: Mutex<Option<Waker>>,
494}
495
496impl FlowControl {
497 fn new(capacity: usize) -> Self {
498 Self {
499 credits: AtomicUsize::new(capacity),
500 waker: Mutex::new(None),
501 }
502 }
503
504 fn try_acquire(&self) -> bool {
505 self.credits
506 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| v.checked_sub(1))
507 .is_ok()
508 }
509
510 fn release(&self) {
511 self.credits.fetch_add(1, Ordering::Release);
512 if let Some(waker) = self.waker.lock().unwrap().take() {
513 waker.wake();
514 }
515 }
516
517 fn register_waker(&self, waker: Waker) {
518 *self.waker.lock().unwrap() = Some(waker);
519 }
520
521 fn has_credits(&self) -> bool {
522 self.credits.load(Ordering::Acquire) > 0
523 }
524}
525
526impl Debug for FlowControl {
527 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
528 f.debug_struct("FlowControl")
529 .field("credits", &self.credits.load(Ordering::Relaxed))
530 .finish()
531 }
532}
533
534impl Debug for WriteHalf {
535 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
536 f.debug_struct("WriteHalf")
537 .field("pair", &self.pair)
538 .field("is_shutdown", &self.is_shutdown)
539 .finish()
540 }
541}
542
543impl AsyncRead for ReadHalf {
544 fn poll_read(
545 mut self: Pin<&mut Self>,
546 cx: &mut Context<'_>,
547 buf: &mut ReadBuf,
548 ) -> Poll<Result<()>> {
549 self.poll_read_priv(cx, buf)
550 }
551}
552
553impl AsyncRead for TcpStream {
554 fn poll_read(
555 mut self: Pin<&mut Self>,
556 cx: &mut Context<'_>,
557 buf: &mut ReadBuf,
558 ) -> Poll<Result<()>> {
559 Pin::new(&mut self.read_half).poll_read(cx, buf)
560 }
561}
562
563impl AsyncWrite for WriteHalf {
564 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
565 self.poll_write_priv(cx, buf)
566 }
567
568 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
569 Poll::Ready(Ok(()))
570 }
571
572 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
573 self.poll_shutdown_priv()
574 }
575}
576
577impl AsyncWrite for TcpStream {
578 fn poll_write(
579 mut self: Pin<&mut Self>,
580 cx: &mut Context<'_>,
581 buf: &[u8],
582 ) -> Poll<Result<usize>> {
583 Pin::new(&mut self.write_half).poll_write(cx, buf)
584 }
585
586 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
587 Pin::new(&mut self.write_half).poll_flush(cx)
588 }
589
590 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
591 Pin::new(&mut self.write_half).poll_shutdown(cx)
592 }
593}
594
595impl Drop for ReadHalf {
596 fn drop(&mut self) {
597 World::current_if_set(|world| {
598 let has_unread = !self.is_closed
606 && (self.rx.buffer.is_some()
607 || matches!(self.rx.recv.try_recv(), Ok(SequencedSegment::Data(_)))
608 || world.current_host_mut().tcp.has_buffered_data(*self.pair));
609
610 if has_unread {
611 let pair = *self.pair;
612 let message = Protocol::Tcp(Segment::Rst);
613 if is_same(pair.local, pair.remote) {
614 send_loopback(pair.local, pair.remote, message);
615 } else {
616 let _ = world.send_message(pair.local, pair.remote, message);
617 }
618 world.current_host_mut().tcp.reset_stream(pair);
622 return;
623 }
624
625 world.current_host_mut().tcp.close_stream_half(*self.pair);
626 })
627 }
628}
629
630impl Drop for WriteHalf {
631 fn drop(&mut self) {
632 World::current_if_set(|world| {
633 if !self.is_shutdown {
635 if let Ok(seq) = self.seq(world) {
636 let _ = self.send(world, Segment::Fin(seq));
637 }
638 }
639 world.current_host_mut().tcp.close_stream_half(*self.pair);
640 })
641 }
642}