turmoil/net/tcp/
split_owned.rs1use std::{
2 error::Error,
3 fmt, io,
4 net::SocketAddr,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::net::TcpStream;
13
14use super::stream::{ReadHalf, WriteHalf};
15
16#[derive(Debug)]
18pub struct OwnedReadHalf {
19 pub(crate) inner: ReadHalf,
20}
21
22impl OwnedReadHalf {
23 pub fn local_addr(&self) -> io::Result<SocketAddr> {
25 Ok(self.inner.pair.local)
26 }
27
28 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
30 Ok(self.inner.pair.remote)
31 }
32
33 pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
37 reunite(self, other)
38 }
39
40 pub fn poll_peek(
44 mut self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 buf: &mut ReadBuf,
47 ) -> Poll<io::Result<usize>> {
48 Pin::new(&mut self.inner).poll_peek(cx, buf)
49 }
50
51 pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
57 self.inner.peek(buf).await
58 }
59}
60
61#[derive(Debug)]
70pub struct OwnedWriteHalf {
71 pub(crate) inner: WriteHalf,
72}
73
74impl OwnedWriteHalf {
75 pub fn local_addr(&self) -> io::Result<SocketAddr> {
77 Ok(self.inner.pair.local)
78 }
79
80 pub fn peer_addr(&self) -> io::Result<SocketAddr> {
82 Ok(self.inner.pair.remote)
83 }
84
85 pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
89 reunite(other, self)
90 }
91}
92
93fn reunite(read: OwnedReadHalf, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
94 if Arc::ptr_eq(&read.inner.pair, &write.inner.pair) {
95 Ok(TcpStream::reunite(read.inner, write.inner))
96 } else {
97 Err(ReuniteError(read, write))
98 }
99}
100
101#[derive(Debug)]
104pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
105
106impl fmt::Display for ReuniteError {
107 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108 write!(
109 f,
110 "tried to reunite halves that are not from the same socket"
111 )
112 }
113}
114
115impl Error for ReuniteError {}
116
117impl AsyncRead for OwnedReadHalf {
118 fn poll_read(
119 mut self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &mut ReadBuf,
122 ) -> Poll<io::Result<()>> {
123 Pin::new(&mut self.inner).poll_read(cx, buf)
124 }
125}
126
127impl AsyncWrite for OwnedWriteHalf {
128 fn poll_write(
129 mut self: Pin<&mut Self>,
130 cx: &mut Context<'_>,
131 buf: &[u8],
132 ) -> Poll<io::Result<usize>> {
133 Pin::new(&mut self.inner).poll_write(cx, buf)
134 }
135
136 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
137 Pin::new(&mut self.inner).poll_flush(cx)
138 }
139
140 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
141 Pin::new(&mut self.inner).poll_shutdown(cx)
142 }
143}