async_compression/tokio/write/generic/
encoder.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use crate::{
8    codec::Encode,
9    tokio::write::{AsyncBufWrite, BufWriter},
10    util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18    Encoding,
19    Flushing,
20    Finishing,
21    Done,
22}
23
24pin_project! {
25    #[derive(Debug)]
26    pub struct Encoder<W, E> {
27        #[pin]
28        writer: BufWriter<W>,
29        encoder: E,
30        state: State,
31    }
32}
33
34impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
35    pub fn new(writer: W, encoder: E) -> Self {
36        Self {
37            writer: BufWriter::new(writer),
38            encoder,
39            state: State::Encoding,
40        }
41    }
42
43    pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
44        Self {
45            writer: BufWriter::with_capacity(cap, writer),
46            encoder,
47            state: State::Encoding,
48        }
49    }
50}
51
52impl<W, E> Encoder<W, E> {
53    pub fn get_ref(&self) -> &W {
54        self.writer.get_ref()
55    }
56
57    pub fn get_mut(&mut self) -> &mut W {
58        self.writer.get_mut()
59    }
60
61    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
62        self.project().writer.get_pin_mut()
63    }
64
65    pub(crate) fn get_encoder_ref(&self) -> &E {
66        &self.encoder
67    }
68
69    pub fn into_inner(self) -> W {
70        self.writer.into_inner()
71    }
72}
73
74impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
75    fn do_poll_write(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        input: &mut PartialBuffer<&[u8]>,
79    ) -> Poll<io::Result<()>> {
80        let mut this = self.project();
81
82        loop {
83            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
84            let mut output = PartialBuffer::new(output);
85
86            *this.state = match this.state {
87                State::Encoding => {
88                    this.encoder.encode(input, &mut output)?;
89                    State::Encoding
90                }
91
92                // Once a flush has been started, it must be completed.
93                State::Flushing => match this.encoder.flush(&mut output)? {
94                    true => State::Encoding,
95                    false => State::Flushing,
96                },
97
98                State::Finishing | State::Done => {
99                    return Poll::Ready(Err(io::Error::other("Write after shutdown")))
100                }
101            };
102
103            let produced = output.written().len();
104            this.writer.as_mut().produce(produced);
105
106            if input.unwritten().is_empty() {
107                return Poll::Ready(Ok(()));
108            }
109        }
110    }
111
112    fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113        let mut this = self.project();
114
115        loop {
116            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
117            let mut output = PartialBuffer::new(output);
118
119            let done = match this.state {
120                State::Encoding | State::Flushing => this.encoder.flush(&mut output)?,
121
122                State::Finishing | State::Done => {
123                    return Poll::Ready(Err(io::Error::other("Flush after shutdown")))
124                }
125            };
126            *this.state = State::Flushing;
127
128            let produced = output.written().len();
129            this.writer.as_mut().produce(produced);
130
131            if done {
132                *this.state = State::Encoding;
133                return Poll::Ready(Ok(()));
134            }
135        }
136    }
137
138    fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
139        let mut this = self.project();
140
141        loop {
142            let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
143            let mut output = PartialBuffer::new(output);
144
145            *this.state = match this.state {
146                State::Encoding | State::Finishing => {
147                    if this.encoder.finish(&mut output)? {
148                        State::Done
149                    } else {
150                        State::Finishing
151                    }
152                }
153
154                // Once a flush has been started, it must be completed.
155                State::Flushing => match this.encoder.flush(&mut output)? {
156                    true => State::Finishing,
157                    false => State::Flushing,
158                },
159
160                State::Done => State::Done,
161            };
162
163            let produced = output.written().len();
164            this.writer.as_mut().produce(produced);
165
166            if let State::Done = this.state {
167                return Poll::Ready(Ok(()));
168            }
169        }
170    }
171}
172
173impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
174    fn poll_write(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<io::Result<usize>> {
179        if buf.is_empty() {
180            return Poll::Ready(Ok(0));
181        }
182
183        let mut input = PartialBuffer::new(buf);
184
185        match self.do_poll_write(cx, &mut input)? {
186            Poll::Pending if input.written().is_empty() => Poll::Pending,
187            _ => Poll::Ready(Ok(input.written().len())),
188        }
189    }
190
191    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192        ready!(self.as_mut().do_poll_flush(cx))?;
193        ready!(self.project().writer.as_mut().poll_flush(cx))?;
194        Poll::Ready(Ok(()))
195    }
196
197    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198        ready!(self.as_mut().do_poll_shutdown(cx))?;
199        ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
200        Poll::Ready(Ok(()))
201    }
202}
203
204impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
205    fn poll_read(
206        self: Pin<&mut Self>,
207        cx: &mut Context<'_>,
208        buf: &mut ReadBuf<'_>,
209    ) -> Poll<io::Result<()>> {
210        self.get_pin_mut().poll_read(cx, buf)
211    }
212}
213
214impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
215    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
216        self.get_pin_mut().poll_fill_buf(cx)
217    }
218
219    fn consume(self: Pin<&mut Self>, amt: usize) {
220        self.get_pin_mut().consume(amt)
221    }
222}