zstd/stream/zio/
writer.rs

1use std::io::{self, Write};
2
3use crate::stream::raw::{InBuffer, Operation, OutBuffer};
4
5// input -> [ zstd -> buffer -> writer ]
6
7/// Implements the [`Write`] API around an [`Operation`].
8///
9/// This can be used to wrap a raw in-memory operation in a write-focused API.
10///
11/// It can be used with either compression or decompression, and forwards the
12/// output to a wrapped `Write`.
13pub struct Writer<W, D> {
14    /// Either an encoder or a decoder.
15    operation: D,
16
17    /// Where we send the output of the operation.
18    writer: W,
19
20    /// Offset into the buffer
21    ///
22    /// Only things after this matter. Things before have already been sent to the writer.
23    offset: usize,
24
25    /// Output buffer
26    ///
27    /// Where the operation writes, before it gets flushed to the writer
28    buffer: Vec<u8>,
29
30    // When `true`, indicates that nothing should be added to the buffer.
31    // All that's left if to empty the buffer.
32    finished: bool,
33
34    /// When `true`, the operation just finished a frame.
35    ///
36    /// Only happens when decompressing.
37    /// The context needs to be re-initialized to process the next frame.
38    finished_frame: bool,
39}
40
41impl<W, D> Writer<W, D>
42where
43    W: Write,
44    D: Operation,
45{
46    /// Creates a new `Writer` with a fixed buffer capacity of 32KB
47    ///
48    /// All output from the given operation will be forwarded to `writer`.
49    pub fn new(writer: W, operation: D) -> Self {
50        // 32KB buffer? That's what flate2 uses
51        Self::new_with_capacity(writer, operation, 32 * 1024)
52    }
53
54    /// Creates a new `Writer` with user defined capacity.
55    ///
56    /// All output from the given operation will be forwarded to `writer`.
57    pub fn new_with_capacity(
58        writer: W,
59        operation: D,
60        capacity: usize,
61    ) -> Self {
62        Self::with_output_buffer(
63            Vec::with_capacity(capacity),
64            writer,
65            operation,
66        )
67    }
68
69    /// Creates a new `Writer` using the given output buffer.
70    ///
71    /// The output buffer _must_ have pre-allocated capacity (its capacity will not be changed after).
72    ///
73    /// Usually you would use `Vec::with_capacity(desired_buffer_size)`.
74    pub fn with_output_buffer(
75        output_buffer: Vec<u8>,
76        writer: W,
77        operation: D,
78    ) -> Self {
79        Writer {
80            writer,
81            operation,
82
83            offset: 0,
84            // 32KB buffer? That's what flate2 uses
85            buffer: output_buffer,
86
87            finished: false,
88            finished_frame: false,
89        }
90    }
91
92    /// Ends the stream.
93    ///
94    /// This *must* be called after all data has been written to finish the
95    /// stream.
96    ///
97    /// If you forget to call this and just drop the `Writer`, you *will* have
98    /// an incomplete output.
99    ///
100    /// Keep calling it until it returns `Ok(())`, then don't call it again.
101    pub fn finish(&mut self) -> io::Result<()> {
102        loop {
103            // Keep trying until we're really done.
104            self.write_from_offset()?;
105
106            // At this point the buffer has been fully written out.
107
108            if self.finished {
109                return Ok(());
110            }
111
112            // Let's fill this buffer again!
113
114            let finished_frame = self.finished_frame;
115            let hint =
116                self.with_buffer(|dst, op| op.finish(dst, finished_frame));
117            self.offset = 0;
118            // println!("Hint: {:?}\nOut:{:?}", hint, &self.buffer);
119
120            // We return here if zstd had a problem.
121            // Could happen with invalid data, ...
122            let hint = hint?;
123
124            if hint != 0 && self.buffer.is_empty() {
125                // This happens if we are decoding an incomplete frame.
126                return Err(io::Error::new(
127                    io::ErrorKind::UnexpectedEof,
128                    "incomplete frame",
129                ));
130            }
131
132            // println!("Finishing {}, {}", bytes_written, hint);
133
134            self.finished = hint == 0;
135        }
136    }
137
138    /// Run the given closure on `self.buffer`.
139    ///
140    /// The buffer will be cleared, and made available wrapped in an `OutBuffer`.
141    fn with_buffer<F, T>(&mut self, f: F) -> T
142    where
143        F: FnOnce(&mut OutBuffer<'_, Vec<u8>>, &mut D) -> T,
144    {
145        self.buffer.clear();
146        let mut output = OutBuffer::around(&mut self.buffer);
147        // eprintln!("Output: {:?}", output);
148        f(&mut output, &mut self.operation)
149    }
150
151    /// Attempt to write `self.buffer` to the wrapped writer.
152    ///
153    /// Returns `Ok(())` once all the buffer has been written.
154    fn write_from_offset(&mut self) -> io::Result<()> {
155        // The code looks a lot like `write_all`, but keeps track of what has
156        // been written in case we're interrupted.
157        while self.offset < self.buffer.len() {
158            match self.writer.write(&self.buffer[self.offset..]) {
159                Ok(0) => {
160                    return Err(io::Error::new(
161                        io::ErrorKind::WriteZero,
162                        "writer will not accept any more data",
163                    ))
164                }
165                Ok(n) => self.offset += n,
166                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => (),
167                Err(e) => return Err(e),
168            }
169        }
170        Ok(())
171    }
172
173    /// Return the wrapped `Writer` and `Operation`.
174    ///
175    /// Careful: if you call this before calling [`Writer::finish()`], the
176    /// output may be incomplete.
177    pub fn into_inner(self) -> (W, D) {
178        (self.writer, self.operation)
179    }
180
181    /// Gives a reference to the inner writer.
182    pub fn writer(&self) -> &W {
183        &self.writer
184    }
185
186    /// Gives a mutable reference to the inner writer.
187    pub fn writer_mut(&mut self) -> &mut W {
188        &mut self.writer
189    }
190
191    /// Gives a reference to the inner operation.
192    pub fn operation(&self) -> &D {
193        &self.operation
194    }
195
196    /// Gives a mutable reference to the inner operation.
197    pub fn operation_mut(&mut self) -> &mut D {
198        &mut self.operation
199    }
200
201    /// Returns the offset in the current buffer. Only useful for debugging.
202    #[cfg(test)]
203    pub fn offset(&self) -> usize {
204        self.offset
205    }
206
207    /// Returns the current buffer. Only useful for debugging.
208    #[cfg(test)]
209    pub fn buffer(&self) -> &[u8] {
210        &self.buffer
211    }
212}
213
214impl<W, D> Write for Writer<W, D>
215where
216    W: Write,
217    D: Operation,
218{
219    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
220        if self.finished {
221            return Err(io::Error::new(
222                io::ErrorKind::Other,
223                "encoder is finished",
224            ));
225        }
226        // Keep trying until _something_ has been consumed.
227        // As soon as some input has been taken, we cannot afford
228        // to take any chance: if an error occurs, the user couldn't know
229        // that some data _was_ successfully written.
230        loop {
231            // First, write any pending data from `self.buffer`.
232            self.write_from_offset()?;
233            // At this point `self.buffer` can safely be discarded.
234
235            // Support writing concatenated frames by re-initializing the
236            // context.
237            if self.finished_frame {
238                self.operation.reinit()?;
239                self.finished_frame = false;
240            }
241
242            let mut src = InBuffer::around(buf);
243            let hint = self.with_buffer(|dst, op| op.run(&mut src, dst));
244            let bytes_read = src.pos;
245
246            // eprintln!(
247            //     "Write Hint: {:?}\n src: {:?}\n dst: {:?}",
248            //     hint, src, self.buffer
249            // );
250
251            self.offset = 0;
252            let hint = hint?;
253
254            if hint == 0 {
255                self.finished_frame = true;
256            }
257
258            // As we said, as soon as we've consumed something, return.
259            if bytes_read > 0 || buf.is_empty() {
260                // println!("Returning {}", bytes_read);
261                return Ok(bytes_read);
262            }
263        }
264    }
265
266    fn flush(&mut self) -> io::Result<()> {
267        let mut finished = self.finished;
268        loop {
269            // If the output is blocked or has an error, return now.
270            self.write_from_offset()?;
271
272            if finished {
273                break;
274            }
275
276            let hint = self.with_buffer(|dst, op| op.flush(dst));
277
278            self.offset = 0;
279            let hint = hint?;
280
281            finished = hint == 0;
282        }
283
284        self.writer.flush()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::Writer;
291    use std::io::Write;
292
293    #[test]
294    fn test_noop() {
295        use crate::stream::raw::NoOp;
296
297        let input = b"AbcdefghAbcdefgh.";
298
299        // Test writer
300        let mut output = Vec::new();
301        {
302            let mut writer = Writer::new(&mut output, NoOp);
303            writer.write_all(input).unwrap();
304            writer.finish().unwrap();
305        }
306        assert_eq!(&output, input);
307    }
308
309    #[test]
310    fn test_compress() {
311        use crate::stream::raw::Encoder;
312
313        let input = b"AbcdefghAbcdefgh.";
314
315        // Test writer
316        let mut output = Vec::new();
317        {
318            let mut writer =
319                Writer::new(&mut output, Encoder::new(1).unwrap());
320            writer.write_all(input).unwrap();
321            writer.finish().unwrap();
322        }
323        // println!("Output: {:?}", output);
324        let decoded = crate::decode_all(&output[..]).unwrap();
325        assert_eq!(&decoded, input);
326    }
327
328    #[test]
329    fn test_compress_with_capacity() {
330        use crate::stream::raw::Encoder;
331
332        let input = b"AbcdefghAbcdefgh.";
333
334        // Test writer
335        let mut output = Vec::new();
336        {
337            let mut writer = Writer::new_with_capacity(
338                &mut output,
339                Encoder::new(1).unwrap(),
340                64,
341            );
342            assert_eq!(writer.buffer.capacity(), 64);
343            writer.write_all(input).unwrap();
344            writer.finish().unwrap();
345        }
346        let decoded = crate::decode_all(&output[..]).unwrap();
347        assert_eq!(&decoded, input);
348    }
349
350    #[test]
351    fn test_decompress() {
352        use crate::stream::raw::Decoder;
353
354        let input = b"AbcdefghAbcdefgh.";
355        let compressed = crate::encode_all(&input[..], 1).unwrap();
356
357        // Test writer
358        let mut output = Vec::new();
359        {
360            let mut writer = Writer::new(&mut output, Decoder::new().unwrap());
361            writer.write_all(&compressed).unwrap();
362            writer.finish().unwrap();
363        }
364        // println!("Output: {:?}", output);
365        assert_eq!(&output, input);
366    }
367
368    #[test]
369    fn test_decompress_with_capacity() {
370        use crate::stream::raw::Decoder;
371
372        let input = b"AbcdefghAbcdefgh.";
373        let compressed = crate::encode_all(&input[..], 1).unwrap();
374
375        // Test writer
376        let mut output = Vec::new();
377        {
378            let mut writer = Writer::new_with_capacity(
379                &mut output,
380                Decoder::new().unwrap(),
381                64,
382            );
383            assert_eq!(writer.buffer.capacity(), 64);
384            writer.write_all(&compressed).unwrap();
385            writer.finish().unwrap();
386        }
387        assert_eq!(&output, input);
388    }
389}