zstd/stream/
raw.rs

1//! Raw in-memory stream compression/decompression.
2//!
3//! This module defines a `Decoder` and an `Encoder` to decode/encode streams
4//! of data using buffers.
5//!
6//! They are mostly thin wrappers around `zstd_safe::{DCtx, CCtx}`.
7use std::io;
8
9pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};
10
11use crate::dict::{DecoderDictionary, EncoderDictionary};
12use crate::map_error_code;
13
14/// Represents an abstract compression/decompression operation.
15///
16/// This trait covers both `Encoder` and `Decoder`.
17pub trait Operation {
18    /// Performs a single step of this operation.
19    ///
20    /// Should return a hint for the next input size.
21    ///
22    /// If the result is `Ok(0)`, it may indicate that a frame was just
23    /// finished.
24    fn run<C: WriteBuf + ?Sized>(
25        &mut self,
26        input: &mut InBuffer<'_>,
27        output: &mut OutBuffer<'_, C>,
28    ) -> io::Result<usize>;
29
30    /// Performs a single step of this operation.
31    ///
32    /// This is a comvenience wrapper around `Operation::run` if you don't
33    /// want to deal with `InBuffer`/`OutBuffer`.
34    fn run_on_buffers(
35        &mut self,
36        input: &[u8],
37        output: &mut [u8],
38    ) -> io::Result<Status> {
39        let mut input = InBuffer::around(input);
40        let mut output = OutBuffer::around(output);
41
42        let remaining = self.run(&mut input, &mut output)?;
43
44        Ok(Status {
45            remaining,
46            bytes_read: input.pos(),
47            bytes_written: output.pos(),
48        })
49    }
50
51    /// Flushes any internal buffer, if any.
52    ///
53    /// Returns the number of bytes still in the buffer.
54    /// To flush entirely, keep calling until it returns `Ok(0)`.
55    fn flush<C: WriteBuf + ?Sized>(
56        &mut self,
57        output: &mut OutBuffer<'_, C>,
58    ) -> io::Result<usize> {
59        let _ = output;
60        Ok(0)
61    }
62
63    /// Prepares the operation for a new frame.
64    ///
65    /// This is hopefully cheaper than creating a new operation.
66    fn reinit(&mut self) -> io::Result<()> {
67        Ok(())
68    }
69
70    /// Finishes the operation, writing any footer if necessary.
71    ///
72    /// Returns the number of bytes still to write.
73    ///
74    /// Keep calling this method until it returns `Ok(0)`,
75    /// and then don't ever call this method.
76    fn finish<C: WriteBuf + ?Sized>(
77        &mut self,
78        output: &mut OutBuffer<'_, C>,
79        finished_frame: bool,
80    ) -> io::Result<usize> {
81        let _ = output;
82        let _ = finished_frame;
83        Ok(0)
84    }
85}
86
87/// Dummy operation that just copies its input to the output.
88pub struct NoOp;
89
90impl Operation for NoOp {
91    fn run<C: WriteBuf + ?Sized>(
92        &mut self,
93        input: &mut InBuffer<'_>,
94        output: &mut OutBuffer<'_, C>,
95    ) -> io::Result<usize> {
96        // Skip the prelude
97        let src = &input.src[input.pos..];
98        // Safe because `output.pos() <= output.capacity()`.
99        let output_pos = output.pos();
100        let dst = unsafe { output.as_mut_ptr().add(output_pos) };
101
102        // Ignore anything past the end
103        let len = usize::min(src.len(), output.capacity() - output_pos);
104        let src = &src[..len];
105
106        // Safe because:
107        // * `len` is less than either of the two lengths
108        // * `src` and `dst` do not overlap because we have `&mut` to each.
109        unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len) };
110        input.set_pos(input.pos() + len);
111        unsafe { output.set_pos(output_pos + len) };
112
113        Ok(0)
114    }
115}
116
117/// Describes the result of an operation.
118pub struct Status {
119    /// Number of bytes expected for next input.
120    ///
121    /// * If `remaining = 0`, then we are at the end of a frame.
122    /// * If `remaining > 0`, then it's just a hint for how much there is still
123    ///   to read.
124    pub remaining: usize,
125
126    /// Number of bytes read from the input.
127    pub bytes_read: usize,
128
129    /// Number of bytes written to the output.
130    pub bytes_written: usize,
131}
132
133/// An in-memory decoder for streams of data.
134pub struct Decoder<'a> {
135    context: MaybeOwnedDCtx<'a>,
136}
137
138impl Decoder<'static> {
139    /// Creates a new decoder.
140    pub fn new() -> io::Result<Self> {
141        Self::with_dictionary(&[])
142    }
143
144    /// Creates a new decoder initialized with the given dictionary.
145    pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
146        let mut context = zstd_safe::DCtx::create();
147        context.init().map_err(map_error_code)?;
148        context
149            .load_dictionary(dictionary)
150            .map_err(map_error_code)?;
151        Ok(Decoder {
152            context: MaybeOwnedDCtx::Owned(context),
153        })
154    }
155}
156
157impl<'a> Decoder<'a> {
158    /// Creates a new decoder which employs the provided context for deserialization.
159    pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
160        Self {
161            context: MaybeOwnedDCtx::Borrowed(context),
162        }
163    }
164
165    /// Creates a new decoder, using an existing `DecoderDictionary`.
166    pub fn with_prepared_dictionary<'b>(
167        dictionary: &DecoderDictionary<'b>,
168    ) -> io::Result<Self>
169    where
170        'b: 'a,
171    {
172        let mut context = zstd_safe::DCtx::create();
173        context
174            .ref_ddict(dictionary.as_ddict())
175            .map_err(map_error_code)?;
176        Ok(Decoder {
177            context: MaybeOwnedDCtx::Owned(context),
178        })
179    }
180
181    /// Creates a new decoder, using a ref prefix
182    pub fn with_ref_prefix<'b>(ref_prefix: &'b [u8]) -> io::Result<Self>
183    where
184        'b: 'a,
185    {
186        let mut context = zstd_safe::DCtx::create();
187        context.ref_prefix(ref_prefix).map_err(map_error_code)?;
188        Ok(Decoder {
189            context: MaybeOwnedDCtx::Owned(context),
190        })
191    }
192
193    /// Sets a decompression parameter for this decoder.
194    pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
195        match &mut self.context {
196            MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
197            MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
198        }
199        .map_err(map_error_code)?;
200        Ok(())
201    }
202}
203
204impl Operation for Decoder<'_> {
205    fn run<C: WriteBuf + ?Sized>(
206        &mut self,
207        input: &mut InBuffer<'_>,
208        output: &mut OutBuffer<'_, C>,
209    ) -> io::Result<usize> {
210        match &mut self.context {
211            MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
212            MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
213        }
214        .map_err(map_error_code)
215    }
216
217    fn flush<C: WriteBuf + ?Sized>(
218        &mut self,
219        output: &mut OutBuffer<'_, C>,
220    ) -> io::Result<usize> {
221        // To flush, we just offer no additional input.
222        self.run(&mut InBuffer::around(&[]), output)?;
223
224        // We don't _know_ how much (decompressed data) there is still in buffer.
225        if output.pos() < output.capacity() {
226            // We only know when there's none (the output buffer is not full).
227            Ok(0)
228        } else {
229            // Otherwise, pretend there's still "1 byte" remaining.
230            Ok(1)
231        }
232    }
233
234    fn reinit(&mut self) -> io::Result<()> {
235        match &mut self.context {
236            MaybeOwnedDCtx::Owned(x) => {
237                x.reset(zstd_safe::ResetDirective::SessionOnly)
238            }
239            MaybeOwnedDCtx::Borrowed(x) => {
240                x.reset(zstd_safe::ResetDirective::SessionOnly)
241            }
242        }
243        .map_err(map_error_code)?;
244        Ok(())
245    }
246
247    fn finish<C: WriteBuf + ?Sized>(
248        &mut self,
249        _output: &mut OutBuffer<'_, C>,
250        finished_frame: bool,
251    ) -> io::Result<usize> {
252        if finished_frame {
253            Ok(0)
254        } else {
255            Err(io::Error::new(
256                io::ErrorKind::UnexpectedEof,
257                "incomplete frame",
258            ))
259        }
260    }
261}
262
263/// An in-memory encoder for streams of data.
264pub struct Encoder<'a> {
265    context: MaybeOwnedCCtx<'a>,
266}
267
268impl Encoder<'static> {
269    /// Creates a new encoder.
270    pub fn new(level: i32) -> io::Result<Self> {
271        Self::with_dictionary(level, &[])
272    }
273
274    /// Creates a new encoder initialized with the given dictionary.
275    pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
276        let mut context = zstd_safe::CCtx::create();
277
278        context
279            .set_parameter(CParameter::CompressionLevel(level))
280            .map_err(map_error_code)?;
281
282        context
283            .load_dictionary(dictionary)
284            .map_err(map_error_code)?;
285
286        Ok(Encoder {
287            context: MaybeOwnedCCtx::Owned(context),
288        })
289    }
290}
291
292impl<'a> Encoder<'a> {
293    /// Creates a new encoder that uses the provided context for serialization.
294    pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
295        Self {
296            context: MaybeOwnedCCtx::Borrowed(context),
297        }
298    }
299
300    /// Creates a new encoder using an existing `EncoderDictionary`.
301    pub fn with_prepared_dictionary<'b>(
302        dictionary: &EncoderDictionary<'b>,
303    ) -> io::Result<Self>
304    where
305        'b: 'a,
306    {
307        let mut context = zstd_safe::CCtx::create();
308        context
309            .ref_cdict(dictionary.as_cdict())
310            .map_err(map_error_code)?;
311        Ok(Encoder {
312            context: MaybeOwnedCCtx::Owned(context),
313        })
314    }
315
316    /// Creates a new encoder initialized with the given ref prefix.
317    pub fn with_ref_prefix<'b>(
318        level: i32,
319        ref_prefix: &'b [u8],
320    ) -> io::Result<Self>
321    where
322        'b: 'a,
323    {
324        let mut context = zstd_safe::CCtx::create();
325
326        context
327            .set_parameter(CParameter::CompressionLevel(level))
328            .map_err(map_error_code)?;
329
330        context.ref_prefix(ref_prefix).map_err(map_error_code)?;
331
332        Ok(Encoder {
333            context: MaybeOwnedCCtx::Owned(context),
334        })
335    }
336
337    /// Sets a compression parameter for this encoder.
338    pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
339        match &mut self.context {
340            MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
341            MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
342        }
343        .map_err(map_error_code)?;
344        Ok(())
345    }
346
347    /// Sets the size of the input expected by zstd.
348    ///
349    /// May affect compression ratio.
350    ///
351    /// It is an error to give an incorrect size (an error _will_ be returned when closing the
352    /// stream).
353    ///
354    /// If `None` is given, it assume the size is not known (default behaviour).
355    pub fn set_pledged_src_size(
356        &mut self,
357        pledged_src_size: Option<u64>,
358    ) -> io::Result<()> {
359        match &mut self.context {
360            MaybeOwnedCCtx::Owned(x) => {
361                x.set_pledged_src_size(pledged_src_size)
362            }
363            MaybeOwnedCCtx::Borrowed(x) => {
364                x.set_pledged_src_size(pledged_src_size)
365            }
366        }
367        .map_err(map_error_code)?;
368        Ok(())
369    }
370}
371
372impl<'a> Operation for Encoder<'a> {
373    fn run<C: WriteBuf + ?Sized>(
374        &mut self,
375        input: &mut InBuffer<'_>,
376        output: &mut OutBuffer<'_, C>,
377    ) -> io::Result<usize> {
378        match &mut self.context {
379            MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
380            MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
381        }
382        .map_err(map_error_code)
383    }
384
385    fn flush<C: WriteBuf + ?Sized>(
386        &mut self,
387        output: &mut OutBuffer<'_, C>,
388    ) -> io::Result<usize> {
389        match &mut self.context {
390            MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
391            MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
392        }
393        .map_err(map_error_code)
394    }
395
396    fn finish<C: WriteBuf + ?Sized>(
397        &mut self,
398        output: &mut OutBuffer<'_, C>,
399        _finished_frame: bool,
400    ) -> io::Result<usize> {
401        match &mut self.context {
402            MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
403            MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
404        }
405        .map_err(map_error_code)
406    }
407
408    fn reinit(&mut self) -> io::Result<()> {
409        match &mut self.context {
410            MaybeOwnedCCtx::Owned(x) => {
411                x.reset(zstd_safe::ResetDirective::SessionOnly)
412            }
413            MaybeOwnedCCtx::Borrowed(x) => {
414                x.reset(zstd_safe::ResetDirective::SessionOnly)
415            }
416        }
417        .map_err(map_error_code)?;
418        Ok(())
419    }
420}
421
422enum MaybeOwnedCCtx<'a> {
423    Owned(zstd_safe::CCtx<'a>),
424    Borrowed(&'a mut zstd_safe::CCtx<'static>),
425}
426
427enum MaybeOwnedDCtx<'a> {
428    Owned(zstd_safe::DCtx<'a>),
429    Borrowed(&'a mut zstd_safe::DCtx<'static>),
430}
431
432#[cfg(test)]
433mod tests {
434
435    // This requires impl for [u8; N] which is currently behind a feature.
436    #[cfg(feature = "arrays")]
437    #[test]
438    fn test_cycle() {
439        use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
440
441        let mut encoder = Encoder::new(1).unwrap();
442        let mut decoder = Decoder::new().unwrap();
443
444        // Step 1: compress
445        let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
446
447        let mut output = [0u8; 128];
448        let mut output = OutBuffer::around(&mut output);
449
450        loop {
451            encoder.run(&mut input, &mut output).unwrap();
452
453            if input.pos == input.src.len() {
454                break;
455            }
456        }
457        encoder.finish(&mut output, true).unwrap();
458
459        let initial_data = input.src;
460
461        // Step 2: decompress
462        let mut input = InBuffer::around(output.as_slice());
463        let mut output = [0u8; 128];
464        let mut output = OutBuffer::around(&mut output);
465
466        loop {
467            decoder.run(&mut input, &mut output).unwrap();
468
469            if input.pos == input.src.len() {
470                break;
471            }
472        }
473
474        assert_eq!(initial_data, output.as_slice());
475    }
476}