async_compression/tokio/write/generic/
encoder.rs1use std::{
2 io,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use crate::{
8 codec::Encode,
9 tokio::write::{AsyncBufWrite, BufWriter},
10 util::PartialBuffer,
11};
12use futures_core::ready;
13use pin_project_lite::pin_project;
14use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
15
16#[derive(Debug)]
17enum State {
18 Encoding,
19 Flushing,
20 Finishing,
21 Done,
22}
23
24pin_project! {
25 #[derive(Debug)]
26 pub struct Encoder<W, E> {
27 #[pin]
28 writer: BufWriter<W>,
29 encoder: E,
30 state: State,
31 }
32}
33
34impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
35 pub fn new(writer: W, encoder: E) -> Self {
36 Self {
37 writer: BufWriter::new(writer),
38 encoder,
39 state: State::Encoding,
40 }
41 }
42
43 pub fn with_capacity(writer: W, encoder: E, cap: usize) -> Self {
44 Self {
45 writer: BufWriter::with_capacity(cap, writer),
46 encoder,
47 state: State::Encoding,
48 }
49 }
50}
51
52impl<W, E> Encoder<W, E> {
53 pub fn get_ref(&self) -> &W {
54 self.writer.get_ref()
55 }
56
57 pub fn get_mut(&mut self) -> &mut W {
58 self.writer.get_mut()
59 }
60
61 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
62 self.project().writer.get_pin_mut()
63 }
64
65 pub(crate) fn get_encoder_ref(&self) -> &E {
66 &self.encoder
67 }
68
69 pub fn into_inner(self) -> W {
70 self.writer.into_inner()
71 }
72}
73
74impl<W: AsyncWrite, E: Encode> Encoder<W, E> {
75 fn do_poll_write(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 input: &mut PartialBuffer<&[u8]>,
79 ) -> Poll<io::Result<()>> {
80 let mut this = self.project();
81
82 loop {
83 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
84 let mut output = PartialBuffer::new(output);
85
86 *this.state = match this.state {
87 State::Encoding => {
88 this.encoder.encode(input, &mut output)?;
89 State::Encoding
90 }
91
92 State::Flushing => match this.encoder.flush(&mut output)? {
94 true => State::Encoding,
95 false => State::Flushing,
96 },
97
98 State::Finishing | State::Done => {
99 return Poll::Ready(Err(io::Error::other("Write after shutdown")))
100 }
101 };
102
103 let produced = output.written().len();
104 this.writer.as_mut().produce(produced);
105
106 if input.unwritten().is_empty() {
107 return Poll::Ready(Ok(()));
108 }
109 }
110 }
111
112 fn do_poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
113 let mut this = self.project();
114
115 loop {
116 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
117 let mut output = PartialBuffer::new(output);
118
119 let done = match this.state {
120 State::Encoding | State::Flushing => this.encoder.flush(&mut output)?,
121
122 State::Finishing | State::Done => {
123 return Poll::Ready(Err(io::Error::other("Flush after shutdown")))
124 }
125 };
126 *this.state = State::Flushing;
127
128 let produced = output.written().len();
129 this.writer.as_mut().produce(produced);
130
131 if done {
132 *this.state = State::Encoding;
133 return Poll::Ready(Ok(()));
134 }
135 }
136 }
137
138 fn do_poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
139 let mut this = self.project();
140
141 loop {
142 let output = ready!(this.writer.as_mut().poll_partial_flush_buf(cx))?;
143 let mut output = PartialBuffer::new(output);
144
145 *this.state = match this.state {
146 State::Encoding | State::Finishing => {
147 if this.encoder.finish(&mut output)? {
148 State::Done
149 } else {
150 State::Finishing
151 }
152 }
153
154 State::Flushing => match this.encoder.flush(&mut output)? {
156 true => State::Finishing,
157 false => State::Flushing,
158 },
159
160 State::Done => State::Done,
161 };
162
163 let produced = output.written().len();
164 this.writer.as_mut().produce(produced);
165
166 if let State::Done = this.state {
167 return Poll::Ready(Ok(()));
168 }
169 }
170 }
171}
172
173impl<W: AsyncWrite, E: Encode> AsyncWrite for Encoder<W, E> {
174 fn poll_write(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<io::Result<usize>> {
179 if buf.is_empty() {
180 return Poll::Ready(Ok(0));
181 }
182
183 let mut input = PartialBuffer::new(buf);
184
185 match self.do_poll_write(cx, &mut input)? {
186 Poll::Pending if input.written().is_empty() => Poll::Pending,
187 _ => Poll::Ready(Ok(input.written().len())),
188 }
189 }
190
191 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
192 ready!(self.as_mut().do_poll_flush(cx))?;
193 ready!(self.project().writer.as_mut().poll_flush(cx))?;
194 Poll::Ready(Ok(()))
195 }
196
197 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
198 ready!(self.as_mut().do_poll_shutdown(cx))?;
199 ready!(self.project().writer.as_mut().poll_shutdown(cx))?;
200 Poll::Ready(Ok(()))
201 }
202}
203
204impl<W: AsyncRead, E> AsyncRead for Encoder<W, E> {
205 fn poll_read(
206 self: Pin<&mut Self>,
207 cx: &mut Context<'_>,
208 buf: &mut ReadBuf<'_>,
209 ) -> Poll<io::Result<()>> {
210 self.get_pin_mut().poll_read(cx, buf)
211 }
212}
213
214impl<W: AsyncBufRead, E> AsyncBufRead for Encoder<W, E> {
215 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
216 self.get_pin_mut().poll_fill_buf(cx)
217 }
218
219 fn consume(self: Pin<&mut Self>, amt: usize) {
220 self.get_pin_mut().consume(amt)
221 }
222}