tokio_tungstenite/
compat.rs

1use log::*;
2use std::{
3    io::{Read, Write},
4    pin::Pin,
5    task::{Context, Poll},
6};
7
8use futures_util::task;
9use std::sync::Arc;
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tungstenite::Error as WsError;
12
13pub(crate) enum ContextWaker {
14    Read,
15    Write,
16}
17
18#[derive(Debug)]
19pub(crate) struct AllowStd<S> {
20    inner: S,
21    // We have the problem that external read operations (i.e. the Stream impl)
22    // can trigger both read (AsyncRead) and write (AsyncWrite) operations on
23    // the underyling stream. At the same time write operations (i.e. the Sink
24    // impl) can trigger write operations (AsyncWrite) too.
25    // Both the Stream and the Sink can be used on two different tasks, but it
26    // is required that AsyncRead and AsyncWrite are only ever used by a single
27    // task (or better: with a single waker) at a time.
28    //
29    // Doing otherwise would cause only the latest waker to be remembered, so
30    // in our case either the Stream or the Sink impl would potentially wait
31    // forever to be woken up because only the other one would've been woken
32    // up.
33    //
34    // To solve this we implement a waker proxy that has two slots (one for
35    // read, one for write) to store wakers. One waker proxy is always passed
36    // to the AsyncRead, the other to AsyncWrite so that they will only ever
37    // have to store a single waker, but internally we dispatch any wakeups to
38    // up to two actual wakers (one from the Sink impl and one from the Stream
39    // impl).
40    //
41    // write_waker_proxy is always used for AsyncWrite, read_waker_proxy for
42    // AsyncRead. The read_waker slots of both are used for the Stream impl
43    // (and handshaking), the write_waker slots for the Sink impl.
44    write_waker_proxy: Arc<WakerProxy>,
45    read_waker_proxy: Arc<WakerProxy>,
46}
47
48// Internal trait used only in the Handshake module for registering
49// the waker for the context used during handshaking. We're using the
50// read waker slot for this, but any would do.
51//
52// Don't ever use this from multiple tasks at the same time!
53pub(crate) trait SetWaker {
54    fn set_waker(&self, waker: &task::Waker);
55}
56
57impl<S> SetWaker for AllowStd<S> {
58    fn set_waker(&self, waker: &task::Waker) {
59        self.set_waker(ContextWaker::Read, waker);
60    }
61}
62
63impl<S> AllowStd<S> {
64    pub(crate) fn new(inner: S, waker: &task::Waker) -> Self {
65        let res = Self {
66            inner,
67            write_waker_proxy: Default::default(),
68            read_waker_proxy: Default::default(),
69        };
70
71        // Register the handshake waker as read waker for both proxies,
72        // see also the SetWaker trait.
73        res.write_waker_proxy.read_waker.register(waker);
74        res.read_waker_proxy.read_waker.register(waker);
75
76        res
77    }
78
79    // Set the read or write waker for our proxies.
80    //
81    // Read: this is only supposed to be called by read (or handshake) operations, i.e. the Stream
82    // impl on the WebSocketStream.
83    // Reading can also cause writes to happen, e.g. in case of Message::Ping handling.
84    //
85    // Write: this is only supposde to be called by write operations, i.e. the Sink impl on the
86    // WebSocketStream.
87    pub(crate) fn set_waker(&self, kind: ContextWaker, waker: &task::Waker) {
88        match kind {
89            ContextWaker::Read => {
90                self.write_waker_proxy.read_waker.register(waker);
91                self.read_waker_proxy.read_waker.register(waker);
92            }
93            ContextWaker::Write => {
94                self.write_waker_proxy.write_waker.register(waker);
95                self.read_waker_proxy.write_waker.register(waker);
96            }
97        }
98    }
99}
100
101// Proxy Waker that we pass to the internal AsyncRead/Write of the
102// stream underlying the websocket. We have two slots here for the
103// actual wakers to allow external read operations to trigger both
104// reads and writes, and the same for writes.
105#[derive(Debug, Default)]
106struct WakerProxy {
107    read_waker: task::AtomicWaker,
108    write_waker: task::AtomicWaker,
109}
110
111impl task::ArcWake for WakerProxy {
112    fn wake_by_ref(arc_self: &Arc<Self>) {
113        arc_self.read_waker.wake();
114        arc_self.write_waker.wake();
115    }
116}
117
118impl<S> AllowStd<S>
119where
120    S: Unpin,
121{
122    fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
123    where
124        F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
125    {
126        trace!("{}:{} AllowStd.with_context", file!(), line!());
127        let waker = match kind {
128            ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
129            ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
130        };
131        let mut context = task::Context::from_waker(&waker);
132        f(&mut context, Pin::new(&mut self.inner))
133    }
134
135    pub(crate) fn get_mut(&mut self) -> &mut S {
136        &mut self.inner
137    }
138
139    pub(crate) fn get_ref(&self) -> &S {
140        &self.inner
141    }
142}
143
144impl<S> Read for AllowStd<S>
145where
146    S: AsyncRead + Unpin,
147{
148    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
149        trace!("{}:{} Read.read", file!(), line!());
150        let mut buf = ReadBuf::new(buf);
151        match self.with_context(ContextWaker::Read, |ctx, stream| {
152            trace!("{}:{} Read.with_context read -> poll_read", file!(), line!());
153            stream.poll_read(ctx, &mut buf)
154        }) {
155            Poll::Ready(Ok(_)) => Ok(buf.filled().len()),
156            Poll::Ready(Err(err)) => Err(err),
157            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
158        }
159    }
160}
161
162impl<S> Write for AllowStd<S>
163where
164    S: AsyncWrite + Unpin,
165{
166    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
167        trace!("{}:{} Write.write", file!(), line!());
168        match self.with_context(ContextWaker::Write, |ctx, stream| {
169            trace!("{}:{} Write.with_context write -> poll_write", file!(), line!());
170            stream.poll_write(ctx, buf)
171        }) {
172            Poll::Ready(r) => r,
173            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
174        }
175    }
176
177    fn flush(&mut self) -> std::io::Result<()> {
178        trace!("{}:{} Write.flush", file!(), line!());
179        match self.with_context(ContextWaker::Write, |ctx, stream| {
180            trace!("{}:{} Write.with_context flush -> poll_flush", file!(), line!());
181            stream.poll_flush(ctx)
182        }) {
183            Poll::Ready(r) => r,
184            Poll::Pending => Err(std::io::Error::from(std::io::ErrorKind::WouldBlock)),
185        }
186    }
187}
188
189pub(crate) fn cvt<T>(r: Result<T, WsError>) -> Poll<Result<T, WsError>> {
190    match r {
191        Ok(v) => Poll::Ready(Ok(v)),
192        Err(WsError::Io(ref e)) if e.kind() == std::io::ErrorKind::WouldBlock => {
193            trace!("WouldBlock");
194            Poll::Pending
195        }
196        Err(e) => Poll::Ready(Err(e)),
197    }
198}