zstd/
dict.rs

1//! Train a dictionary from various sources.
2//!
3//! A dictionary can help improve the compression of small files.
4//! The dictionary must be present during decompression,
5//! but can be shared across multiple "similar" files.
6//!
7//! Creating a dictionary using the `zstd` C library,
8//! using the `zstd` command-line interface, using this library,
9//! or using the `train` binary provided, should give the same result,
10//! and are therefore completely compatible.
11//!
12//! To use, see [`Encoder::with_dictionary`] or [`Decoder::with_dictionary`].
13//!
14//! [`Encoder::with_dictionary`]: ../struct.Encoder.html#method.with_dictionary
15//! [`Decoder::with_dictionary`]: ../struct.Decoder.html#method.with_dictionary
16
17#[cfg(feature = "zdict_builder")]
18use std::io::{self, Read};
19
20pub use zstd_safe::{CDict, DDict};
21
22/// Prepared dictionary for compression
23///
24/// A dictionary can include its own copy of the data (if it is `'static`), or it can merely point
25/// to a separate buffer (if it has another lifetime).
26pub struct EncoderDictionary<'a> {
27    cdict: CDict<'a>,
28}
29
30impl EncoderDictionary<'static> {
31    /// Creates a prepared dictionary for compression.
32    ///
33    /// This will copy the dictionary internally.
34    pub fn copy(dictionary: &[u8], level: i32) -> Self {
35        Self {
36            cdict: zstd_safe::create_cdict(dictionary, level),
37        }
38    }
39}
40
41impl<'a> EncoderDictionary<'a> {
42    #[cfg(feature = "experimental")]
43    #[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "experimental")))]
44    /// Create prepared dictionary for compression
45    ///
46    /// A level of `0` uses zstd's default (currently `3`).
47    ///
48    /// Only available with the `experimental` feature. Use `EncoderDictionary::copy` otherwise.
49    pub fn new(dictionary: &'a [u8], level: i32) -> Self {
50        Self {
51            cdict: zstd_safe::CDict::create_by_reference(dictionary, level),
52        }
53    }
54
55    /// Returns reference to `CDict` inner object
56    pub fn as_cdict(&self) -> &CDict<'a> {
57        &self.cdict
58    }
59}
60
61/// Prepared dictionary for decompression
62pub struct DecoderDictionary<'a> {
63    ddict: DDict<'a>,
64}
65
66impl DecoderDictionary<'static> {
67    /// Create a prepared dictionary for decompression.
68    ///
69    /// This will copy the dictionary internally.
70    pub fn copy(dictionary: &[u8]) -> Self {
71        Self {
72            ddict: zstd_safe::DDict::create(dictionary),
73        }
74    }
75}
76
77impl<'a> DecoderDictionary<'a> {
78    #[cfg(feature = "experimental")]
79    #[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "experimental")))]
80    /// Create prepared dictionary for decompression
81    ///
82    /// Only available with the `experimental` feature. Use `DecoderDictionary::copy` otherwise.
83    pub fn new(dict: &'a [u8]) -> Self {
84        Self {
85            ddict: zstd_safe::DDict::create_by_reference(dict),
86        }
87    }
88
89    /// Returns reference to `DDict` inner object
90    pub fn as_ddict(&self) -> &DDict<'a> {
91        &self.ddict
92    }
93}
94
95/// Train a dictionary from a big continuous chunk of data, with all samples
96/// contiguous in memory.
97///
98/// This is the most efficient way to train a dictionary,
99/// since this is directly fed into `zstd`.
100///
101/// * `sample_data` is the concatenation of all sample data.
102/// * `sample_sizes` is the size of each sample in `sample_data`.
103///     The sum of all `sample_sizes` should equal the length of `sample_data`.
104/// * `max_size` is the maximum size of the dictionary to generate.
105///
106/// The result is the dictionary data. You can, for example, feed it to [`CDict::create`].
107#[cfg(feature = "zdict_builder")]
108#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
109pub fn from_continuous(
110    sample_data: &[u8],
111    sample_sizes: &[usize],
112    max_size: usize,
113) -> io::Result<Vec<u8>> {
114    use crate::map_error_code;
115
116    // Complain if the lengths don't add up to the entire data.
117    if sample_sizes.iter().sum::<usize>() != sample_data.len() {
118        return Err(io::Error::new(
119            io::ErrorKind::Other,
120            "sample sizes don't add up".to_string(),
121        ));
122    }
123
124    let mut result = Vec::with_capacity(max_size);
125    zstd_safe::train_from_buffer(&mut result, sample_data, sample_sizes)
126        .map_err(map_error_code)?;
127    Ok(result)
128}
129
130/// Train a dictionary from multiple samples.
131///
132/// The samples will internally be copied to a single continuous buffer,
133/// so make sure you have enough memory available.
134///
135/// If you need to stretch your system's limits,
136/// [`from_continuous`] directly uses the given slice.
137///
138/// [`from_continuous`]: ./fn.from_continuous.html
139///
140/// * `samples` is a list of individual samples to train on.
141/// * `max_size` is the maximum size of the dictionary to generate.
142///
143/// The result is the dictionary data. You can, for example, feed it to [`CDict::create`].
144#[cfg(feature = "zdict_builder")]
145#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
146pub fn from_samples<S: AsRef<[u8]>>(
147    samples: &[S],
148    max_size: usize,
149) -> io::Result<Vec<u8>> {
150    // Pre-allocate the entire required size.
151    let total_length: usize =
152        samples.iter().map(|sample| sample.as_ref().len()).sum();
153
154    let mut data = Vec::with_capacity(total_length);
155
156    // Copy every sample to a big chunk of memory
157    data.extend(samples.iter().flat_map(|s| s.as_ref()).cloned());
158
159    let sizes: Vec<_> = samples.iter().map(|s| s.as_ref().len()).collect();
160
161    from_continuous(&data, &sizes, max_size)
162}
163
164/// Train a dictionary from multiple samples.
165///
166/// Unlike [`from_samples`], this does not require having a list of all samples.
167/// It also allows running into an error when iterating through the samples.
168///
169/// They will still be copied to a continuous array and fed to [`from_continuous`].
170///
171/// * `samples` is an iterator of individual samples to train on.
172/// * `max_size` is the maximum size of the dictionary to generate.
173///
174/// The result is the dictionary data. You can, for example, feed it to [`CDict::create`].
175///
176/// # Examples
177///
178/// ```rust,no_run
179/// // Train from a couple of json files.
180/// let dict_buffer = zstd::dict::from_sample_iterator(
181///     ["file_a.json", "file_b.json"]
182///         .into_iter()
183///         .map(|filename| std::fs::File::open(filename)),
184///     10_000,  // 10kB dictionary
185/// ).unwrap();
186/// ```
187///
188/// ```rust,no_run
189/// use std::io::BufRead as _;
190/// // Treat each line from stdin as a separate sample.
191/// let dict_buffer = zstd::dict::from_sample_iterator(
192///     std::io::stdin().lock().lines().map(|line: std::io::Result<String>| {
193///         // Transform each line into a `Cursor<Vec<u8>>` so they implement Read.
194///         line.map(String::into_bytes)
195///             .map(std::io::Cursor::new)
196///     }),
197///     10_000,  // 10kB dictionary
198/// ).unwrap();
199/// ```
200#[cfg(feature = "zdict_builder")]
201#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
202pub fn from_sample_iterator<I, R>(
203    samples: I,
204    max_size: usize,
205) -> io::Result<Vec<u8>>
206where
207    I: IntoIterator<Item = io::Result<R>>,
208    R: Read,
209{
210    let mut data = Vec::new();
211    let mut sizes = Vec::new();
212
213    for sample in samples {
214        let mut sample = sample?;
215        let len = sample.read_to_end(&mut data)?;
216        sizes.push(len);
217    }
218
219    from_continuous(&data, &sizes, max_size)
220}
221
222/// Train a dict from a list of files.
223///
224/// * `filenames` is an iterator of files to load. Each file will be treated as an individual
225///     sample.
226/// * `max_size` is the maximum size of the dictionary to generate.
227///
228/// The result is the dictionary data. You can, for example, feed it to [`CDict::create`].
229#[cfg(feature = "zdict_builder")]
230#[cfg_attr(feature = "doc-cfg", doc(cfg(feature = "zdict_builder")))]
231pub fn from_files<I, P>(filenames: I, max_size: usize) -> io::Result<Vec<u8>>
232where
233    P: AsRef<std::path::Path>,
234    I: IntoIterator<Item = P>,
235{
236    from_sample_iterator(
237        filenames
238            .into_iter()
239            .map(|filename| std::fs::File::open(filename)),
240        max_size,
241    )
242}
243
244#[cfg(test)]
245#[cfg(feature = "zdict_builder")]
246mod tests {
247    use std::fs;
248    use std::io;
249    use std::io::Read;
250
251    use walkdir;
252
253    #[test]
254    fn test_dict_training() {
255        // Train a dictionary
256        let paths: Vec<_> = walkdir::WalkDir::new("src")
257            .into_iter()
258            .map(|entry| entry.unwrap())
259            .map(|entry| entry.into_path())
260            .filter(|path| path.to_str().unwrap().ends_with(".rs"))
261            .collect();
262
263        let dict = super::from_files(&paths, 4000).unwrap();
264
265        for path in paths {
266            let mut buffer = Vec::new();
267            let mut file = fs::File::open(path).unwrap();
268            let mut content = Vec::new();
269            file.read_to_end(&mut content).unwrap();
270            io::copy(
271                &mut &content[..],
272                &mut crate::stream::Encoder::with_dictionary(
273                    &mut buffer,
274                    1,
275                    &dict,
276                )
277                .unwrap()
278                .auto_finish(),
279            )
280            .unwrap();
281
282            let mut result = Vec::new();
283            io::copy(
284                &mut crate::stream::Decoder::with_dictionary(
285                    &buffer[..],
286                    &dict[..],
287                )
288                .unwrap(),
289                &mut result,
290            )
291            .unwrap();
292
293            assert_eq!(&content, &result);
294        }
295    }
296}