async_compression/tokio/write/
buf_writer.rs
1use super::AsyncBufWrite;
6use futures_core::ready;
7use pin_project_lite::pin_project;
8use std::{
9 cmp::min,
10 fmt, io,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tokio::io::AsyncWrite;
15
16const DEFAULT_BUF_SIZE: usize = 8192;
17
18pin_project! {
19 pub struct BufWriter<W> {
20 #[pin]
21 inner: W,
22 buf: Box<[u8]>,
23 written: usize,
24 buffered: usize,
25 }
26}
27
28impl<W: AsyncWrite> BufWriter<W> {
29 pub fn new(inner: W) -> Self {
32 Self::with_capacity(DEFAULT_BUF_SIZE, inner)
33 }
34
35 pub fn with_capacity(cap: usize, inner: W) -> Self {
37 Self {
38 inner,
39 buf: vec![0; cap].into(),
40 written: 0,
41 buffered: 0,
42 }
43 }
44
45 fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46 let mut this = self.project();
47
48 let mut ret = Ok(());
49 while *this.written < *this.buffered {
50 match this
51 .inner
52 .as_mut()
53 .poll_write(cx, &this.buf[*this.written..*this.buffered])
54 {
55 Poll::Pending => {
56 break;
57 }
58 Poll::Ready(Ok(0)) => {
59 ret = Err(io::Error::new(
60 io::ErrorKind::WriteZero,
61 "failed to write the buffered data",
62 ));
63 break;
64 }
65 Poll::Ready(Ok(n)) => *this.written += n,
66 Poll::Ready(Err(e)) => {
67 ret = Err(e);
68 break;
69 }
70 }
71 }
72
73 if *this.written > 0 {
74 this.buf.copy_within(*this.written..*this.buffered, 0);
75 *this.buffered -= *this.written;
76 *this.written = 0;
77
78 Poll::Ready(ret)
79 } else if *this.buffered == 0 {
80 Poll::Ready(ret)
81 } else {
82 ret?;
83 Poll::Pending
84 }
85 }
86
87 fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 let mut this = self.project();
89
90 let mut ret = Ok(());
91 while *this.written < *this.buffered {
92 match ready!(this
93 .inner
94 .as_mut()
95 .poll_write(cx, &this.buf[*this.written..*this.buffered]))
96 {
97 Ok(0) => {
98 ret = Err(io::Error::new(
99 io::ErrorKind::WriteZero,
100 "failed to write the buffered data",
101 ));
102 break;
103 }
104 Ok(n) => *this.written += n,
105 Err(e) => {
106 ret = Err(e);
107 break;
108 }
109 }
110 }
111 this.buf.copy_within(*this.written..*this.buffered, 0);
112 *this.buffered -= *this.written;
113 *this.written = 0;
114 Poll::Ready(ret)
115 }
116
117 pub fn get_ref(&self) -> &W {
119 &self.inner
120 }
121
122 pub fn get_mut(&mut self) -> &mut W {
126 &mut self.inner
127 }
128
129 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
133 self.project().inner
134 }
135
136 pub fn into_inner(self) -> W {
140 self.inner
141 }
142}
143
144impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
145 fn poll_write(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &[u8],
149 ) -> Poll<io::Result<usize>> {
150 let this = self.as_mut().project();
151 if *this.buffered + buf.len() > this.buf.len() {
152 ready!(self.as_mut().partial_flush_buf(cx))?;
153 }
154
155 let this = self.as_mut().project();
156 if buf.len() >= this.buf.len() {
157 if *this.buffered == 0 {
158 this.inner.poll_write(cx, buf)
159 } else {
160 Poll::Pending
163 }
164 } else {
165 let len = min(this.buf.len() - *this.buffered, buf.len());
166 this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]);
167 *this.buffered += len;
168 Poll::Ready(Ok(len))
169 }
170 }
171
172 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173 ready!(self.as_mut().flush_buf(cx))?;
174 self.project().inner.poll_flush(cx)
175 }
176
177 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178 ready!(self.as_mut().flush_buf(cx))?;
179 self.project().inner.poll_shutdown(cx)
180 }
181}
182
183impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
184 fn poll_partial_flush_buf(
185 mut self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 ) -> Poll<io::Result<&mut [u8]>> {
188 ready!(self.as_mut().partial_flush_buf(cx))?;
189 let this = self.project();
190 Poll::Ready(Ok(&mut this.buf[*this.buffered..]))
191 }
192
193 fn produce(self: Pin<&mut Self>, amt: usize) {
194 *self.project().buffered += amt;
195 }
196}
197
198impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("BufWriter")
201 .field("writer", &self.inner)
202 .field(
203 "buffer",
204 &format_args!("{}/{}", self.buffered, self.buf.len()),
205 )
206 .field("written", &self.written)
207 .finish()
208 }
209}