1use crate::util::poll_proceed;
4
5use bytes::Buf;
6use bytes::BytesMut;
7use futures_core::ready;
8use std::io::Error as IoError;
9use std::io::ErrorKind as IoErrorKind;
10use std::io::IoSlice;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll, Waker};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16type IoResult<T> = Result<T, IoError>;
17
18const CLOSED_ERROR_MSG: &str = "simplex has been closed";
19
20#[derive(Debug)]
21struct Inner {
22 backpressure_boundary: usize,
24
25 is_closed: bool,
27
28 receiver_waker: Option<Waker>,
30
31 sender_waker: Option<Waker>,
33
34 buf: BytesMut,
36}
37
38impl Inner {
39 fn with_capacity(capacity: usize) -> Self {
40 Self {
41 backpressure_boundary: capacity,
42 is_closed: false,
43 receiver_waker: None,
44 sender_waker: None,
45 buf: BytesMut::with_capacity(capacity),
46 }
47 }
48
49 fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
50 match self.receiver_waker.as_mut() {
51 Some(old) if old.will_wake(waker) => None,
52 _ => self.receiver_waker.replace(waker.clone()),
53 }
54 }
55
56 fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
57 match self.sender_waker.as_mut() {
58 Some(old) if old.will_wake(waker) => None,
59 _ => self.sender_waker.replace(waker.clone()),
60 }
61 }
62
63 fn take_receiver_waker(&mut self) -> Option<Waker> {
64 self.receiver_waker.take()
65 }
66
67 fn take_sender_waker(&mut self) -> Option<Waker> {
68 self.sender_waker.take()
69 }
70
71 fn is_closed(&self) -> bool {
72 self.is_closed
73 }
74
75 fn close_receiver(&mut self) -> Option<Waker> {
76 self.is_closed = true;
77 self.take_sender_waker()
78 }
79
80 fn close_sender(&mut self) -> Option<Waker> {
81 self.is_closed = true;
82 self.take_receiver_waker()
83 }
84}
85
86#[derive(Debug)]
99pub struct Receiver {
100 inner: Arc<Mutex<Inner>>,
101}
102
103impl Drop for Receiver {
104 fn drop(&mut self) {
106 let maybe_waker = {
107 let mut inner = self.inner.lock().unwrap();
108 inner.close_receiver()
109 };
110
111 if let Some(waker) = maybe_waker {
112 waker.wake();
113 }
114 }
115}
116
117impl AsyncRead for Receiver {
118 fn poll_read(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 buf: &mut ReadBuf<'_>,
122 ) -> Poll<IoResult<()>> {
123 let coop = ready!(poll_proceed(cx));
124
125 let mut inner = self.inner.lock().unwrap();
126
127 let to_read = buf.remaining().min(inner.buf.remaining());
128 if to_read == 0 {
129 if inner.is_closed() || buf.remaining() == 0 {
130 return Poll::Ready(Ok(()));
131 }
132
133 let old_waker = inner.register_receiver_waker(cx.waker());
134 let maybe_waker = inner.take_sender_waker();
135
136 drop(inner);
138 drop(old_waker);
139 if let Some(waker) = maybe_waker {
140 waker.wake();
141 }
142 return Poll::Pending;
143 }
144
145 coop.made_progress();
147
148 buf.put_slice(&inner.buf[..to_read]);
149 inner.buf.advance(to_read);
150
151 let waker = inner.take_sender_waker();
152 drop(inner); if let Some(waker) = waker {
154 waker.wake();
155 }
156
157 Poll::Ready(Ok(()))
158 }
159}
160
161#[derive(Debug)]
174pub struct Sender {
175 inner: Arc<Mutex<Inner>>,
176}
177
178impl Drop for Sender {
179 fn drop(&mut self) {
181 let maybe_waker = {
182 let mut inner = self.inner.lock().unwrap();
183 inner.close_sender()
184 };
185
186 if let Some(waker) = maybe_waker {
187 waker.wake();
188 }
189 }
190}
191
192impl AsyncWrite for Sender {
193 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
198 let coop = ready!(poll_proceed(cx));
199
200 let mut inner = self.inner.lock().unwrap();
201
202 if inner.is_closed() {
203 return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
204 }
205
206 let free = inner
207 .backpressure_boundary
208 .checked_sub(inner.buf.len())
209 .expect("backpressure boundary overflow");
210 let to_write = buf.len().min(free);
211 if to_write == 0 {
212 if buf.is_empty() {
213 return Poll::Ready(Ok(0));
214 }
215
216 let old_waker = inner.register_sender_waker(cx.waker());
217 let waker = inner.take_receiver_waker();
218
219 drop(inner);
221 drop(old_waker);
222 if let Some(waker) = waker {
223 waker.wake();
224 }
225
226 return Poll::Pending;
227 }
228
229 coop.made_progress();
231
232 inner.buf.extend_from_slice(&buf[..to_write]);
233
234 let waker = inner.take_receiver_waker();
235 drop(inner); if let Some(waker) = waker {
237 waker.wake();
238 }
239
240 Poll::Ready(Ok(to_write))
241 }
242
243 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
248 let inner = self.inner.lock().unwrap();
249 if inner.is_closed() {
250 Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
251 } else {
252 Poll::Ready(Ok(()))
253 }
254 }
255
256 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
263 let maybe_waker = {
264 let mut inner = self.inner.lock().unwrap();
265 inner.close_sender()
266 };
267
268 if let Some(waker) = maybe_waker {
269 waker.wake();
270 }
271
272 Poll::Ready(Ok(()))
273 }
274
275 fn is_write_vectored(&self) -> bool {
276 true
277 }
278
279 fn poll_write_vectored(
280 self: Pin<&mut Self>,
281 cx: &mut Context<'_>,
282 bufs: &[IoSlice<'_>],
283 ) -> Poll<Result<usize, IoError>> {
284 let coop = ready!(poll_proceed(cx));
285
286 let mut inner = self.inner.lock().unwrap();
287 if inner.is_closed() {
288 return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
289 }
290
291 let free = inner
292 .backpressure_boundary
293 .checked_sub(inner.buf.len())
294 .expect("backpressure boundary overflow");
295 if free == 0 {
296 let old_waker = inner.register_sender_waker(cx.waker());
297 let maybe_waker = inner.take_receiver_waker();
298
299 drop(inner);
301 drop(old_waker);
302 if let Some(waker) = maybe_waker {
303 waker.wake();
304 }
305
306 return Poll::Pending;
307 }
308
309 coop.made_progress();
311
312 let mut rem = free;
313 for buf in bufs {
314 if rem == 0 {
315 break;
316 }
317
318 let to_write = buf.len().min(rem);
319 if to_write == 0 {
320 assert_ne!(rem, 0);
321 assert_eq!(buf.len(), 0);
322 continue;
323 }
324
325 inner.buf.extend_from_slice(&buf[..to_write]);
326 rem -= to_write;
327 }
328
329 let waker = inner.take_receiver_waker();
330 drop(inner); if let Some(waker) = waker {
332 waker.wake();
333 }
334
335 Poll::Ready(Ok(free - rem))
336 }
337}
338
339pub fn new(capacity: usize) -> (Sender, Receiver) {
349 assert_ne!(capacity, 0, "capacity must be greater than zero");
350
351 let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
352 let tx = Sender {
353 inner: Arc::clone(&inner),
354 };
355 let rx = Receiver { inner };
356 (tx, rx)
357}