asynchronous_codec/
framed_write.rs

1use super::fuse::Fuse;
2use super::Encoder;
3use bytes::{Buf, BytesMut};
4use futures_sink::Sink;
5use futures_util::io::{AsyncRead, AsyncWrite};
6use futures_util::ready;
7use pin_project_lite::pin_project;
8use std::io::{Error, ErrorKind};
9use std::marker::Unpin;
10use std::ops::{Deref, DerefMut};
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14pin_project! {
15    /// A `Sink` of frames encoded to an `AsyncWrite`.
16    ///
17    /// # Example
18    /// ```
19    /// use bytes::Bytes;
20    /// use asynchronous_codec::{FramedWrite, BytesCodec};
21    /// use futures::SinkExt;
22    ///
23    /// # futures::executor::block_on(async move {
24    /// let mut buf = Vec::new();
25    /// let mut framed = FramedWrite::new(&mut buf, BytesCodec {});
26    ///
27    /// let bytes = Bytes::from("Hello World!");
28    /// framed.send(bytes.clone()).await?;
29    ///
30    /// assert_eq!(&buf[..], &bytes[..]);
31    /// # Ok::<_, std::io::Error>(())
32    /// # }).unwrap();
33    /// ```
34    #[derive(Debug)]
35    pub struct FramedWrite<T, E> {
36        #[pin]
37        inner: FramedWrite2<Fuse<T, E>>,
38    }
39}
40
41impl<T, E> FramedWrite<T, E>
42where
43    T: AsyncWrite,
44    E: Encoder,
45{
46    /// Creates a new `FramedWrite` transport with the given `Encoder`.
47    pub fn new(inner: T, encoder: E) -> Self {
48        Self {
49            inner: framed_write_2(Fuse::new(inner, encoder), None),
50        }
51    }
52
53    /// Creates a new `FramedWrite` from [`FramedWriteParts`].
54    ///
55    /// See also [`FramedWrite::into_parts`].
56    pub fn from_parts(
57        FramedWriteParts {
58            io,
59            encoder,
60            buffer,
61            ..
62        }: FramedWriteParts<T, E>,
63    ) -> Self {
64        Self {
65            inner: framed_write_2(Fuse::new(io, encoder), Some(buffer)),
66        }
67    }
68
69    /// High-water mark for writes, in bytes
70    ///
71    /// The send *high-water mark* prevents the `FramedWrite`
72    /// from accepting additional messages to send when its
73    /// buffer exceeds this length, in bytes. Attempts to enqueue
74    /// additional messages will be deferred until progress is
75    /// made on the underlying `AsyncWrite`. This applies
76    /// back-pressure on fast senders and prevents unbounded
77    /// buffer growth.
78    ///
79    /// See [`set_send_high_water_mark()`](#method.set_send_high_water_mark).
80    pub fn send_high_water_mark(&self) -> usize {
81        self.inner.high_water_mark
82    }
83
84    /// Sets high-water mark for writes, in bytes
85    ///
86    /// The send *high-water mark* prevents the `FramedWrite`
87    /// from accepting additional messages to send when its
88    /// buffer exceeds this length, in bytes. Attempts to enqueue
89    /// additional messages will be deferred until progress is
90    /// made on the underlying `AsyncWrite`. This applies
91    /// back-pressure on fast senders and prevents unbounded
92    /// buffer growth.
93    ///
94    /// The default high-water mark is 2^17 bytes. Applications
95    /// which desire low latency may wish to reduce this value.
96    /// There is little point to increasing this value beyond
97    /// your socket's `SO_SNDBUF` size. On linux, this defaults
98    /// to 212992 bytes but is user-adjustable.
99    pub fn set_send_high_water_mark(&mut self, hwm: usize) {
100        self.inner.high_water_mark = hwm;
101    }
102
103    /// Consumes the `FramedWrite`, returning its parts such that
104    /// a new `FramedWrite` may be constructed, possibly with a different encoder.
105    ///
106    /// See also [`FramedWrite::from_parts`].
107    pub fn into_parts(self) -> FramedWriteParts<T, E> {
108        let (fuse, buffer) = self.inner.into_parts();
109        FramedWriteParts {
110            io: fuse.t,
111            encoder: fuse.u,
112            buffer,
113            _priv: (),
114        }
115    }
116
117    /// Consumes the `FramedWrite`, returning its underlying I/O stream.
118    ///
119    /// Note that data that has already been written but not yet flushed
120    /// is dropped. To retain any such potentially buffered data, use
121    /// [`FramedWrite::into_parts()`].
122    pub fn into_inner(self) -> T {
123        self.into_parts().io
124    }
125
126    /// Returns a reference to the underlying encoder.
127    ///
128    /// Note that care should be taken to not tamper with the underlying encoder
129    /// as it may corrupt the stream of frames otherwise being worked with.
130    pub fn encoder(&self) -> &E {
131        &self.inner.u
132    }
133
134    /// Returns a mutable reference to the underlying encoder.
135    ///
136    /// Note that care should be taken to not tamper with the underlying encoder
137    /// as it may corrupt the stream of frames otherwise being worked with.
138    pub fn encoder_mut(&mut self) -> &mut E {
139        &mut self.inner.u
140    }
141}
142
143impl<T, E> Sink<E::Item> for FramedWrite<T, E>
144where
145    T: AsyncWrite + Unpin,
146    E: Encoder,
147{
148    type Error = E::Error;
149
150    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
151        self.project().inner.poll_ready(cx)
152    }
153    fn start_send(self: Pin<&mut Self>, item: E::Item) -> Result<(), Self::Error> {
154        self.project().inner.start_send(item)
155    }
156    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
157        self.project().inner.poll_flush(cx)
158    }
159    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
160        self.project().inner.poll_close(cx)
161    }
162}
163
164impl<T, E> Deref for FramedWrite<T, E> {
165    type Target = T;
166
167    fn deref(&self) -> &T {
168        &self.inner
169    }
170}
171
172impl<T, E> DerefMut for FramedWrite<T, E> {
173    fn deref_mut(&mut self) -> &mut T {
174        &mut self.inner
175    }
176}
177
178pin_project! {
179    #[derive(Debug)]
180    pub struct FramedWrite2<T> {
181        #[pin]
182        pub inner: T,
183        pub high_water_mark: usize,
184        buffer: BytesMut,
185    }
186}
187
188impl<T> Deref for FramedWrite2<T> {
189    type Target = T;
190
191    fn deref(&self) -> &T {
192        &self.inner
193    }
194}
195
196impl<T> DerefMut for FramedWrite2<T> {
197    fn deref_mut(&mut self) -> &mut T {
198        &mut self.inner
199    }
200}
201
202// 2^17 bytes, which is slightly over 60% of the default
203// TCP send buffer size (SO_SNDBUF)
204const DEFAULT_SEND_HIGH_WATER_MARK: usize = 131072;
205
206pub fn framed_write_2<T>(inner: T, buffer: Option<BytesMut>) -> FramedWrite2<T> {
207    FramedWrite2 {
208        inner,
209        high_water_mark: DEFAULT_SEND_HIGH_WATER_MARK,
210        buffer: buffer.unwrap_or_else(|| BytesMut::with_capacity(1028 * 8)),
211    }
212}
213
214impl<T: AsyncRead + Unpin> AsyncRead for FramedWrite2<T> {
215    fn poll_read(
216        self: Pin<&mut Self>,
217        cx: &mut Context<'_>,
218        buf: &mut [u8],
219    ) -> Poll<Result<usize, Error>> {
220        self.project().inner.poll_read(cx, buf)
221    }
222}
223
224impl<T> Sink<T::Item> for FramedWrite2<T>
225where
226    T: AsyncWrite + Encoder + Unpin,
227{
228    type Error = T::Error;
229
230    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
231        let this = &mut *self;
232        while this.buffer.len() >= this.high_water_mark {
233            let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
234
235            if num_write == 0 {
236                return Poll::Ready(Err(err_eof().into()));
237            }
238
239            this.buffer.advance(num_write);
240        }
241
242        Poll::Ready(Ok(()))
243    }
244    fn start_send(mut self: Pin<&mut Self>, item: T::Item) -> Result<(), Self::Error> {
245        let this = &mut *self;
246        this.inner.encode(item, &mut this.buffer)
247    }
248    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
249        let mut this = self.project();
250
251        while !this.buffer.is_empty() {
252            let num_write = ready!(Pin::new(&mut this.inner).poll_write(cx, &this.buffer))?;
253
254            if num_write == 0 {
255                return Poll::Ready(Err(err_eof().into()));
256            }
257
258            this.buffer.advance(num_write);
259        }
260
261        this.inner.poll_flush(cx).map_err(Into::into)
262    }
263    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
264        ready!(self.as_mut().poll_flush(cx))?;
265        self.project().inner.poll_close(cx).map_err(Into::into)
266    }
267}
268
269impl<T> FramedWrite2<T> {
270    pub fn into_parts(self) -> (T, BytesMut) {
271        (self.inner, self.buffer)
272    }
273}
274
275fn err_eof() -> Error {
276    Error::new(ErrorKind::UnexpectedEof, "End of file")
277}
278
279/// The parts obtained from [`FramedWrite::into_parts`].
280pub struct FramedWriteParts<T, E> {
281    /// The underlying I/O stream.
282    pub io: T,
283    /// The frame encoder.
284    pub encoder: E,
285    /// The framed data that has been buffered but not yet flushed to `io`.
286    pub buffer: BytesMut,
287    /// Keep the constructor private.
288    _priv: (),
289}
290
291impl<T, E> FramedWriteParts<T, E> {
292    /// Changes the encoder used in `FramedWriteParts`.
293    pub fn map_encoder<G, F>(self, f: F) -> FramedWriteParts<T, G>
294    where
295        G: Encoder,
296        F: FnOnce(E) -> G,
297    {
298        FramedWriteParts {
299            io: self.io,
300            encoder: f(self.encoder),
301            buffer: self.buffer,
302            _priv: (),
303        }
304    }
305}