lz4_flex/block/
compress.rs

1//! The compression algorithm.
2//!
3//! We make use of hash tables to find duplicates. This gives a reasonable compression ratio with a
4//! high performance. It has fixed memory usage, which contrary to other approaches, makes it less
5//! memory hungry.
6
7use crate::block::hashtable::HashTable;
8use crate::block::END_OFFSET;
9use crate::block::LZ4_MIN_LENGTH;
10use crate::block::MAX_DISTANCE;
11use crate::block::MFLIMIT;
12use crate::block::MINMATCH;
13#[cfg(not(feature = "safe-encode"))]
14use crate::sink::PtrSink;
15use crate::sink::Sink;
16use crate::sink::SliceSink;
17#[allow(unused_imports)]
18use alloc::vec;
19
20#[allow(unused_imports)]
21use alloc::vec::Vec;
22
23use super::hashtable::HashTable4K;
24use super::hashtable::HashTable4KU16;
25use super::{CompressError, WINDOW_SIZE};
26
27/// Increase step size after 1<<INCREASE_STEPSIZE_BITSHIFT non matches
28const INCREASE_STEPSIZE_BITSHIFT: usize = 5;
29
30/// Read a 4-byte "batch" from some position.
31///
32/// This will read a native-endian 4-byte integer from some position.
33#[inline]
34#[cfg(not(feature = "safe-encode"))]
35pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
36    unsafe { read_u32_ptr(input.as_ptr().add(n)) }
37}
38
39#[inline]
40#[cfg(feature = "safe-encode")]
41pub(super) fn get_batch(input: &[u8], n: usize) -> u32 {
42    u32::from_ne_bytes(input[n..n + 4].try_into().unwrap())
43}
44
45/// Read an usize sized "batch" from some position.
46///
47/// This will read a native-endian usize from some position.
48#[inline]
49#[allow(dead_code)]
50#[cfg(not(feature = "safe-encode"))]
51pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
52    unsafe { read_usize_ptr(input.as_ptr().add(n)) }
53}
54
55#[inline]
56#[allow(dead_code)]
57#[cfg(feature = "safe-encode")]
58pub(super) fn get_batch_arch(input: &[u8], n: usize) -> usize {
59    const USIZE_SIZE: usize = core::mem::size_of::<usize>();
60    let arr: &[u8; USIZE_SIZE] = input[n..n + USIZE_SIZE].try_into().unwrap();
61    usize::from_ne_bytes(*arr)
62}
63
64#[inline]
65fn token_from_literal(lit_len: usize) -> u8 {
66    if lit_len < 0xF {
67        // Since we can fit the literals length into it, there is no need for saturation.
68        (lit_len as u8) << 4
69    } else {
70        // We were unable to fit the literals into it, so we saturate to 0xF. We will later
71        // write the extensional value.
72        0xF0
73    }
74}
75
76#[inline]
77fn token_from_literal_and_match_length(lit_len: usize, duplicate_length: usize) -> u8 {
78    let mut token = if lit_len < 0xF {
79        // Since we can fit the literals length into it, there is no need for saturation.
80        (lit_len as u8) << 4
81    } else {
82        // We were unable to fit the literals into it, so we saturate to 0xF. We will later
83        // write the extensional value.
84        0xF0
85    };
86
87    token |= if duplicate_length < 0xF {
88        // We could fit it in.
89        duplicate_length as u8
90    } else {
91        // We were unable to fit it in, so we default to 0xF, which will later be extended.
92        0xF
93    };
94
95    token
96}
97
98/// Counts the number of same bytes in two byte streams.
99/// `input` is the complete input
100/// `cur` is the current position in the input. it will be incremented by the number of matched
101/// bytes `source` either the same as input or an external slice
102/// `candidate` is the candidate position in `source`
103///
104/// The function ignores the last END_OFFSET bytes in input as those should be literals.
105#[inline]
106#[cfg(feature = "safe-encode")]
107fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
108    const USIZE_SIZE: usize = core::mem::size_of::<usize>();
109    let cur_slice = &input[*cur..input.len() - END_OFFSET];
110    let cand_slice = &source[candidate..];
111
112    let mut num = 0;
113    for (block1, block2) in cur_slice
114        .chunks_exact(USIZE_SIZE)
115        .zip(cand_slice.chunks_exact(USIZE_SIZE))
116    {
117        let input_block = usize::from_ne_bytes(block1.try_into().unwrap());
118        let match_block = usize::from_ne_bytes(block2.try_into().unwrap());
119
120        if input_block == match_block {
121            num += USIZE_SIZE;
122        } else {
123            let diff = input_block ^ match_block;
124            num += (diff.to_le().trailing_zeros() / 8) as usize;
125            *cur += num;
126            return num;
127        }
128    }
129
130    // If we're here we may have 1 to 7 bytes left to check close to the end of input
131    // or source slices. Since this is rare occurrence we mark it cold to get better
132    // ~5% better performance.
133    #[cold]
134    fn count_same_bytes_tail(a: &[u8], b: &[u8], offset: usize) -> usize {
135        a.iter()
136            .zip(b)
137            .skip(offset)
138            .take_while(|(a, b)| a == b)
139            .count()
140    }
141    num += count_same_bytes_tail(cur_slice, cand_slice, num);
142
143    *cur += num;
144    num
145}
146
147/// Counts the number of same bytes in two byte streams.
148/// `input` is the complete input
149/// `cur` is the current position in the input. it will be incremented by the number of matched
150/// bytes `source` either the same as input OR an external slice
151/// `candidate` is the candidate position in `source`
152///
153/// The function ignores the last END_OFFSET bytes in input as those should be literals.
154#[inline]
155#[cfg(not(feature = "safe-encode"))]
156fn count_same_bytes(input: &[u8], cur: &mut usize, source: &[u8], candidate: usize) -> usize {
157    let max_input_match = input.len().saturating_sub(*cur + END_OFFSET);
158    let max_candidate_match = source.len() - candidate;
159    // Considering both limits calc how far we may match in input.
160    let input_end = *cur + max_input_match.min(max_candidate_match);
161
162    let start = *cur;
163    let mut source_ptr = unsafe { source.as_ptr().add(candidate) };
164
165    // compare 4/8 bytes blocks depending on the arch
166    const STEP_SIZE: usize = core::mem::size_of::<usize>();
167    while *cur + STEP_SIZE <= input_end {
168        let diff = read_usize_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_usize_ptr(source_ptr);
169
170        if diff == 0 {
171            *cur += STEP_SIZE;
172            unsafe {
173                source_ptr = source_ptr.add(STEP_SIZE);
174            }
175        } else {
176            *cur += (diff.to_le().trailing_zeros() / 8) as usize;
177            return *cur - start;
178        }
179    }
180
181    // compare 4 bytes block
182    #[cfg(target_pointer_width = "64")]
183    {
184        if input_end - *cur >= 4 {
185            let diff = read_u32_ptr(unsafe { input.as_ptr().add(*cur) }) ^ read_u32_ptr(source_ptr);
186
187            if diff == 0 {
188                *cur += 4;
189                unsafe {
190                    source_ptr = source_ptr.add(4);
191                }
192            } else {
193                *cur += (diff.to_le().trailing_zeros() / 8) as usize;
194                return *cur - start;
195            }
196        }
197    }
198
199    // compare 2 bytes block
200    if input_end - *cur >= 2
201        && unsafe { read_u16_ptr(input.as_ptr().add(*cur)) == read_u16_ptr(source_ptr) }
202    {
203        *cur += 2;
204        unsafe {
205            source_ptr = source_ptr.add(2);
206        }
207    }
208
209    if *cur < input_end
210        && unsafe { input.as_ptr().add(*cur).read() } == unsafe { source_ptr.read() }
211    {
212        *cur += 1;
213    }
214
215    *cur - start
216}
217
218/// Write an integer to the output.
219///
220/// Each additional byte then represent a value from 0 to 255, which is added to the previous value
221/// to produce a total length. When the byte value is 255, another byte must read and added, and so
222/// on. There can be any number of bytes of value "255" following token
223#[inline]
224pub(super) fn write_integer(output: &mut impl Sink, mut n: usize) {
225    // Note: Since `n` is usually < 0xFF and writing multiple bytes to the output
226    // requires 2 branches of bound check (due to the possibility of add overflows)
227    // the simple byte at a time implementation below is faster in most cases.
228    while n >= 0xFF {
229        n -= 0xFF;
230        push_byte(output, 0xFF);
231    }
232    push_byte(output, n as u8);
233}
234
235/// Handle the last bytes from the input as literals
236#[cold]
237fn handle_last_literals(output: &mut impl Sink, input: &[u8], start: usize) {
238    let lit_len = input.len() - start;
239
240    let token = token_from_literal(lit_len);
241    push_byte(output, token);
242    if lit_len >= 0xF {
243        write_integer(output, lit_len - 0xF);
244    }
245    // Now, write the actual literals.
246    output.extend_from_slice(&input[start..]);
247}
248
249/// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
250#[inline]
251#[cfg(feature = "safe-encode")]
252fn backtrack_match(
253    input: &[u8],
254    cur: &mut usize,
255    literal_start: usize,
256    source: &[u8],
257    candidate: &mut usize,
258) {
259    // Note: Even if iterator version of this loop has less branches inside the loop it has more
260    // branches before the loop. That in practice seems to make it slower than the while version
261    // bellow. TODO: It should be possible remove all bounds checks, since we are walking
262    // backwards
263    while *candidate > 0 && *cur > literal_start && input[*cur - 1] == source[*candidate - 1] {
264        *cur -= 1;
265        *candidate -= 1;
266    }
267}
268
269/// Moves the cursors back as long as the bytes match, to find additional bytes in a duplicate
270#[inline]
271#[cfg(not(feature = "safe-encode"))]
272fn backtrack_match(
273    input: &[u8],
274    cur: &mut usize,
275    literal_start: usize,
276    source: &[u8],
277    candidate: &mut usize,
278) {
279    while unsafe {
280        *candidate > 0
281            && *cur > literal_start
282            && input.get_unchecked(*cur - 1) == source.get_unchecked(*candidate - 1)
283    } {
284        *cur -= 1;
285        *candidate -= 1;
286    }
287}
288
289/// Compress all bytes of `input[input_pos..]` into `output`.
290///
291/// Bytes in `input[..input_pos]` are treated as a preamble and can be used for lookback.
292/// This part is known as the compressor "prefix".
293/// Bytes in `ext_dict` logically precede the bytes in `input` and can also be used for lookback.
294///
295/// `input_stream_offset` is the logical position of the first byte of `input`. This allows same
296/// `dict` to be used for many calls to `compress_internal` as we can "readdress" the first byte of
297/// `input` to be something other than 0.
298///
299/// `dict` is the dictionary of previously encoded sequences.
300///
301/// This is used to find duplicates in the stream so they are not written multiple times.
302///
303/// Every four bytes are hashed, and in the resulting slot their position in the input buffer
304/// is placed in the dict. This way we can easily look up a candidate to back references.
305///
306/// Returns the number of bytes written (compressed) into `output`.
307///
308/// # Const parameters
309/// `USE_DICT`: Disables usage of ext_dict (it'll panic if a non-empty slice is used).
310/// In other words, this generates more optimized code when an external dictionary isn't used.
311///
312/// A similar const argument could be used to disable the Prefix mode (eg. USE_PREFIX),
313/// which would impose `input_pos == 0 && input_stream_offset == 0`. Experiments didn't
314/// show significant improvement though.
315// Intentionally avoid inlining.
316// Empirical tests revealed it to be rarely better but often significantly detrimental.
317#[inline(never)]
318pub(crate) fn compress_internal<T: HashTable, const USE_DICT: bool, S: Sink>(
319    input: &[u8],
320    input_pos: usize,
321    output: &mut S,
322    dict: &mut T,
323    ext_dict: &[u8],
324    input_stream_offset: usize,
325) -> Result<usize, CompressError> {
326    assert!(input_pos <= input.len());
327    if USE_DICT {
328        assert!(ext_dict.len() <= super::WINDOW_SIZE);
329        assert!(ext_dict.len() <= input_stream_offset);
330        // Check for overflow hazard when using ext_dict
331        assert!(input_stream_offset
332            .checked_add(input.len())
333            .and_then(|i| i.checked_add(ext_dict.len()))
334            .is_some_and(|i| i <= isize::MAX as usize));
335    } else {
336        assert!(ext_dict.is_empty());
337    }
338    if output.capacity() - output.pos() < get_maximum_output_size(input.len() - input_pos) {
339        return Err(CompressError::OutputTooSmall);
340    }
341
342    let output_start_pos = output.pos();
343    if input.len() - input_pos < LZ4_MIN_LENGTH {
344        handle_last_literals(output, input, input_pos);
345        return Ok(output.pos() - output_start_pos);
346    }
347
348    let ext_dict_stream_offset = input_stream_offset - ext_dict.len();
349    let end_pos_check = input.len() - MFLIMIT;
350    let mut literal_start = input_pos;
351    let mut cur = input_pos;
352
353    if cur == 0 && input_stream_offset == 0 {
354        // According to the spec we can't start with a match,
355        // except when referencing another block.
356        let hash = T::get_hash_at(input, 0);
357        dict.put_at(hash, 0);
358        cur = 1;
359    }
360
361    loop {
362        // Read the next block into two sections, the literals and the duplicates.
363        let mut step_size;
364        let mut candidate;
365        let mut candidate_source;
366        let mut offset;
367        let mut non_match_count = 1 << INCREASE_STEPSIZE_BITSHIFT;
368        // The number of bytes before our cursor, where the duplicate starts.
369        let mut next_cur = cur;
370
371        // In this loop we search for duplicates via the hashtable. 4bytes or 8bytes are hashed and
372        // compared.
373        loop {
374            step_size = non_match_count >> INCREASE_STEPSIZE_BITSHIFT;
375            non_match_count += 1;
376
377            cur = next_cur;
378            next_cur += step_size;
379
380            // Same as cur + MFLIMIT > input.len()
381            if cur > end_pos_check {
382                handle_last_literals(output, input, literal_start);
383                return Ok(output.pos() - output_start_pos);
384            }
385            // Find a candidate in the dictionary with the hash of the current four bytes.
386            // Unchecked is safe as long as the values from the hash function don't exceed the size
387            // of the table. This is ensured by right shifting the hash values
388            // (`dict_bitshift`) to fit them in the table
389
390            // [Bounds Check]: Can be elided due to `end_pos_check` above
391            let hash = T::get_hash_at(input, cur);
392            candidate = dict.get_at(hash);
393            dict.put_at(hash, cur + input_stream_offset);
394
395            // Sanity check: Matches can't be ahead of `cur`.
396            debug_assert!(candidate <= input_stream_offset + cur);
397
398            // Two requirements to the candidate exists:
399            // - We should not return a position which is merely a hash collision, so that the
400            //   candidate actually matches what we search for.
401            // - We can address up to 16-bit offset, hence we are only able to address the candidate
402            //   if its offset is less than or equals to 0xFFFF.
403            if input_stream_offset + cur - candidate > MAX_DISTANCE {
404                continue;
405            }
406
407            if candidate >= input_stream_offset {
408                // match within input
409                offset = (input_stream_offset + cur - candidate) as u16;
410                candidate -= input_stream_offset;
411                candidate_source = input;
412            } else if USE_DICT {
413                // Sanity check, which may fail if we lost history beyond MAX_DISTANCE
414                debug_assert!(
415                    candidate >= ext_dict_stream_offset,
416                    "Lost history in ext dict mode"
417                );
418                // match within ext dict
419                offset = (input_stream_offset + cur - candidate) as u16;
420                candidate -= ext_dict_stream_offset;
421                candidate_source = ext_dict;
422            } else {
423                // Match is not reachable anymore
424                // eg. compressing an independent block frame w/o clearing
425                // the matches tables, only increasing input_stream_offset.
426                // Sanity check
427                debug_assert!(input_pos == 0, "Lost history in prefix mode");
428                continue;
429            }
430            // [Bounds Check]: Candidate is coming from the Hashmap. It can't be out of bounds, but
431            // impossible to prove for the compiler and remove the bounds checks.
432            let cand_bytes: u32 = get_batch(candidate_source, candidate);
433            // [Bounds Check]: Should be able to be elided due to `end_pos_check`.
434            let curr_bytes: u32 = get_batch(input, cur);
435
436            if cand_bytes == curr_bytes {
437                break;
438            }
439        }
440
441        // Extend the match backwards if we can
442        backtrack_match(
443            input,
444            &mut cur,
445            literal_start,
446            candidate_source,
447            &mut candidate,
448        );
449
450        // The length (in bytes) of the literals section.
451        let lit_len = cur - literal_start;
452
453        // Generate the higher half of the token.
454        cur += MINMATCH;
455        candidate += MINMATCH;
456        let duplicate_length = count_same_bytes(input, &mut cur, candidate_source, candidate);
457
458        // Note: The `- 2` offset was copied from the reference implementation, it could be
459        // arbitrary.
460        let hash = T::get_hash_at(input, cur - 2);
461        dict.put_at(hash, cur - 2 + input_stream_offset);
462
463        let token = token_from_literal_and_match_length(lit_len, duplicate_length);
464
465        // Push the token to the output stream.
466        push_byte(output, token);
467        // If we were unable to fit the literals length into the token, write the extensional
468        // part.
469        if lit_len >= 0xF {
470            write_integer(output, lit_len - 0xF);
471        }
472
473        // Now, write the actual literals.
474        //
475        // The unsafe version copies blocks of 8bytes, and therefore may copy up to 7bytes more than
476        // needed. This is safe, because the last 12 bytes (MF_LIMIT) are handled in
477        // handle_last_literals.
478        copy_literals_wild(output, input, literal_start, lit_len);
479        // write the offset in little endian.
480        push_u16(output, offset);
481
482        // If we were unable to fit the duplicates length into the token, write the
483        // extensional part.
484        if duplicate_length >= 0xF {
485            write_integer(output, duplicate_length - 0xF);
486        }
487        literal_start = cur;
488    }
489}
490
491#[inline]
492#[cfg(feature = "safe-encode")]
493fn push_byte(output: &mut impl Sink, el: u8) {
494    output.push(el);
495}
496
497#[inline]
498#[cfg(not(feature = "safe-encode"))]
499fn push_byte(output: &mut impl Sink, el: u8) {
500    unsafe {
501        core::ptr::write(output.pos_mut_ptr(), el);
502        output.set_pos(output.pos() + 1);
503    }
504}
505
506#[inline]
507#[cfg(feature = "safe-encode")]
508fn push_u16(output: &mut impl Sink, el: u16) {
509    output.extend_from_slice(&el.to_le_bytes());
510}
511
512#[inline]
513#[cfg(not(feature = "safe-encode"))]
514fn push_u16(output: &mut impl Sink, el: u16) {
515    unsafe {
516        core::ptr::copy_nonoverlapping(el.to_le_bytes().as_ptr(), output.pos_mut_ptr(), 2);
517        output.set_pos(output.pos() + 2);
518    }
519}
520
521#[inline(always)] // (always) necessary otherwise compiler fails to inline it
522#[cfg(feature = "safe-encode")]
523fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
524    output.extend_from_slice_wild(&input[input_start..input_start + len], len)
525}
526
527#[inline]
528#[cfg(not(feature = "safe-encode"))]
529fn copy_literals_wild(output: &mut impl Sink, input: &[u8], input_start: usize, len: usize) {
530    debug_assert!(input_start + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= input.len());
531    debug_assert!(output.pos() + len / 8 * 8 + ((len % 8) != 0) as usize * 8 <= output.capacity());
532    unsafe {
533        // Note: This used to be a wild copy loop of 8 bytes, but the compiler consistently
534        // transformed it into a call to memcopy, which hurts performance significantly for
535        // small copies, which are common.
536        let start_ptr = input.as_ptr().add(input_start);
537        match len {
538            0..=8 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 8),
539            9..=16 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 16),
540            17..=24 => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), 24),
541            _ => core::ptr::copy_nonoverlapping(start_ptr, output.pos_mut_ptr(), len),
542        }
543        output.set_pos(output.pos() + len);
544    }
545}
546
547/// Compress all bytes of `input` into `output`.
548/// The method chooses an appropriate hashtable to lookup duplicates.
549/// output should be preallocated with a size of
550/// `get_maximum_output_size`.
551///
552/// Returns the number of bytes written (compressed) into `output`.
553#[inline]
554pub(crate) fn compress_into_sink_with_dict<const USE_DICT: bool>(
555    input: &[u8],
556    output: &mut impl Sink,
557    mut dict_data: &[u8],
558) -> Result<usize, CompressError> {
559    if dict_data.len() + input.len() < u16::MAX as usize {
560        let mut dict = HashTable4KU16::new();
561        init_dict(&mut dict, &mut dict_data);
562        compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
563    } else {
564        let mut dict = HashTable4K::new();
565        init_dict(&mut dict, &mut dict_data);
566        compress_internal::<_, USE_DICT, _>(input, 0, output, &mut dict, dict_data, dict_data.len())
567    }
568}
569
570#[inline]
571fn init_dict<T: HashTable>(dict: &mut T, dict_data: &mut &[u8]) {
572    if dict_data.len() > WINDOW_SIZE {
573        *dict_data = &dict_data[dict_data.len() - WINDOW_SIZE..];
574    }
575    let mut i = 0usize;
576    while i + core::mem::size_of::<usize>() <= dict_data.len() {
577        let hash = T::get_hash_at(dict_data, i);
578        dict.put_at(hash, i);
579        // Note: The 3 byte step was copied from the reference implementation, it could be
580        // arbitrary.
581        i += 3;
582    }
583}
584
585/// Returns the maximum output size of the compressed data.
586/// Can be used to preallocate capacity on the output vector
587#[inline]
588pub const fn get_maximum_output_size(input_len: usize) -> usize {
589    16 + 4 + (input_len * 110 / 100)
590}
591
592/// Compress all bytes of `input` into `output`.
593/// The method chooses an appropriate hashtable to lookup duplicates.
594/// output should be preallocated with a size of
595/// `get_maximum_output_size`.
596///
597/// Returns the number of bytes written (compressed) into `output`.
598#[inline]
599pub fn compress_into(input: &[u8], output: &mut [u8]) -> Result<usize, CompressError> {
600    compress_into_sink_with_dict::<false>(input, &mut SliceSink::new(output, 0), b"")
601}
602
603/// Compress all bytes of `input` into `output`.
604/// The method chooses an appropriate hashtable to lookup duplicates.
605/// output should be preallocated with a size of
606/// `get_maximum_output_size`.
607///
608/// Returns the number of bytes written (compressed) into `output`.
609#[inline]
610pub fn compress_into_with_dict(
611    input: &[u8],
612    output: &mut [u8],
613    dict_data: &[u8],
614) -> Result<usize, CompressError> {
615    compress_into_sink_with_dict::<true>(input, &mut SliceSink::new(output, 0), dict_data)
616}
617
618#[inline]
619fn compress_into_vec_with_dict<const USE_DICT: bool>(
620    input: &[u8],
621    prepend_size: bool,
622    mut dict_data: &[u8],
623) -> Vec<u8> {
624    let prepend_size_num_bytes = if prepend_size { 4 } else { 0 };
625    let max_compressed_size = get_maximum_output_size(input.len()) + prepend_size_num_bytes;
626    if dict_data.len() <= 3 {
627        dict_data = b"";
628    }
629    #[cfg(feature = "safe-encode")]
630    let mut compressed = {
631        let mut compressed: Vec<u8> = vec![0u8; max_compressed_size];
632        let out = if prepend_size {
633            compressed[..4].copy_from_slice(&(input.len() as u32).to_le_bytes());
634            &mut compressed[4..]
635        } else {
636            &mut compressed
637        };
638        let compressed_len =
639            compress_into_sink_with_dict::<USE_DICT>(input, &mut SliceSink::new(out, 0), dict_data)
640                .unwrap();
641
642        compressed.truncate(prepend_size_num_bytes + compressed_len);
643        compressed
644    };
645    #[cfg(not(feature = "safe-encode"))]
646    let mut compressed = {
647        let mut vec = Vec::with_capacity(max_compressed_size);
648        let start_pos = if prepend_size {
649            vec.extend_from_slice(&(input.len() as u32).to_le_bytes());
650            4
651        } else {
652            0
653        };
654        let compressed_len = compress_into_sink_with_dict::<USE_DICT>(
655            input,
656            &mut PtrSink::from_vec(&mut vec, start_pos),
657            dict_data,
658        )
659        .unwrap();
660        unsafe {
661            vec.set_len(prepend_size_num_bytes + compressed_len);
662        }
663        vec
664    };
665
666    compressed.shrink_to_fit();
667    compressed
668}
669
670/// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
671/// endian u32. Can be used in conjunction with `decompress_size_prepended`
672#[inline]
673pub fn compress_prepend_size(input: &[u8]) -> Vec<u8> {
674    compress_into_vec_with_dict::<false>(input, true, b"")
675}
676
677/// Compress all bytes of `input`.
678#[inline]
679pub fn compress(input: &[u8]) -> Vec<u8> {
680    compress_into_vec_with_dict::<false>(input, false, b"")
681}
682
683/// Compress all bytes of `input` with an external dictionary.
684#[inline]
685pub fn compress_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
686    compress_into_vec_with_dict::<true>(input, false, ext_dict)
687}
688
689/// Compress all bytes of `input` into `output`. The uncompressed size will be prepended as a little
690/// endian u32. Can be used in conjunction with `decompress_size_prepended_with_dict`
691#[inline]
692pub fn compress_prepend_size_with_dict(input: &[u8], ext_dict: &[u8]) -> Vec<u8> {
693    compress_into_vec_with_dict::<true>(input, true, ext_dict)
694}
695
696#[inline]
697#[cfg(not(feature = "safe-encode"))]
698fn read_u16_ptr(input: *const u8) -> u16 {
699    let mut num: u16 = 0;
700    unsafe {
701        core::ptr::copy_nonoverlapping(input, &mut num as *mut u16 as *mut u8, 2);
702    }
703    num
704}
705
706#[inline]
707#[cfg(not(feature = "safe-encode"))]
708fn read_u32_ptr(input: *const u8) -> u32 {
709    let mut num: u32 = 0;
710    unsafe {
711        core::ptr::copy_nonoverlapping(input, &mut num as *mut u32 as *mut u8, 4);
712    }
713    num
714}
715
716#[inline]
717#[cfg(not(feature = "safe-encode"))]
718fn read_usize_ptr(input: *const u8) -> usize {
719    let mut num: usize = 0;
720    unsafe {
721        core::ptr::copy_nonoverlapping(
722            input,
723            &mut num as *mut usize as *mut u8,
724            core::mem::size_of::<usize>(),
725        );
726    }
727    num
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733
734    #[test]
735    fn test_count_same_bytes() {
736        // 8byte aligned block, zeros and ones are added because the end/offset
737        let first: &[u8] = &[
738            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
739        ];
740        let second: &[u8] = &[
741            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
742        ];
743        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 16);
744
745        // 4byte aligned block
746        let first: &[u8] = &[
747            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
748            0, 0, 0,
749        ];
750        let second: &[u8] = &[
751            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1,
752            1, 1, 1,
753        ];
754        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 20);
755
756        // 2byte aligned block
757        let first: &[u8] = &[
758            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 0, 0, 0, 0, 0, 0, 0,
759            0, 0, 0, 0, 0,
760        ];
761        let second: &[u8] = &[
762            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 1, 1, 1, 1, 1, 1, 1,
763            1, 1, 1, 1, 1,
764        ];
765        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
766
767        // 1byte aligned block
768        let first: &[u8] = &[
769            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
770            0, 0, 0, 0, 0, 0,
771        ];
772        let second: &[u8] = &[
773            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 1, 1, 1, 1, 1, 1,
774            1, 1, 1, 1, 1, 1,
775        ];
776        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 23);
777
778        // 1byte aligned block - last byte different
779        let first: &[u8] = &[
780            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 5, 0, 0, 0, 0, 0, 0,
781            0, 0, 0, 0, 0, 0,
782        ];
783        let second: &[u8] = &[
784            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
785            1, 1, 1, 1, 1, 1,
786        ];
787        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 22);
788
789        // 1byte aligned block
790        let first: &[u8] = &[
791            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 9, 5, 0, 0, 0, 0, 0, 0,
792            0, 0, 0, 0, 0, 0,
793        ];
794        let second: &[u8] = &[
795            1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 3, 4, 6, 1, 1, 1, 1, 1, 1,
796            1, 1, 1, 1, 1, 1,
797        ];
798        assert_eq!(count_same_bytes(first, &mut 0, second, 0), 21);
799
800        for diff_idx in 8..100 {
801            let first: Vec<u8> = (0u8..255).cycle().take(100 + 12).collect();
802            let mut second = first.clone();
803            second[diff_idx] = 255;
804            for start in 0..=diff_idx {
805                let same_bytes = count_same_bytes(&first, &mut start.clone(), &second, start);
806                assert_eq!(same_bytes, diff_idx - start);
807            }
808        }
809    }
810
811    #[test]
812    fn test_bug() {
813        let input: &[u8] = &[
814            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
815        ];
816        let _out = compress(input);
817    }
818
819    #[test]
820    fn test_dict() {
821        let input: &[u8] = &[
822            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
823        ];
824        let dict = input;
825        let compressed = compress_with_dict(input, dict);
826        assert_lt!(compressed.len(), compress(input).len());
827
828        assert!(compressed.len() < compress(input).len());
829        let mut uncompressed = vec![0u8; input.len()];
830        let uncomp_size = crate::block::decompress::decompress_into_with_dict(
831            &compressed,
832            &mut uncompressed,
833            dict,
834        )
835        .unwrap();
836        uncompressed.truncate(uncomp_size);
837        assert_eq!(input, uncompressed);
838    }
839
840    #[test]
841    fn test_dict_no_panic() {
842        let input: &[u8] = &[
843            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
844        ];
845        let dict = &[10, 12, 14];
846        let _compressed = compress_with_dict(input, dict);
847    }
848
849    #[test]
850    fn test_dict_match_crossing() {
851        let input: &[u8] = &[
852            10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18, 10, 12, 14, 16, 18,
853        ];
854        let dict = input;
855        let compressed = compress_with_dict(input, dict);
856        assert_lt!(compressed.len(), compress(input).len());
857
858        let mut uncompressed = vec![0u8; input.len() * 2];
859        // copy first half of the input into output
860        let dict_cutoff = dict.len() / 2;
861        let output_start = dict.len() - dict_cutoff;
862        uncompressed[..output_start].copy_from_slice(&dict[dict_cutoff..]);
863        let uncomp_len = {
864            let mut sink = SliceSink::new(&mut uncompressed[..], output_start);
865            crate::block::decompress::decompress_internal::<true, _>(
866                &compressed,
867                &mut sink,
868                &dict[..dict_cutoff],
869            )
870            .unwrap()
871        };
872        assert_eq!(input.len(), uncomp_len);
873        assert_eq!(
874            input,
875            &uncompressed[output_start..output_start + uncomp_len]
876        );
877    }
878
879    #[test]
880    fn test_conformant_last_block() {
881        // From the spec:
882        // The last match must start at least 12 bytes before the end of block.
883        // The last match is part of the penultimate sequence. It is followed by the last sequence,
884        // which contains only literals. Note that, as a consequence, an independent block <
885        // 13 bytes cannot be compressed, because the match must copy "something",
886        // so it needs at least one prior byte.
887        // When a block can reference data from another block, it can start immediately with a match
888        // and no literal, so a block of 12 bytes can be compressed.
889        let aaas: &[u8] = b"aaaaaaaaaaaaaaa";
890
891        // incompressible
892        let out = compress(&aaas[..12]);
893        assert_gt!(out.len(), 12);
894        // compressible
895        let out = compress(&aaas[..13]);
896        assert_le!(out.len(), 13);
897        let out = compress(&aaas[..14]);
898        assert_le!(out.len(), 14);
899        let out = compress(&aaas[..15]);
900        assert_le!(out.len(), 15);
901
902        // dict incompressible
903        let out = compress_with_dict(&aaas[..11], aaas);
904        assert_gt!(out.len(), 11);
905        // compressible
906        let out = compress_with_dict(&aaas[..12], aaas);
907        // According to the spec this _could_ compress, but it doesn't in this lib
908        // as it aborts compression for any input len < LZ4_MIN_LENGTH
909        assert_gt!(out.len(), 12);
910        let out = compress_with_dict(&aaas[..13], aaas);
911        assert_le!(out.len(), 13);
912        let out = compress_with_dict(&aaas[..14], aaas);
913        assert_le!(out.len(), 14);
914        let out = compress_with_dict(&aaas[..15], aaas);
915        assert_le!(out.len(), 15);
916    }
917
918    #[test]
919    fn test_dict_size() {
920        let dict = vec![b'a'; 1024 * 1024];
921        let input = &b"aaaaaaaaaaaaaaaaaaaaaaaaaaaaa"[..];
922        let compressed = compress_prepend_size_with_dict(input, &dict);
923        let decompressed =
924            crate::block::decompress_size_prepended_with_dict(&compressed, &dict).unwrap();
925        assert_eq!(decompressed, input);
926    }
927}