turmoil/net/tcp/
split_owned.rs

1use std::{
2    error::Error,
3    fmt, io,
4    net::SocketAddr,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::net::TcpStream;
13
14use super::stream::{ReadHalf, WriteHalf};
15
16/// Owned read half of a `TcpStream`, created by `into_split`.
17#[derive(Debug)]
18pub struct OwnedReadHalf {
19    pub(crate) inner: ReadHalf,
20}
21
22impl OwnedReadHalf {
23    /// Returns the local address that this stream is bound to.
24    pub fn local_addr(&self) -> io::Result<SocketAddr> {
25        Ok(self.inner.pair.local)
26    }
27
28    /// Returns the remote address that this stream is connected to.
29    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
30        Ok(self.inner.pair.remote)
31    }
32
33    /// Attempts to put the two halves of a `TcpStream` back together and
34    /// recover the original socket. Succeeds only if the two halves
35    /// originated from the same call to `into_split`.
36    pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
37        reunite(self, other)
38    }
39
40    /// Attempts to receive data on the socket, without removing that data from
41    /// the queue, registering the current task for wakeup if data is not yet
42    /// available.
43    pub fn poll_peek(
44        mut self: Pin<&mut Self>,
45        cx: &mut Context<'_>,
46        buf: &mut ReadBuf,
47    ) -> Poll<io::Result<usize>> {
48        Pin::new(&mut self.inner).poll_peek(cx, buf)
49    }
50
51    /// Receives data on the socket from the remote address to which it is
52    /// connected, without removing that data from the queue. On success,
53    /// returns the number of bytes peeked.
54    ///
55    /// Successive calls return the same data.
56    pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
57        self.inner.peek(buf).await
58    }
59}
60
61/// Owned write half of a `TcpStream`, created by `into_split`.
62///
63/// Note that in the [`AsyncWrite`] implementation of this type, [`poll_shutdown`] will
64/// shut down the TCP stream in the write direction. Dropping the write half
65/// will also shut down the write half of the TCP stream.
66///
67/// [`AsyncWrite`]: trait@tokio::io::AsyncWrite
68/// [`poll_shutdown`]: fn@tokio::io::AsyncWrite::poll_shutdown
69#[derive(Debug)]
70pub struct OwnedWriteHalf {
71    pub(crate) inner: WriteHalf,
72}
73
74impl OwnedWriteHalf {
75    /// Returns the local address that this stream is bound to.
76    pub fn local_addr(&self) -> io::Result<SocketAddr> {
77        Ok(self.inner.pair.local)
78    }
79
80    /// Returns the remote address that this stream is connected to.
81    pub fn peer_addr(&self) -> io::Result<SocketAddr> {
82        Ok(self.inner.pair.remote)
83    }
84
85    /// Attempts to put the two halves of a `TcpStream` back together and
86    /// recover the original socket. Succeeds only if the two halves
87    /// originated from the same call to `into_split`.
88    pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
89        reunite(other, self)
90    }
91}
92
93fn reunite(read: OwnedReadHalf, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
94    if Arc::ptr_eq(&read.inner.pair, &write.inner.pair) {
95        Ok(TcpStream::reunite(read.inner, write.inner))
96    } else {
97        Err(ReuniteError(read, write))
98    }
99}
100
101/// Error indicating that two halves were not from the same socket, and thus could
102/// not be reunited.
103#[derive(Debug)]
104pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
105
106impl fmt::Display for ReuniteError {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        write!(
109            f,
110            "tried to reunite halves that are not from the same socket"
111        )
112    }
113}
114
115impl Error for ReuniteError {}
116
117impl AsyncRead for OwnedReadHalf {
118    fn poll_read(
119        mut self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &mut ReadBuf,
122    ) -> Poll<io::Result<()>> {
123        Pin::new(&mut self.inner).poll_read(cx, buf)
124    }
125}
126
127impl AsyncWrite for OwnedWriteHalf {
128    fn poll_write(
129        mut self: Pin<&mut Self>,
130        cx: &mut Context<'_>,
131        buf: &[u8],
132    ) -> Poll<io::Result<usize>> {
133        Pin::new(&mut self.inner).poll_write(cx, buf)
134    }
135
136    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137        Pin::new(&mut self.inner).poll_flush(cx)
138    }
139
140    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
141        Pin::new(&mut self.inner).poll_shutdown(cx)
142    }
143}