lz4_flex/block/
compress.rs

1//! The compression algorithm.
2//!
3//! We make use of hash tables to find duplicates. This gives a reasonable compression ratio with a
4//! high performance. It has fixed memory usage, which contrary to other approachs, makes it less
5//! memory hungry.
6
7use crate::block::hashtable::HashTable;
8use crate::block::END_OFFSET;
9use crate::block::LZ4_MIN_LENGTH;
10use crate::block::MAX_DISTANCE;
11use crate::block::MFLIMIT;
12use crate::block::MINMATCH;
13#[cfg(not(feature = "safe-encode"))]
14use crate::sink::PtrSink;
15use crate::sink::Sink;
16use crate::sink::SliceSink;
17#[allow(unused_imports)]
18use alloc::vec;
19
20#[allow(unused_imports)]
21use alloc::vec::Vec;
22
23use super::hashtable::HashTable4K;
24use super::hashtable::HashTable4KU16;
25use super::{CompressError, WINDOW_SIZE};
26
27/// Increase step size after 1<<INCREASE_STEPSIZE_BITSHIFT non matches
28const INCREASE_STEPSIZE_BITSHIFT: usize = 5;
29
30/// Read a 4-byte "batch" from some position.
31///
32/// This will read a native-endian 4-byte integer from some position.
33#[inline]
34#[cfg(not(feature = "safe-encode"))]
35pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
36    unsafe { read_u32_ptr(input.as_ptr().add(n)) }
37}
38
39#[inline]
40#[cfg(feature = "safe-encode")]
41pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
42    u32::from_ne_bytes(input[n..n + 4].try_into().unwrap())
43}
44
45/// Read an usize sized "batch" from some position.
46///
47/// This will read a native-endian usize from some position.
48#[inline]
49#[allow(dead_code)]
50#[cfg(not(feature = "safe-encode"))]
51pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
52    unsafe { read_usize_ptr(input.as_ptr().add(n)) }
53}
54
55#[inline]
56#[allow(dead_code)]
57#[cfg(feature = "safe-encode")]
58pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
59    const USIZE_SIZE: usize = core::mem::size_of::<usize>();
60    let arr: &[u8; USIZE_SIZE] = input[n..n + USIZE_SIZE].try_into().unwrap();
61    usize::from_ne_bytes(*arr)
62}
63
64#[inline]
65fn token_from_literal(lit_len: usize) -> u8 {
66    if lit_len < 0xF {
67        // Since we can fit the literals length into it, there is no need for saturation.
68        (lit_len as u8) << 4
69    } else {
70        // We were unable to fit the literals into it, so we saturate to 0xF. We will later
71        // write the extensional value.
72        0xF0
73    }
74}
75
76#[inline]
77fn token_from_literal_and_match_length(lit_len: usize, duplicate_length: usize) -> u8 {
78    let mut token = if lit_len < 0xF {
79        // Since we can fit the literals length into it, there is no need for saturation.
80        (lit_len as u8) << 4
81    } else {
82        // We were unable to fit the literals into it, so we saturate to 0xF. We will later
83        // write the extensional value.
84        0xF0
85    };
86
87    token |= if duplicate_length < 0xF {
88        // We could fit it in.
89        duplicate_length as u8
90    } else {
91        // We were unable to fit it in, so we default to 0xF, which will later be extended.
92        0xF
93    };
94
95    token
96}
97
98/// Counts the number of same bytes in two byte streams.
99/// `input` is the complete input
100/// `cur` is the current position in the input. it will be incremented by the number of matched
101/// bytes `source` either the same as input or an external slice
102/// `candidate` is the candidate position in `source`
103///
104/// The function ignores the last END_OFFSET bytes in input as those should be literals.
105#[inline]
106#[cfg(feature = "safe-encode")]
107fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
108    const USIZE_SIZE: usize = core::mem::size_of::<usize>();
109    let cur_slice = &input[*cur..input.len() - END_OFFSET];
110    let cand_slice = &source[candidate..];
111
112    let mut num = 0;
113    for (block1, block2) in cur_slice
114        .chunks_exact(USIZE_SIZE)
115        .zip(cand_slice.chunks_exact(USIZE_SIZE))
116    {
117        let input_block = usize::from_ne_bytes(block1.try_into().unwrap());
118        let match_block = usize::from_ne_bytes(block2.try_into().unwrap());
119
120        if input_block == match_block {
121            num += USIZE_SIZE;
122        } else {
123            let diff = input_block ^ match_block;
124            num += (diff.to_le().trailing_zeros() / 8) as usize;
125            *cur += num;
126            return num;
127        }
128    }
129
130    // If we're here we may have 1 to 7 bytes left to check close to the end of input
131    // or source slices. Since this is rare occurrence we mark it cold to get better
132    // ~5% better performance.
133    #[cold]
134    fn count_same_bytes_tail(a: &[u8], b: &[u8], offset: usize) -> usize {
135        a.iter()
136            .zip(b)
137            .skip(offset)
138            .take_while(|(a, b)| a == b)
139            .count()
140    }
141    num += count_same_bytes_tail(cur_slice, cand_slice, num);
142
143    *cur += num;
144    num
145}
146
147/// Counts the number of same bytes in two byte streams.
148/// `input` is the complete input
149/// `cur` is the current position in the input. it will be incremented by the number of matched
150/// bytes `source` either the same as input OR an external slice
151/// `candidate` is the candidate position in `source`
152///
153/// The function ignores the last END_OFFSET bytes in input as those should be literals.
154#[inline]
155#[cfg(not(feature = "safe-encode"))]
156fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
157    let max_input_match = input.len().saturating_sub(*cur + END_OFFSET);
158    let max_candidate_match = source.len() - candidate;
159    // Considering both limits calc how far we may match in input.
160    let input_end = *cur + max_input_match.min(max_candidate_match);
161
162    let start = *cur;
163    let mut source_ptr = unsafe { source.as_ptr().add(candidate) };
164
165    // compare 4/8 bytes blocks depending on the arch
166    const STEP_SIZE: usize = core::mem::size_of::<usize>();
167    while *cur + STEP_SIZE <= input_end {
168        let diff = read_usize_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_usize_ptr(source_ptr);
169
170        if diff == 0 {
171            *cur += STEP_SIZE;
172            unsafe {
173                source_ptr = source_ptr.add(STEP_SIZE);
174            }
175        } else {
176            *cur += (diff.to_le().trailing_zeros() / 8) as usize;
177            return *cur - start;
178        }
179    }
180
181    // compare 4 bytes block
182    #[cfg(target_pointer_width = "64")]
183    {
184        if input_end - *cur >= 4 {
185            let diff = read_u32_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_u32_ptr(source_ptr);
186
187            if diff == 0 {
188                *cur += 4;
189                unsafe {
190                    source_ptr = source_ptr.add(4);
191                }
192            } else {
193                *cur += (diff.to_le().trailing_zeros() / 8) as usize;
194                return *cur - start;
195            }
196        }
197    }
198
199    // compare 2 bytes block
200    if input_end - *cur >= 2
201        && unsafe { read_u16_ptr(input.as_ptr().add(*cur)) == read_u16_ptr(source_ptr) }
202    {
203        *cur += 2;
204        unsafe {
205            source_ptr = source_ptr.add(2);
206        }
207    }
208
209    if *cur < input_end
210        && unsafe { input.as_ptr().add(*cur).read() } == unsafe { source_ptr.read() }
211    {
212        *cur += 1;
213    }
214
215    *cur - start
216}
217
218/// Write an integer to the output.
219///
220/// Each additional byte then represent a value from 0 to 255, which is added to the previous value
221/// to produce a total length. When the byte value is 255, another byte must read and added, and so
222/// on. There can be any number of bytes of value "255" following token
223#[inline]
224#[cfg(feature = "safe-encode")]
225fn write_integer(output: &mut impl Sink, mut n: usize) {
226    // Note: Since `n` is usually < 0xFF and writing multiple bytes to the output
227    // requires 2 branches of bound check (due to the possibility of add overflows)
228    // the simple byte at a time implementation below is faster in most cases.
229    while n >= 0xFF {
230        n -= 0xFF;
231        push_byte(output, 0xFF);
232    }
233    push_byte(output, n as u8);
234}
235
236/// Write an integer to the output.
237///
238/// Each additional byte then represent a value from 0 to 255, which is added to the previous value
239/// to produce a total length. When the byte value is 255, another byte must read and added, and so
240/// on. There can be any number of bytes of value "255" following token
241#[inline]
242#[cfg(not(feature = "safe-encode"))]
243fn write_integer(output: &mut impl Sink, mut n: usize) {
244    // Write the 0xFF bytes as long as the integer is higher than said value.
245    if n >= 4 * 0xFF {
246        // In this unlikelly branch we use a fill instead of a loop,
247        // otherwise rustc may output a large unrolled/vectorized loop.
248        let bulk = n / (4 * 0xFF);
249        n %= 4 * 0xFF;
250        unsafe {
251            core::ptr::write_bytes(output.pos_mut_ptr(), 0xFF, 4 * bulk);
252            output.set_pos(output.pos() + 4 * bulk);
253        }
254    }
255
256    // Handle last 1 to 4 bytes
257    push_u32(output, 0xFFFFFFFF);
258    // Updating output len for the remainder
259    unsafe {
260        output.set_pos(output.pos() - 4 + 1 + n / 255);
261        // Write the remaining byte.
262        *output.pos_mut_ptr().sub(1) = (n % 255) as u8;
263    }
264}
265
266/// Handle the last bytes from the input as literals
267#[cold]
268fn handle_last_literals(output: &mut impl Sink, input: &[u8], start: usize) {
269    let lit_len = input.len() - start;
270
271    let token = token_from_literal(lit_len);
272    push_byte(output, token);
273    if lit_len >= 0xF {
274        write_integer(output, lit_len - 0xF);
275    }
276    // Now, write the actual literals.
277    output.extend_from_slice(&input[start..]);
278}
279
280/// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
281#[inline]
282#[cfg(feature = "safe-encode")]
283fn backtrack_match(
284    input: &[u8],
285    cur: &mut usize,
286    literal_start: usize,
287    source: &[u8],
288    candidate: &mut usize,
289) {
290    // Note: Even if iterator version of this loop has less branches inside the loop it has more
291    // branches before the loop. That in practice seems to make it slower than the while version
292    // bellow. TODO: It should be possible remove all bounds checks, since we are walking
293    // backwards
294    while *candidate > 0 && *cur > literal_start && input[*cur - 1] == source[*candidate - 1] {
295        *cur -= 1;
296        *candidate -= 1;
297    }
298}
299
300/// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
301#[inline]
302#[cfg(not(feature = "safe-encode"))]
303fn backtrack_match(
304    input: &[u8],
305    cur: &mut usize,
306    literal_start: usize,
307    source: &[u8],
308    candidate: &mut usize,
309) {
310    while unsafe {
311        *candidate > 0
312            && *cur > literal_start
313            && input.get_unchecked(*cur - 1) == source.get_unchecked(*candidate - 1)
314    } {
315        *cur -= 1;
316        *candidate -= 1;
317    }
318}
319
320/// Compress all bytes of `input[input_pos..]` into `output`.
321///
322/// Bytes in `input[..input_pos]` are treated as a preamble and can be used for lookback.
323/// This part is known as the compressor "prefix".
324/// Bytes in `ext_dict` logically precede the bytes in `input` and can also be used for lookback.
325///
326/// `input_stream_offset` is the logical position of the first byte of `input`. This allows same
327/// `dict` to be used for many calls to `compress_internal` as we can "readdress" the first byte of
328/// `input` to be something other than 0.
329///
330/// `dict` is the dictionary of previously encoded sequences.
331///
332/// This is used to find duplicates in the stream so they are not written multiple times.
333///
334/// Every four bytes are hashed, and in the resulting slot their position in the input buffer
335/// is placed in the dict. This way we can easily look up a candidate to back references.
336///
337/// Returns the number of bytes written (compressed) into `output`.
338///
339/// # Const parameters
340/// `USE_DICT`: Disables usage of ext_dict (it'll panic if a non-empty slice is used).
341/// In other words, this generates more optimized code when an external dictionary isn't used.
342///
343/// A similar const argument could be used to disable the Prefix mode (eg. USE_PREFIX),
344/// which would impose `input_pos == 0 && input_stream_offset == 0`. Experiments didn't
345/// show significant improvement though.
346// Intentionally avoid inlining.
347// Empirical tests revealed it to be rarely better but often significantly detrimental.
348#[inline(never)]
349pub(crate) fn compress_internal<T: HashTable, const USE_DICT: bool, S: Sink>(
350    input: &[u8],
351    input_pos: usize,
352    output: &mut S,
353    dict: &mut T,
354    ext_dict: &[u8],
355    input_stream_offset: usize,
356) -> Result<usize, CompressError> {
357    assert!(input_pos <= input.len());
358    if USE_DICT {
359        assert!(ext_dict.len() <= super::WINDOW_SIZE);
360        assert!(ext_dict.len() <= input_stream_offset);
361        // Check for overflow hazard when using ext_dict
362        assert!(input_stream_offset
363            .checked_add(input.len())
364            .and_then(|i| i.checked_add(ext_dict.len()))
365            .map_or(false, |i| i <= isize::MAX as usize));
366    } else {
367        assert!(ext_dict.is_empty());
368    }
369    if output.capacity() - output.pos() < get_maximum_output_size(input.len() - input_pos) {
370        return Err(CompressError::OutputTooSmall);
371    }
372
373    let output_start_pos = output.pos();
374    if input.len() - input_pos < LZ4_MIN_LENGTH {
375        handle_last_literals(output, input, input_pos);
376        return Ok(output.pos() - output_start_pos);
377    }
378
379    let ext_dict_stream_offset = input_stream_offset - ext_dict.len();
380    let end_pos_check = input.len() - MFLIMIT;
381    let mut literal_start = input_pos;
382    let mut cur = input_pos;
383
384    if cur == 0 && input_stream_offset == 0 {
385        // According to the spec we can't start with a match,
386        // except when referencing another block.
387        let hash = T::get_hash_at(input, 0);
388        dict.put_at(hash, 0);
389        cur = 1;
390    }
391
392    loop {
393        // Read the next block into two sections, the literals and the duplicates.
394        let mut step_size;
395        let mut candidate;
396        let mut candidate_source;
397        let mut offset;
398        let mut non_match_count = 1 << INCREASE_STEPSIZE_BITSHIFT;
399        // The number of bytes before our cursor, where the duplicate starts.
400        let mut next_cur = cur;
401
402        // In this loop we search for duplicates via the hashtable. 4bytes or 8bytes are hashed and
403        // compared.
404        loop {
405            step_size = non_match_count >> INCREASE_STEPSIZE_BITSHIFT;
406            non_match_count += 1;
407
408            cur = next_cur;
409            next_cur += step_size;
410
411            // Same as cur + MFLIMIT > input.len()
412            if cur > end_pos_check {
413                handle_last_literals(output, input, literal_start);
414                return Ok(output.pos() - output_start_pos);
415            }
416            // Find a candidate in the dictionary with the hash of the current four bytes.
417            // Unchecked is safe as long as the values from the hash function don't exceed the size
418            // of the table. This is ensured by right shifting the hash values
419            // (`dict_bitshift`) to fit them in the table
420
421            // [Bounds Check]: Can be elided due to `end_pos_check` above
422            let hash = T::get_hash_at(input, cur);
423            candidate = dict.get_at(hash);
424            dict.put_at(hash, cur + input_stream_offset);
425
426            // Sanity check: Matches can't be ahead of `cur`.
427            debug_assert!(candidate <= input_stream_offset + cur);
428
429            // Two requirements to the candidate exists:
430            // - We should not return a position which is merely a hash collision, so that the
431            //   candidate actually matches what we search for.
432            // - We can address up to 16-bit offset, hence we are only able to address the candidate
433            //   if its offset is less than or equals to 0xFFFF.
434            if input_stream_offset + cur - candidate > MAX_DISTANCE {
435                continue;
436            }
437
438            if candidate >= input_stream_offset {
439                // match within input
440                offset = (input_stream_offset + cur - candidate) as u16;
441                candidate -= input_stream_offset;
442                candidate_source = input;
443            } else if USE_DICT {
444                // Sanity check, which may fail if we lost history beyond MAX_DISTANCE
445                debug_assert!(
446                    candidate >= ext_dict_stream_offset,
447                    "Lost history in ext dict mode"
448                );
449                // match within ext dict
450                offset = (input_stream_offset + cur - candidate) as u16;
451                candidate -= ext_dict_stream_offset;
452                candidate_source = ext_dict;
453            } else {
454                // Match is not reachable anymore
455                // eg. compressing an independent block frame w/o clearing
456                // the matches tables, only increasing input_stream_offset.
457                // Sanity check
458                debug_assert!(input_pos == 0, "Lost history in prefix mode");
459                continue;
460            }
461            // [Bounds Check]: Candidate is coming from the Hashmap. It can't be out of bounds, but
462            // impossible to prove for the compiler and remove the bounds checks.
463            let cand_bytes: u32 = get_batch(candidate_source, candidate);
464            // [Bounds Check]: Should be able to be elided due to `end_pos_check`.
465            let curr_bytes: u32 = get_batch(input, cur);
466
467            if cand_bytes == curr_bytes {
468                break;
469            }
470        }
471
472        // Extend the match backwards if we can
473        backtrack_match(
474            input,
475            &mut cur,
476            literal_start,
477            candidate_source,
478            &mut candidate,
479        );
480
481        // The length (in bytes) of the literals section.
482        let lit_len = cur - literal_start;
483
484        // Generate the higher half of the token.
485        cur += MINMATCH;
486        candidate += MINMATCH;
487        let duplicate_length = count_same_bytes(input, &mut cur, candidate_source, candidate);
488
489        // Note: The `- 2` offset was copied from the reference implementation, it could be
490        // arbitrary.
491        let hash = T::get_hash_at(input, cur - 2);
492        dict.put_at(hash, cur - 2 + input_stream_offset);
493
494        let token = token_from_literal_and_match_length(lit_len, duplicate_length);
495
496        // Push the token to the output stream.
497        push_byte(output, token);
498        // If we were unable to fit the literals length into the token, write the extensional
499        // part.
500        if lit_len >= 0xF {
501            write_integer(output, lit_len - 0xF);
502        }
503
504        // Now, write the actual literals.
505        //
506        // The unsafe version copies blocks of 8bytes, and therefore may copy up to 7bytes more than
507        // needed. This is safe, because the last 12 bytes (MF_LIMIT) are handled in
508        // handle_last_literals.
509        copy_literals_wild(output, input, literal_start, lit_len);
510        // write the offset in little endian.
511        push_u16(output, offset);
512
513        // If we were unable to fit the duplicates length into the token, write the
514        // extensional part.
515        if duplicate_length >= 0xF {
516            write_integer(output, duplicate_length - 0xF);
517        }
518        literal_start = cur;
519    }
520}
521
522#[inline]
523#[cfg(feature = "safe-encode")]
524fn push_byte(output: &mut impl Sink, el: u8) {
525    output.push(el);
526}
527
528#[inline]
529#[cfg(not(feature = "safe-encode"))]
530fn push_byte(output: &mut impl Sink, el: u8) {
531    unsafe {
532        core::ptr::write(output.pos_mut_ptr(), el);
533        output.set_pos(output.pos() + 1);
534    }
535}
536
537#[inline]
538#[cfg(feature = "safe-encode")]
539fn push_u16(output: &mut impl Sink, el: u16) {
540    output.extend_from_slice(&el.to_le_bytes());
541}
542
543#[inline]
544#[cfg(not(feature = "safe-encode"))]
545fn push_u16(output: &mut impl Sink, el: u16) {
546    unsafe {
547        core::ptr::copy_nonoverlapping(el.to_le_bytes().as_ptr(), output.pos_mut_ptr(), 2);
548        output.set_pos(output.pos() + 2);
549    }
550}
551
552#[inline]
553#[cfg(not(feature = "safe-encode"))]
554fn push_u32(output: &mut impl Sink, el: u32) {
555    unsafe {
556        core::ptr::copy_nonoverlapping(el.to_le_bytes().as_ptr(), output.pos_mut_ptr(), 4);
557        output.set_pos(output.pos() + 4);
558    }
559}
560
561#[inline(always)] // (always) necessary otherwise compiler fails to inline it
562#[cfg(feature = "safe-encode")]
563fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
564    output.extend_from_slice_wild(&input[input_start..input_start + len], len)
565}
566
567#[inline]
568#[cfg(not(feature = "safe-encode"))]
569fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
570    debug_assert!(input_start + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= input.len());
571    debug_assert!(output.pos() + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= output.capacity());
572    unsafe {
573        // Note: This used to be a wild copy loop of 8 bytes, but the compiler consistently
574        // transformed it into a call to memcopy, which hurts performance significantly for
575        // small copies, which are common.
576        let start_ptr = input.as_ptr().add(input_start);
577        match len {
578            0..=8 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 8),
579            9..=16 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 16),
580            17..=24 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 24),
581            _ => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), len),
582        }
583        output.set_pos(output.pos() + len);
584    }
585}
586
587/// Compress all bytes of `input` into `output`.
588/// The method chooses an appropriate hashtable to lookup duplicates.
589/// output should be preallocated with a size of
590/// `get_maximum_output_size`.
591///
592/// Returns the number of bytes written (compressed) into `output`.
593
594#[inline]
595pub(crate) fn compress_into_sink_with_dict<const USE_DICT: bool>(
596    input: &[u8],
597    output: &mut impl Sink,
598    mut dict_data: &[u8],
599) -> Result<usize, CompressError> {
600    if dict_data.len() + input.len() < u16::MAX as usize {
601        let mut dict = HashTable4KU16::new();
602        init_dict(&mut dict, &mut dict_data);
603        compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
604    } else {
605        let mut dict = HashTable4K::new();
606        init_dict(&mut dict, &mut dict_data);
607        compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
608    }
609}
610
611#[inline]
612fn init_dict<T: HashTable>(dict: &mut T, dict_data: &mut &[u8]) {
613    if dict_data.len() > WINDOW_SIZE {
614        *dict_data = &dict_data[dict_data.len() - WINDOW_SIZE..];
615    }
616    let mut i = 0usize;
617    while i + core::mem::size_of::<usize>() <= dict_data.len() {
618        let hash = T::get_hash_at(dict_data, i);
619        dict.put_at(hash, i);
620        // Note: The 3 byte step was copied from the reference implementation, it could be
621        // arbitrary.
622        i += 3;
623    }
624}
625
626/// Returns the maximum output size of the compressed data.
627/// Can be used to preallocate capacity on the output vector
628#[inline]
629pub const fn get_maximum_output_size(input_len: usize) -> usize {
630    16 + 4 + (input_len * 110 / 100) as usize
631}
632
633/// Compress all bytes of `input` into `output`.
634/// The method chooses an appropriate hashtable to lookup duplicates.
635/// output should be preallocated with a size of
636/// `get_maximum_output_size`.
637///
638/// Returns the number of bytes written (compressed) into `output`.
639#[inline]
640pub fn compress_into(input: &[u8], output: &mut [u8]) -> Result<usize, CompressError> {
641    compress_into_sink_with_dict::<false>(input, &mut SliceSink::new(output, 0), b"")
642}
643
644/// Compress all bytes of `input` into `output`.
645/// The method chooses an appropriate hashtable to lookup duplicates.
646/// output should be preallocated with a size of
647/// `get_maximum_output_size`.
648///
649/// Returns the number of bytes written (compressed) into `output`.
650#[inline]
651pub fn compress_into_with_dict(
652    input: &[u8],
653    output: &mut [u8],
654    dict_data: &[u8],
655) -> Result<usize, CompressError> {
656    compress_into_sink_with_dict::<true>(input, &mut SliceSink::new(output, 0), dict_data)
657}
658
659#[inline]
660fn compress_into_vec_with_dict<const USE_DICT: bool>(
661    input: &[u8],
662    prepend_size: bool,
663    mut dict_data: &[u8],
664) -> Vec<u8> {
665    let prepend_size_num_bytes = if prepend_size { 4 } else { 0 };
666    let max_compressed_size = get_maximum_output_size(input.len()) + prepend_size_num_bytes;
667    if dict_data.len() <= 3 {
668        dict_data = b"";
669    }
670    #[cfg(feature = "safe-encode")]
671    let mut compressed = {
672        let mut compressed: Vec<u8> = vec![0u8; max_compressed_size];
673        let out = if prepend_size {
674            compressed[..4].copy_from_slice(&(input.len() as u32).to_le_bytes());
675            &mut compressed[4..]
676        } else {
677            &mut compressed
678        };
679        let compressed_len =
680            compress_into_sink_with_dict::<USE_DICT>(input, &mut SliceSink::new(out, 0), dict_data)
681                .unwrap();
682
683        compressed.truncate(prepend_size_num_bytes + compressed_len);
684        compressed
685    };
686    #[cfg(not(feature = "safe-encode"))]
687    let mut compressed = {
688        let mut vec = Vec::with_capacity(max_compressed_size);
689        let start_pos = if prepend_size {
690            vec.extend_from_slice(&(input.len() as u32).to_le_bytes());
691            4
692        } else {
693            0
694        };
695        let compressed_len = compress_into_sink_with_dict::<USE_DICT>(
696            input,
697            &mut PtrSink::from_vec(&mut vec, start_pos),
698            dict_data,
699        )
700        .unwrap();
701        unsafe {
702            vec.set_len(prepend_size_num_bytes + compressed_len);
703        }
704        vec
705    };
706
707    compressed.shrink_to_fit();
708    compressed
709}
710
711/// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
712/// endian u32. Can be used in conjunction with `decompress_size_prepended`
713#[inline]
714pub fn compress_prepend_size(input: &[u8]) -> Vec<u8> {
715    compress_into_vec_with_dict::<false>(input, true, b"")
716}
717
718/// Compress all bytes of `input`.
719#[inline]
720pub fn compress(input: &[u8]) -> Vec<u8> {
721    compress_into_vec_with_dict::<false>(input, false, b"")
722}
723
724/// Compress all bytes of `input` with an external dictionary.
725#[inline]
726pub fn compress_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
727    compress_into_vec_with_dict::<true>(input, false, ext_dict)
728}
729
730/// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
731/// endian u32. Can be used in conjunction with `decompress_size_prepended_with_dict`
732#[inline]
733pub fn compress_prepend_size_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
734    compress_into_vec_with_dict::<true>(input, true, ext_dict)
735}
736
737#[inline]
738#[cfg(not(feature = "safe-encode"))]
739fn read_u16_ptr(input: *const u8) -> u16 {
740    let mut num: u16 = 0;
741    unsafe {
742        core::ptr::copy_nonoverlapping(input, &mut num as *mut u16 as *mut u8, 2);
743    }
744    num
745}
746
747#[inline]
748#[cfg(not(feature = "safe-encode"))]
749fn read_u32_ptr(input: *const u8) -> u32 {
750    let mut num: u32 = 0;
751    unsafe {
752        core::ptr::copy_nonoverlapping(input, &mut num as *mut u32 as *mut u8, 4);
753    }
754    num
755}
756
757#[inline]
758#[cfg(not(feature = "safe-encode"))]
759fn read_usize_ptr(input: *const u8) -> usize {
760    let mut num: usize = 0;
761    unsafe {
762        core::ptr::copy_nonoverlapping(
763            input,
764            &mut num as *mut usize as *mut u8,
765            core::mem::size_of::<usize>(),
766        );
767    }
768    num
769}
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774
775    #[test]
776    fn test_count_same_bytes() {
777        // 8byte aligned block, zeros and ones are added because the end/offset
778        let first: &[u8] = &[
779            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
780        ];
781        let second: &[u8] = &[
782            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
783        ];
784        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 16);
785
786        // 4byte aligned block
787        let first: &[u8] = &[
788            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
789            0, 0, 0,
790        ];
791        let second: &[u8] = &[
792            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1,
793            1, 1, 1,
794        ];
795        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 20);
796
797        // 2byte aligned block
798        let first: &[u8] = &[
799            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 0, 0, 0, 0, 0, 0, 0,
800            0, 0, 0, 0, 0,
801        ];
802        let second: &[u8] = &[
803            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 1, 1, 1, 1, 1, 1, 1,
804            1, 1, 1, 1, 1,
805        ];
806        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
807
808        // 1byte aligned block
809        let first: &[u8] = &[
810            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
811            0, 0, 0, 0, 0, 0,
812        ];
813        let second: &[u8] = &[
814            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 1, 1, 1, 1, 1, 1,
815            1, 1, 1, 1, 1, 1,
816        ];
817        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 23);
818
819        // 1byte aligned block - last byte different
820        let first: &[u8] = &[
821            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
822            0, 0, 0, 0, 0, 0,
823        ];
824        let second: &[u8] = &[
825            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
826            1, 1, 1, 1, 1, 1,
827        ];
828        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
829
830        // 1byte aligned block
831        let first: &[u8] = &[
832            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 9, 5, 0, 0, 0, 0, 0, 0,
833            0, 0, 0, 0, 0, 0,
834        ];
835        let second: &[u8] = &[
836            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
837            1, 1, 1, 1, 1, 1,
838        ];
839        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 21);
840
841        for diff_idx in 8..100 {
842            let first: Vec<u8> = (0u8..255).cycle().take(100 + 12).collect();
843            let mut second = first.clone();
844            second[diff_idx] = 255;
845            for start in 0..=diff_idx {
846                let same_bytes = count_same_bytes(&first, &mut start.clone(), &second, start);
847                assert_eq!(same_bytes, diff_idx - start);
848            }
849        }
850    }
851
852    #[test]
853    fn test_bug() {
854        let input: &[u8] = &[
855            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
856        ];
857        let _out = compress(input);
858    }
859
860    #[test]
861    fn test_dict() {
862        let input: &[u8] = &[
863            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
864        ];
865        let dict = input;
866        let compressed = compress_with_dict(input, dict);
867        assert_lt!(compressed.len(), compress(input).len());
868
869        assert!(compressed.len() < compress(input).len());
870        let mut uncompressed = vec![0u8; input.len()];
871        let uncomp_size = crate::block::decompress::decompress_into_with_dict(
872            &compressed,
873            &mut uncompressed,
874            dict,
875        )
876        .unwrap();
877        uncompressed.truncate(uncomp_size);
878        assert_eq!(input, uncompressed);
879    }
880
881    #[test]
882    fn test_dict_no_panic() {
883        let input: &[u8] = &[
884            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
885        ];
886        let dict = &[10, 12, 14];
887        let _compressed = compress_with_dict(input, dict);
888    }
889
890    #[test]
891    fn test_dict_match_crossing() {
892        let input: &[u8] = &[
893            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
894        ];
895        let dict = input;
896        let compressed = compress_with_dict(input, dict);
897        assert_lt!(compressed.len(), compress(input).len());
898
899        let mut uncompressed = vec![0u8; input.len() * 2];
900        // copy first half of the input into output
901        let dict_cutoff = dict.len() / 2;
902        let output_start = dict.len() - dict_cutoff;
903        uncompressed[..output_start].copy_from_slice(&dict[dict_cutoff..]);
904        let uncomp_len = {
905            let mut sink = SliceSink::new(&mut uncompressed[..], output_start);
906            crate::block::decompress::decompress_internal::<true, _>(
907                &compressed,
908                &mut sink,
909                &dict[..dict_cutoff],
910            )
911            .unwrap()
912        };
913        assert_eq!(input.len(), uncomp_len);
914        assert_eq!(
915            input,
916            &uncompressed[output_start..output_start + uncomp_len]
917        );
918    }
919
920    #[test]
921    fn test_conformant_last_block() {
922        // From the spec:
923        // The last match must start at least 12 bytes before the end of block.
924        // The last match is part of the penultimate sequence. It is followed by the last sequence,
925        // which contains only literals. Note that, as a consequence, an independent block <
926        // 13 bytes cannot be compressed, because the match must copy "something",
927        // so it needs at least one prior byte.
928        // When a block can reference data from another block, it can start immediately with a match
929        // and no literal, so a block of 12 bytes can be compressed.
930        let aaas: &[u8] = b"aaaaaaaaaaaaaaa";
931
932        // uncompressible
933        let out = compress(&aaas[..12]);
934        assert_gt!(out.len(), 12);
935        // compressible
936        let out = compress(&aaas[..13]);
937        assert_le!(out.len(), 13);
938        let out = compress(&aaas[..14]);
939        assert_le!(out.len(), 14);
940        let out = compress(&aaas[..15]);
941        assert_le!(out.len(), 15);
942
943        // dict uncompressible
944        let out = compress_with_dict(&aaas[..11], aaas);
945        assert_gt!(out.len(), 11);
946        // compressible
947        let out = compress_with_dict(&aaas[..12], aaas);
948        // According to the spec this _could_ compres, but it doesn't in this lib
949        // as it aborts compression for any input len < LZ4_MIN_LENGTH
950        assert_gt!(out.len(), 12);
951        let out = compress_with_dict(&aaas[..13], aaas);
952        assert_le!(out.len(), 13);
953        let out = compress_with_dict(&aaas[..14], aaas);
954        assert_le!(out.len(), 14);
955        let out = compress_with_dict(&aaas[..15], aaas);
956        assert_le!(out.len(), 15);
957    }
958
959    #[test]
960    fn test_dict_size() {
961        let dict = vec![b'a'; 1024 * 1024];
962        let input = &b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa"[..];
963        let compressed = compress_prepend_size_with_dict(input, &dict);
964        let decompressed =
965            crate::block::decompress_size_prepended_with_dict(&compressed, &dict).unwrap();
966        assert_eq!(decompressed, input);
967    }
968}