async_compression/tokio/write/
buf_writer.rs

1// Originally sourced from `futures_util::io::buf_writer`, needs to be redefined locally so that
2// the `AsyncBufWrite` impl can access its internals, and changed a bit to make it more efficient
3// with those methods.
4
5use super::AsyncBufWrite;
6use futures_core::ready;
7use pin_project_lite::pin_project;
8use std::{
9    cmp::min,
10    fmt, io,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tokio::io::AsyncWrite;
15
16const DEFAULT_BUF_SIZE: usize = 8192;
17
18pin_project! {
19    pub struct BufWriter<W> {
20        #[pin]
21        inner: W,
22        buf: Box<[u8]>,
23        written: usize,
24        buffered: usize,
25    }
26}
27
28impl<W: AsyncWrite> BufWriter<W> {
29    /// Creates a new `BufWriter` with a default buffer capacity. The default is currently 8 KB,
30    /// but may change in the future.
31    pub fn new(inner: W) -> Self {
32        Self::with_capacity(DEFAULT_BUF_SIZE, inner)
33    }
34
35    /// Creates a new `BufWriter` with the specified buffer capacity.
36    pub fn with_capacity(cap: usize, inner: W) -> Self {
37        Self {
38            inner,
39            buf: vec![0; cap].into(),
40            written: 0,
41            buffered: 0,
42        }
43    }
44
45    fn partial_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
46        let mut this = self.project();
47
48        let mut ret = Ok(());
49        while *this.written < *this.buffered {
50            match this
51                .inner
52                .as_mut()
53                .poll_write(cx, &this.buf[*this.written..*this.buffered])
54            {
55                Poll::Pending => {
56                    break;
57                }
58                Poll::Ready(Ok(0)) => {
59                    ret = Err(io::Error::new(
60                        io::ErrorKind::WriteZero,
61                        "failed to write the buffered data",
62                    ));
63                    break;
64                }
65                Poll::Ready(Ok(n)) => *this.written += n,
66                Poll::Ready(Err(e)) => {
67                    ret = Err(e);
68                    break;
69                }
70            }
71        }
72
73        if *this.written > 0 {
74            this.buf.copy_within(*this.written..*this.buffered, 0);
75            *this.buffered -= *this.written;
76            *this.written = 0;
77
78            Poll::Ready(ret)
79        } else if *this.buffered == 0 {
80            Poll::Ready(ret)
81        } else {
82            ret?;
83            Poll::Pending
84        }
85    }
86
87    fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88        let mut this = self.project();
89
90        let mut ret = Ok(());
91        while *this.written < *this.buffered {
92            match ready!(this
93                .inner
94                .as_mut()
95                .poll_write(cx, &this.buf[*this.written..*this.buffered]))
96            {
97                Ok(0) => {
98                    ret = Err(io::Error::new(
99                        io::ErrorKind::WriteZero,
100                        "failed to write the buffered data",
101                    ));
102                    break;
103                }
104                Ok(n) => *this.written += n,
105                Err(e) => {
106                    ret = Err(e);
107                    break;
108                }
109            }
110        }
111        this.buf.copy_within(*this.written..*this.buffered, 0);
112        *this.buffered -= *this.written;
113        *this.written = 0;
114        Poll::Ready(ret)
115    }
116
117    /// Gets a reference to the underlying writer.
118    pub fn get_ref(&self) -> &W {
119        &self.inner
120    }
121
122    /// Gets a mutable reference to the underlying writer.
123    ///
124    /// It is inadvisable to directly write to the underlying writer.
125    pub fn get_mut(&mut self) -> &mut W {
126        &mut self.inner
127    }
128
129    /// Gets a pinned mutable reference to the underlying writer.
130    ///
131    /// It is inadvisable to directly write to the underlying writer.
132    pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
133        self.project().inner
134    }
135
136    /// Consumes this `BufWriter`, returning the underlying writer.
137    ///
138    /// Note that any leftover data in the internal buffer is lost.
139    pub fn into_inner(self) -> W {
140        self.inner
141    }
142}
143
144impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {
145    fn poll_write(
146        mut self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148        buf: &[u8],
149    ) -> Poll<io::Result<usize>> {
150        let this = self.as_mut().project();
151        if *this.buffered + buf.len() > this.buf.len() {
152            ready!(self.as_mut().partial_flush_buf(cx))?;
153        }
154
155        let this = self.as_mut().project();
156        if buf.len() >= this.buf.len() {
157            if *this.buffered == 0 {
158                this.inner.poll_write(cx, buf)
159            } else {
160                // The only way that `partial_flush_buf` would have returned with
161                // `this.buffered != 0` is if it were Pending, so our waker was already queued
162                Poll::Pending
163            }
164        } else {
165            let len = min(this.buf.len() - *this.buffered, buf.len());
166            this.buf[*this.buffered..*this.buffered + len].copy_from_slice(&buf[..len]);
167            *this.buffered += len;
168            Poll::Ready(Ok(len))
169        }
170    }
171
172    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
173        ready!(self.as_mut().flush_buf(cx))?;
174        self.project().inner.poll_flush(cx)
175    }
176
177    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
178        ready!(self.as_mut().flush_buf(cx))?;
179        self.project().inner.poll_shutdown(cx)
180    }
181}
182
183impl<W: AsyncWrite> AsyncBufWrite for BufWriter<W> {
184    fn poll_partial_flush_buf(
185        mut self: Pin<&mut Self>,
186        cx: &mut Context<'_>,
187    ) -> Poll<io::Result<&mut [u8]>> {
188        ready!(self.as_mut().partial_flush_buf(cx))?;
189        let this = self.project();
190        Poll::Ready(Ok(&mut this.buf[*this.buffered..]))
191    }
192
193    fn produce(self: Pin<&mut Self>, amt: usize) {
194        *self.project().buffered += amt;
195    }
196}
197
198impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        f.debug_struct("BufWriter")
201            .field("writer", &self.inner)
202            .field(
203                "buffer",
204                &format_args!("{}/{}", self.buffered, self.buf.len()),
205            )
206            .field("written", &self.written)
207            .finish()
208    }
209}