tokio_util/io/
simplex.rs

1//! Unidirectional byte-oriented channel.
2
3use crate::util::poll_proceed;
4
5use bytes::Buf;
6use bytes::BytesMut;
7use futures_core::ready;
8use std::io::Error as IoError;
9use std::io::ErrorKind as IoErrorKind;
10use std::io::IoSlice;
11use std::pin::Pin;
12use std::sync::{Arc, Mutex};
13use std::task::{Context, Poll, Waker};
14use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
15
16type IoResult<T> = Result<T, IoError>;
17
18const CLOSED_ERROR_MSG: &str = "simplex has been closed";
19
20#[derive(Debug)]
21struct Inner {
22    /// `poll_write` will return [`Poll::Pending`] if the backpressure boundary is reached
23    backpressure_boundary: usize,
24
25    /// either [`Sender`] or [`Receiver`] is closed
26    is_closed: bool,
27
28    /// Waker used to wake the [`Receiver`]
29    receiver_waker: Option<Waker>,
30
31    /// Waker used to wake the [`Sender`]
32    sender_waker: Option<Waker>,
33
34    /// Buffer used to read and write data
35    buf: BytesMut,
36}
37
38impl Inner {
39    fn with_capacity(capacity: usize) -> Self {
40        Self {
41            backpressure_boundary: capacity,
42            is_closed: false,
43            receiver_waker: None,
44            sender_waker: None,
45            buf: BytesMut::with_capacity(capacity),
46        }
47    }
48
49    fn register_receiver_waker(&mut self, waker: &Waker) -> Option<Waker> {
50        match self.receiver_waker.as_mut() {
51            Some(old) if old.will_wake(waker) => None,
52            _ => self.receiver_waker.replace(waker.clone()),
53        }
54    }
55
56    fn register_sender_waker(&mut self, waker: &Waker) -> Option<Waker> {
57        match self.sender_waker.as_mut() {
58            Some(old) if old.will_wake(waker) => None,
59            _ => self.sender_waker.replace(waker.clone()),
60        }
61    }
62
63    fn take_receiver_waker(&mut self) -> Option<Waker> {
64        self.receiver_waker.take()
65    }
66
67    fn take_sender_waker(&mut self) -> Option<Waker> {
68        self.sender_waker.take()
69    }
70
71    fn is_closed(&self) -> bool {
72        self.is_closed
73    }
74
75    fn close_receiver(&mut self) -> Option<Waker> {
76        self.is_closed = true;
77        self.take_sender_waker()
78    }
79
80    fn close_sender(&mut self) -> Option<Waker> {
81        self.is_closed = true;
82        self.take_receiver_waker()
83    }
84}
85
86/// Receiver of the simplex channel.
87///
88/// # Cancellation safety
89///
90/// The `Receiver` is cancel safe. If it is used as the event in a
91/// [`tokio::select!`](macro@tokio::select) statement and some other branch
92/// completes first, it is guaranteed that no bytes were received on this
93/// channel.
94///
95/// You can still read the remaining data from the buffer
96/// even if the write half has been dropped.
97/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
98#[derive(Debug)]
99pub struct Receiver {
100    inner: Arc<Mutex<Inner>>,
101}
102
103impl Drop for Receiver {
104    /// This also wakes up the [`Sender`].
105    fn drop(&mut self) {
106        let maybe_waker = {
107            let mut inner = self.inner.lock().unwrap();
108            inner.close_receiver()
109        };
110
111        if let Some(waker) = maybe_waker {
112            waker.wake();
113        }
114    }
115}
116
117impl AsyncRead for Receiver {
118    fn poll_read(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121        buf: &mut ReadBuf<'_>,
122    ) -> Poll<IoResult<()>> {
123        let coop = ready!(poll_proceed(cx));
124
125        let mut inner = self.inner.lock().unwrap();
126
127        let to_read = buf.remaining().min(inner.buf.remaining());
128        if to_read == 0 {
129            if inner.is_closed() || buf.remaining() == 0 {
130                return Poll::Ready(Ok(()));
131            }
132
133            let old_waker = inner.register_receiver_waker(cx.waker());
134            let maybe_waker = inner.take_sender_waker();
135
136            // unlock before waking up and dropping old waker
137            drop(inner);
138            drop(old_waker);
139            if let Some(waker) = maybe_waker {
140                waker.wake();
141            }
142            return Poll::Pending;
143        }
144
145        // this is to avoid starving other tasks
146        coop.made_progress();
147
148        buf.put_slice(&inner.buf[..to_read]);
149        inner.buf.advance(to_read);
150
151        let waker = inner.take_sender_waker();
152        drop(inner); // unlock before waking up
153        if let Some(waker) = waker {
154            waker.wake();
155        }
156
157        Poll::Ready(Ok(()))
158    }
159}
160
161/// Sender of the simplex channel.
162///
163/// # Cancellation safety
164///
165/// The `Sender` is cancel safe. If it is used as the event in a
166/// [`tokio::select!`](macro@tokio::select) statement and some other branch
167/// completes first, it is guaranteed that no bytes were sent on this
168/// channel.
169///
170/// # Shutdown
171///
172/// See [`Sender::poll_shutdown`].
173#[derive(Debug)]
174pub struct Sender {
175    inner: Arc<Mutex<Inner>>,
176}
177
178impl Drop for Sender {
179    /// This also wakes up the [`Receiver`].
180    fn drop(&mut self) {
181        let maybe_waker = {
182            let mut inner = self.inner.lock().unwrap();
183            inner.close_sender()
184        };
185
186        if let Some(waker) = maybe_waker {
187            waker.wake();
188        }
189    }
190}
191
192impl AsyncWrite for Sender {
193    /// # Errors
194    ///
195    /// This method will return [`IoErrorKind::BrokenPipe`]
196    /// if the channel has been closed.
197    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
198        let coop = ready!(poll_proceed(cx));
199
200        let mut inner = self.inner.lock().unwrap();
201
202        if inner.is_closed() {
203            return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
204        }
205
206        let free = inner
207            .backpressure_boundary
208            .checked_sub(inner.buf.len())
209            .expect("backpressure boundary overflow");
210        let to_write = buf.len().min(free);
211        if to_write == 0 {
212            if buf.is_empty() {
213                return Poll::Ready(Ok(0));
214            }
215
216            let old_waker = inner.register_sender_waker(cx.waker());
217            let waker = inner.take_receiver_waker();
218
219            // unlock before waking up and dropping old waker
220            drop(inner);
221            drop(old_waker);
222            if let Some(waker) = waker {
223                waker.wake();
224            }
225
226            return Poll::Pending;
227        }
228
229        // this is to avoid starving other tasks
230        coop.made_progress();
231
232        inner.buf.extend_from_slice(&buf[..to_write]);
233
234        let waker = inner.take_receiver_waker();
235        drop(inner); // unlock before waking up
236        if let Some(waker) = waker {
237            waker.wake();
238        }
239
240        Poll::Ready(Ok(to_write))
241    }
242
243    /// # Errors
244    ///
245    /// This method will return [`IoErrorKind::BrokenPipe`]
246    /// if the channel has been closed.
247    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
248        let inner = self.inner.lock().unwrap();
249        if inner.is_closed() {
250            Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)))
251        } else {
252            Poll::Ready(Ok(()))
253        }
254    }
255
256    /// After returns [`Poll::Ready`], all the following call to
257    /// [`Sender::poll_write`] and [`Sender::poll_flush`]
258    /// will return error.
259    ///
260    /// The [`Receiver`] can still be used to read remaining data
261    /// until all bytes have been consumed.
262    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
263        let maybe_waker = {
264            let mut inner = self.inner.lock().unwrap();
265            inner.close_sender()
266        };
267
268        if let Some(waker) = maybe_waker {
269            waker.wake();
270        }
271
272        Poll::Ready(Ok(()))
273    }
274
275    fn is_write_vectored(&self) -> bool {
276        true
277    }
278
279    fn poll_write_vectored(
280        self: Pin<&mut Self>,
281        cx: &mut Context<'_>,
282        bufs: &[IoSlice<'_>],
283    ) -> Poll<Result<usize, IoError>> {
284        let coop = ready!(poll_proceed(cx));
285
286        let mut inner = self.inner.lock().unwrap();
287        if inner.is_closed() {
288            return Poll::Ready(Err(IoError::new(IoErrorKind::BrokenPipe, CLOSED_ERROR_MSG)));
289        }
290
291        let free = inner
292            .backpressure_boundary
293            .checked_sub(inner.buf.len())
294            .expect("backpressure boundary overflow");
295        if free == 0 {
296            let old_waker = inner.register_sender_waker(cx.waker());
297            let maybe_waker = inner.take_receiver_waker();
298
299            // unlock before waking up and dropping old waker
300            drop(inner);
301            drop(old_waker);
302            if let Some(waker) = maybe_waker {
303                waker.wake();
304            }
305
306            return Poll::Pending;
307        }
308
309        // this is to avoid starving other tasks
310        coop.made_progress();
311
312        let mut rem = free;
313        for buf in bufs {
314            if rem == 0 {
315                break;
316            }
317
318            let to_write = buf.len().min(rem);
319            if to_write == 0 {
320                assert_ne!(rem, 0);
321                assert_eq!(buf.len(), 0);
322                continue;
323            }
324
325            inner.buf.extend_from_slice(&buf[..to_write]);
326            rem -= to_write;
327        }
328
329        let waker = inner.take_receiver_waker();
330        drop(inner); // unlock before waking up
331        if let Some(waker) = waker {
332            waker.wake();
333        }
334
335        Poll::Ready(Ok(free - rem))
336    }
337}
338
339/// Create a simplex channel.
340///
341/// The `capacity` parameter specifies the maximum number of bytes that can be
342/// stored in the channel without making the [`Sender::poll_write`]
343/// return [`Poll::Pending`].
344///
345/// # Panics
346///
347/// This function will panic if `capacity` is zero.
348pub fn new(capacity: usize) -> (Sender, Receiver) {
349    assert_ne!(capacity, 0, "capacity must be greater than zero");
350
351    let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
352    let tx = Sender {
353        inner: Arc::clone(&inner),
354    };
355    let rx = Receiver { inner };
356    (tx, rx)
357}