1use bytes::{Buf, Bytes};
2use std::future::poll_fn;
3use std::{
4 fmt::Debug,
5 io::{self, Error, Result},
6 net::SocketAddr,
7 pin::Pin,
8 sync::Arc,
9 task::{ready, Context, Poll},
10};
11use tokio::{
12 io::{AsyncRead, AsyncWrite, ReadBuf},
13 runtime::Handle,
14 sync::{mpsc, oneshot},
15 time::sleep,
16};
17
18use crate::{
19 envelope::{Envelope, Protocol, Segment, Syn},
20 host::is_same,
21 host::SequencedSegment,
22 net::SocketPair,
23 world::World,
24 ToSocketAddrs, TRACING_TARGET,
25};
26
27use super::split_owned::{OwnedReadHalf, OwnedWriteHalf};
28
29#[derive(Debug)]
33pub struct TcpStream {
34 read_half: ReadHalf,
35 write_half: WriteHalf,
36}
37
38impl TcpStream {
39 pub(crate) fn new(pair: SocketPair, receiver: mpsc::Receiver<SequencedSegment>) -> Self {
40 let pair = Arc::new(pair);
41 let read_half = ReadHalf {
42 pair: pair.clone(),
43 rx: Rx {
44 recv: receiver,
45 buffer: None,
46 },
47 is_closed: false,
48 };
49
50 let write_half = WriteHalf {
51 pair,
52 is_shutdown: false,
53 };
54
55 Self {
56 read_half,
57 write_half,
58 }
59 }
60
61 pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<TcpStream> {
63 let (ack, syn_ack) = oneshot::channel();
64
65 let (pair, rx) = World::current(|world| {
66 let dst = addr.to_socket_addr(&world.dns);
67
68 let host = world.current_host_mut();
69 let mut local_addr = SocketAddr::new(host.addr, host.assign_ephemeral_port());
70 if dst.ip().is_loopback() {
71 local_addr.set_ip(dst.ip());
72 }
73
74 let pair = SocketPair::new(local_addr, dst);
75 let rx = host.tcp.new_stream(pair);
76
77 let syn = Protocol::Tcp(Segment::Syn(Syn { ack }));
78 if !is_same(local_addr, dst) {
79 world.send_message(local_addr, dst, syn)?;
80 } else {
81 send_loopback(local_addr, dst, syn);
82 };
83
84 Ok::<_, Error>((pair, rx))
85 })?;
86
87 syn_ack.await.map_err(|_| {
88 io::Error::new(io::ErrorKind::ConnectionRefused, pair.remote.to_string())
89 })?;
90
91 tracing::trace!(target: TRACING_TARGET, src = ?pair.remote, dst = ?pair.local, protocol = %"TCP SYN-ACK", "Recv");
92
93 Ok(TcpStream::new(pair, rx))
94 }
95
96 pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
110 self.write_half.try_write(buf)
111 }
112
113 pub fn local_addr(&self) -> Result<SocketAddr> {
115 Ok(self.read_half.pair.local)
116 }
117
118 pub fn peer_addr(&self) -> Result<SocketAddr> {
120 Ok(self.read_half.pair.remote)
121 }
122
123 pub(crate) fn reunite(read_half: ReadHalf, write_half: WriteHalf) -> Self {
124 Self {
125 read_half,
126 write_half,
127 }
128 }
129
130 pub async fn writable(&self) -> Result<()> {
142 Ok(())
143 }
144
145 pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
153 (
154 OwnedReadHalf {
155 inner: self.read_half,
156 },
157 OwnedWriteHalf {
158 inner: self.write_half,
159 },
160 )
161 }
162
163 pub fn set_nodelay(&self, _nodelay: bool) -> Result<()> {
166 Ok(())
167 }
168
169 pub async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
175 self.read_half.peek(buf).await
176 }
177
178 pub fn poll_peek(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<usize>> {
182 self.read_half.poll_peek(cx, buf)
183 }
184}
185
186pub(crate) struct ReadHalf {
187 pub(crate) pair: Arc<SocketPair>,
188 rx: Rx,
189 is_closed: bool,
191}
192
193struct Rx {
194 recv: mpsc::Receiver<SequencedSegment>,
195 buffer: Option<Bytes>,
200}
201
202impl ReadHalf {
203 fn poll_read_priv(&mut self, cx: &mut Context<'_>, buf: &mut ReadBuf) -> Poll<Result<()>> {
204 if self.is_closed || buf.capacity() == 0 {
205 return Poll::Ready(Ok(()));
206 }
207
208 if let Some(bytes) = self.rx.buffer.take() {
209 self.rx.buffer = Self::put_slice(bytes, buf);
210
211 return Poll::Ready(Ok(()));
212 }
213
214 match ready!(self.rx.recv.poll_recv(cx)) {
215 Some(seg) => {
216 tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Recv");
217
218 match seg {
219 SequencedSegment::Data(bytes) => {
220 self.rx.buffer = Self::put_slice(bytes, buf);
221 }
222 SequencedSegment::Fin => {
223 self.is_closed = true;
224 }
225 }
226
227 Poll::Ready(Ok(()))
228 }
229 None => Poll::Ready(Err(io::Error::new(
230 io::ErrorKind::ConnectionReset,
231 "Connection reset",
232 ))),
233 }
234 }
235
236 fn put_slice(mut avail: Bytes, buf: &mut ReadBuf) -> Option<Bytes> {
242 let amt = std::cmp::min(avail.len(), buf.remaining());
243
244 buf.put_slice(&avail[..amt]);
245 avail.advance(amt);
246
247 if avail.is_empty() {
248 None
249 } else {
250 Some(avail)
251 }
252 }
253
254 pub(crate) fn poll_peek(
255 &mut self,
256 cx: &mut Context<'_>,
257 buf: &mut ReadBuf,
258 ) -> Poll<Result<usize>> {
259 if self.is_closed || buf.capacity() == 0 {
260 return Poll::Ready(Ok(0));
261 }
262
263 if let Some(bytes) = &self.rx.buffer {
265 let len = std::cmp::min(bytes.len(), buf.remaining());
266 buf.put_slice(&bytes[..len]);
267 return Poll::Ready(Ok(len));
268 }
269
270 match ready!(self.rx.recv.poll_recv(cx)) {
271 Some(seg) => {
272 tracing::trace!(target: TRACING_TARGET, src = ?self.pair.remote, dst = ?self.pair.local, protocol = %seg, "Peek");
273
274 match seg {
275 SequencedSegment::Data(bytes) => {
276 let len = std::cmp::min(bytes.len(), buf.remaining());
277 buf.put_slice(&bytes[..len]);
278 self.rx.buffer = Some(bytes);
279
280 Poll::Ready(Ok(len))
281 }
282 SequencedSegment::Fin => {
283 self.is_closed = true;
284 Poll::Ready(Ok(0))
285 }
286 }
287 }
288 None => Poll::Ready(Err(io::Error::new(
289 io::ErrorKind::ConnectionReset,
290 "Connection reset",
291 ))),
292 }
293 }
294
295 pub(crate) async fn peek(&mut self, buf: &mut [u8]) -> Result<usize> {
296 let mut buf = ReadBuf::new(buf);
297 poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
298 }
299}
300
301impl Debug for ReadHalf {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 f.debug_struct("ReadHalf")
304 .field("pair", &self.pair)
305 .field("is_closed", &self.is_closed)
306 .finish()
307 }
308}
309
310pub(crate) struct WriteHalf {
311 pub(crate) pair: Arc<SocketPair>,
312 is_shutdown: bool,
314}
315
316impl WriteHalf {
317 fn try_write(&self, buf: &[u8]) -> Result<usize> {
318 if buf.remaining() == 0 {
319 return Ok(0);
320 }
321
322 if self.is_shutdown {
323 return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
324 }
325
326 World::current(|world| {
327 let bytes = Bytes::copy_from_slice(buf);
328 let len = bytes.len();
329
330 let seq = self.seq(world)?;
331 self.send(world, Segment::Data(seq, bytes))?;
332
333 Ok(len)
334 })
335 }
336
337 fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
338 Poll::Ready(self.try_write(buf))
339 }
340
341 fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
342 if self.is_shutdown {
343 return Poll::Ready(Err(io::Error::new(
344 io::ErrorKind::NotConnected,
345 "Socket is not connected",
346 )));
347 }
348
349 let res = World::current(|world| {
350 let seq = self.seq(world)?;
351 self.send(world, Segment::Fin(seq))?;
352
353 self.is_shutdown = true;
354
355 Ok(())
356 });
357
358 Poll::Ready(res)
359 }
360
361 fn seq(&self, world: &mut World) -> Result<u64> {
364 world
365 .current_host_mut()
366 .tcp
367 .assign_send_seq(*self.pair)
368 .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"))
369 }
370
371 fn send(&self, world: &mut World, segment: Segment) -> Result<()> {
372 let message = Protocol::Tcp(segment);
373 if is_same(self.pair.local, self.pair.remote) {
374 send_loopback(self.pair.local, self.pair.remote, message);
375 } else {
376 world.send_message(self.pair.local, self.pair.remote, message)?;
377 }
378 Ok(())
379 }
380}
381
382fn send_loopback(src: SocketAddr, dst: SocketAddr, message: Protocol) {
383 if Handle::try_current().is_err() {
388 return;
389 }
390
391 tokio::spawn(async move {
392 let tick_duration = World::current(|world| world.tick_duration);
396 sleep(tick_duration).await;
397
398 World::current(|world| {
399 if let Err(rst) =
400 world
401 .current_host_mut()
402 .receive_from_network(Envelope { src, dst, message })
403 {
404 _ = world.current_host_mut().receive_from_network(Envelope {
405 src: dst,
406 dst: src,
407 message: rst,
408 });
409 }
410 })
411 });
412}
413
414impl Debug for WriteHalf {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("WriteHalf")
417 .field("pair", &self.pair)
418 .field("is_shutdown", &self.is_shutdown)
419 .finish()
420 }
421}
422
423impl AsyncRead for ReadHalf {
424 fn poll_read(
425 mut self: Pin<&mut Self>,
426 cx: &mut Context<'_>,
427 buf: &mut ReadBuf,
428 ) -> Poll<Result<()>> {
429 self.poll_read_priv(cx, buf)
430 }
431}
432
433impl AsyncRead for TcpStream {
434 fn poll_read(
435 mut self: Pin<&mut Self>,
436 cx: &mut Context<'_>,
437 buf: &mut ReadBuf,
438 ) -> Poll<Result<()>> {
439 Pin::new(&mut self.read_half).poll_read(cx, buf)
440 }
441}
442
443impl AsyncWrite for WriteHalf {
444 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
445 self.poll_write_priv(cx, buf)
446 }
447
448 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
449 Poll::Ready(Ok(()))
450 }
451
452 fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
453 self.poll_shutdown_priv()
454 }
455}
456
457impl AsyncWrite for TcpStream {
458 fn poll_write(
459 mut self: Pin<&mut Self>,
460 cx: &mut Context<'_>,
461 buf: &[u8],
462 ) -> Poll<Result<usize>> {
463 Pin::new(&mut self.write_half).poll_write(cx, buf)
464 }
465
466 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
467 Pin::new(&mut self.write_half).poll_flush(cx)
468 }
469
470 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
471 Pin::new(&mut self.write_half).poll_shutdown(cx)
472 }
473}
474
475impl Drop for ReadHalf {
476 fn drop(&mut self) {
477 World::current_if_set(|world| {
478 world.current_host_mut().tcp.close_stream_half(*self.pair);
479 })
480 }
481}
482
483impl Drop for WriteHalf {
484 fn drop(&mut self) {
485 World::current_if_set(|world| {
486 if !self.is_shutdown {
488 if let Ok(seq) = self.seq(world) {
489 let _ = self.send(world, Segment::Fin(seq));
490 }
491 }
492 world.current_host_mut().tcp.close_stream_half(*self.pair);
493 })
494 }
495}