async_compression/codec/zstd/
decoder.rs

1use std::io;
2use std::io::Result;
3
4use crate::{codec::Decode, unshared::Unshared, util::PartialBuffer};
5use libzstd::stream::raw::{Decoder, Operation};
6
7#[derive(Debug)]
8pub struct ZstdDecoder {
9    decoder: Unshared<Decoder<'static>>,
10}
11
12impl ZstdDecoder {
13    pub(crate) fn new() -> Self {
14        Self {
15            decoder: Unshared::new(Decoder::new().unwrap()),
16        }
17    }
18
19    pub(crate) fn new_with_dict(dictionary: &[u8]) -> io::Result<Self> {
20        let mut decoder = Decoder::with_dictionary(dictionary)?;
21        Ok(Self {
22            decoder: Unshared::new(decoder),
23        })
24    }
25}
26
27impl Decode for ZstdDecoder {
28    fn reinit(&mut self) -> Result<()> {
29        self.decoder.get_mut().reinit()?;
30        Ok(())
31    }
32
33    fn decode(
34        &mut self,
35        input: &mut PartialBuffer<impl AsRef<[u8]>>,
36        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
37    ) -> Result<bool> {
38        let status = self
39            .decoder
40            .get_mut()
41            .run_on_buffers(input.unwritten(), output.unwritten_mut())?;
42        input.advance(status.bytes_read);
43        output.advance(status.bytes_written);
44        Ok(status.remaining == 0)
45    }
46
47    fn flush(
48        &mut self,
49        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
50    ) -> Result<bool> {
51        let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut());
52        let bytes_left = self.decoder.get_mut().flush(&mut out_buf)?;
53        let len = out_buf.as_slice().len();
54        output.advance(len);
55        Ok(bytes_left == 0)
56    }
57
58    fn finish(
59        &mut self,
60        output: &mut PartialBuffer<impl AsRef<[u8]> + AsMut<[u8]>>,
61    ) -> Result<bool> {
62        let mut out_buf = zstd_safe::OutBuffer::around(output.unwritten_mut());
63        let bytes_left = self.decoder.get_mut().finish(&mut out_buf, true)?;
64        let len = out_buf.as_slice().len();
65        output.advance(len);
66        Ok(bytes_left == 0)
67    }
68}