snap/
compress.rs

1use std::fmt;
2use std::ops::{Deref, DerefMut};
3use std::ptr;
4
5use crate::bytes;
6use crate::error::{Error, Result};
7use crate::{MAX_BLOCK_SIZE, MAX_INPUT_SIZE};
8
9/// The total number of slots we permit for our hash table of 4 byte repeat
10/// sequences.
11const MAX_TABLE_SIZE: usize = 1 << 14;
12
13/// The size of a small hash table. This is useful for reducing overhead when
14/// compressing very small blocks of bytes.
15const SMALL_TABLE_SIZE: usize = 1 << 10;
16
17/// The total number of bytes that we always leave uncompressed at the end
18/// of the buffer. This in particular affords us some wiggle room during
19/// compression such that faster copy operations can be used.
20const INPUT_MARGIN: usize = 16 - 1;
21
22/// The minimum block size that we're willing to consider for compression.
23/// Anything smaller than this gets emitted as a literal.
24const MIN_NON_LITERAL_BLOCK_SIZE: usize = 1 + 1 + INPUT_MARGIN;
25
26/// Nice names for the various Snappy tags.
27enum Tag {
28    Literal = 0b00,
29    Copy1 = 0b01,
30    Copy2 = 0b10,
31    // Compression never actually emits a Copy4 operation and decompression
32    // uses tricks so that we never explicitly do case analysis on the copy
33    // operation type, therefore leading to the fact that we never use Copy4.
34    #[allow(dead_code)]
35    Copy4 = 0b11,
36}
37
38/// Returns the maximum compressed size given the uncompressed size.
39///
40/// If the uncompressed size exceeds the maximum allowable size then this
41/// returns 0.
42pub fn max_compress_len(input_len: usize) -> usize {
43    let input_len = input_len as u64;
44    if input_len > MAX_INPUT_SIZE {
45        return 0;
46    }
47    let max = 32 + input_len + (input_len / 6);
48    if max > MAX_INPUT_SIZE {
49        0
50    } else {
51        max as usize
52    }
53}
54
55/// Encoder is a raw encoder for compressing bytes in the Snappy format.
56///
57/// Thie encoder does not use the Snappy frame format and simply compresses the
58/// given bytes in one big Snappy block (that is, it has a single header).
59///
60/// Unless you explicitly need the low-level control, you should use
61/// [`read::FrameEncoder`](../read/struct.FrameEncoder.html)
62/// or
63/// [`write::FrameEncoder`](../write/struct.FrameEncoder.html)
64/// instead, which compresses to the Snappy frame format.
65///
66/// It is beneficial to reuse an Encoder when possible.
67pub struct Encoder {
68    small: [u16; SMALL_TABLE_SIZE],
69    big: Vec<u16>,
70}
71
72impl fmt::Debug for Encoder {
73    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
74        write!(f, "Encoder(...)")
75    }
76}
77
78impl Encoder {
79    /// Return a new encoder that can be used for compressing bytes.
80    pub fn new() -> Encoder {
81        Encoder { small: [0; SMALL_TABLE_SIZE], big: vec![] }
82    }
83
84    /// Compresses all bytes in `input` into `output`.
85    ///
86    /// `input` can be any arbitrary sequence of bytes.
87    ///
88    /// `output` must be large enough to hold the maximum possible compressed
89    /// size of `input`, which can be computed using `max_compress_len`.
90    ///
91    /// On success, this returns the number of bytes written to `output`.
92    ///
93    /// # Errors
94    ///
95    /// This method returns an error in the following circumstances:
96    ///
97    /// * The total number of bytes to compress exceeds `2^32 - 1`.
98    /// * `output` has length less than `max_compress_len(input.len())`.
99    pub fn compress(
100        &mut self,
101        mut input: &[u8],
102        output: &mut [u8],
103    ) -> Result<usize> {
104        match max_compress_len(input.len()) {
105            0 => {
106                return Err(Error::TooBig {
107                    given: input.len() as u64,
108                    max: MAX_INPUT_SIZE,
109                });
110            }
111            min if output.len() < min => {
112                return Err(Error::BufferTooSmall {
113                    given: output.len() as u64,
114                    min: min as u64,
115                });
116            }
117            _ => {}
118        }
119        // Handle an edge case specially.
120        if input.is_empty() {
121            // Encodes a varint of 0, denoting the total size of uncompressed
122            // bytes.
123            output[0] = 0;
124            return Ok(1);
125        }
126        // Write the Snappy header, which is just the total number of
127        // uncompressed bytes.
128        let mut d = bytes::write_varu64(output, input.len() as u64);
129        while !input.is_empty() {
130            // Find the next block.
131            let mut src = input;
132            if src.len() > MAX_BLOCK_SIZE {
133                src = &src[..MAX_BLOCK_SIZE as usize];
134            }
135            input = &input[src.len()..];
136
137            // If the block is smallish, then don't waste time on it and just
138            // emit a literal.
139            let mut block = Block::new(src, output, d);
140            if block.src.len() < MIN_NON_LITERAL_BLOCK_SIZE {
141                let lit_end = block.src.len();
142                unsafe {
143                    // SAFETY: next_emit is zero (in bounds) and the end is
144                    // the length of the block (in bounds).
145                    block.emit_literal(lit_end);
146                }
147            } else {
148                let table = self.block_table(block.src.len());
149                block.compress(table);
150            }
151            d = block.d;
152        }
153        Ok(d)
154    }
155
156    /// Compresses all bytes in `input` into a freshly allocated `Vec`.
157    ///
158    /// This is just like the `compress` method, except it allocates a `Vec`
159    /// with the right size for you. (This is intended to be a convenience
160    /// method.)
161    ///
162    /// This method returns an error under the same circumstances that
163    /// `compress` does.
164    pub fn compress_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> {
165        let mut buf = vec![0; max_compress_len(input.len())];
166        let n = self.compress(input, &mut buf)?;
167        buf.truncate(n);
168        Ok(buf)
169    }
170}
171
172struct Block<'s, 'd> {
173    src: &'s [u8],
174    s: usize,
175    s_limit: usize,
176    dst: &'d mut [u8],
177    d: usize,
178    next_emit: usize,
179}
180
181impl<'s, 'd> Block<'s, 'd> {
182    #[inline(always)]
183    fn new(src: &'s [u8], dst: &'d mut [u8], d: usize) -> Block<'s, 'd> {
184        Block {
185            src: src,
186            s: 0,
187            s_limit: src.len(),
188            dst: dst,
189            d: d,
190            next_emit: 0,
191        }
192    }
193
194    #[inline(always)]
195    fn compress(&mut self, mut table: BlockTable<'_>) {
196        debug_assert!(!table.is_empty());
197        debug_assert!(self.src.len() >= MIN_NON_LITERAL_BLOCK_SIZE);
198
199        self.s += 1;
200        self.s_limit -= INPUT_MARGIN;
201        let mut next_hash =
202            table.hash(bytes::read_u32_le(&self.src[self.s..]));
203        loop {
204            let mut skip = 32;
205            let mut candidate;
206            let mut s_next = self.s;
207            loop {
208                self.s = s_next;
209                let bytes_between_hash_lookups = skip >> 5;
210                s_next = self.s + bytes_between_hash_lookups;
211                skip += bytes_between_hash_lookups;
212                if s_next > self.s_limit {
213                    return self.done();
214                }
215                unsafe {
216                    // SAFETY: next_hash is always computed by table.hash
217                    // which is guaranteed to be in bounds.
218                    candidate = *table.get_unchecked(next_hash) as usize;
219                    *table.get_unchecked_mut(next_hash) = self.s as u16;
220
221                    let srcp = self.src.as_ptr();
222                    // SAFETY: s_next is guaranteed to be less than s_limit by
223                    // the conditional above, which implies s_next is in
224                    // bounds.
225                    let x = bytes::loadu_u32_le(srcp.add(s_next));
226                    next_hash = table.hash(x);
227                    // SAFETY: self.s is always less than s_next, so it is also
228                    // in bounds by the argument above.
229                    //
230                    // candidate is extracted from table, which is only ever
231                    // set to valid positions in the block and is therefore
232                    // also in bounds.
233                    //
234                    // We only need to compare y/z for equality, so we don't
235                    // need to both with endianness. cur corresponds to the
236                    // bytes at the current position and cand corresponds to
237                    // a potential match. If they're equal, we declare victory
238                    // and move below to try and extend the match.
239                    let cur = bytes::loadu_u32_ne(srcp.add(self.s));
240                    let cand = bytes::loadu_u32_ne(srcp.add(candidate));
241                    if cur == cand {
242                        break;
243                    }
244                }
245            }
246            // While the above found a candidate for compression, before we
247            // emit a copy operation for it, we need to make sure that we emit
248            // any bytes between the last copy operation and this one as a
249            // literal.
250            let lit_end = self.s;
251            unsafe {
252                // SAFETY: next_emit is set to a previous value of self.s,
253                // which is guaranteed to be less than s_limit (in bounds).
254                // lit_end is set to the current value of self.s, also
255                // guaranteed to be less than s_limit (in bounds).
256                self.emit_literal(lit_end);
257            }
258            loop {
259                // Look for more matching bytes starting at the position of
260                // the candidate and the current src position. We increment
261                // self.s and candidate by 4 since we already know the first 4
262                // bytes match.
263                let base = self.s;
264                self.s += 4;
265                unsafe {
266                    // SAFETY: candidate is always set to a value from our
267                    // hash table, which only contains positions in self.src
268                    // that have been seen for this block that occurred before
269                    // self.s.
270                    self.extend_match(candidate + 4);
271                }
272                let (offset, len) = (base - candidate, self.s - base);
273                self.emit_copy(offset, len);
274                self.next_emit = self.s;
275                if self.s >= self.s_limit {
276                    return self.done();
277                }
278                // Update the hash table with the byte sequences
279                // self.src[self.s - 1..self.s + 3] and
280                // self.src[self.s..self.s + 4]. Instead of reading 4 bytes
281                // twice, we read 8 bytes once.
282                //
283                // If we happen to get a hit on self.src[self.s..self.s + 4],
284                // then continue this loop and extend the match.
285                unsafe {
286                    let srcp = self.src.as_ptr();
287                    // SAFETY: self.s can never exceed s_limit given by the
288                    // conditional above and self.s is guaranteed to be
289                    // non-zero and is therefore in bounds.
290                    let x = bytes::loadu_u64_le(srcp.add(self.s - 1));
291                    // The lower 4 bytes of x correspond to
292                    // self.src[self.s - 1..self.s + 3].
293                    let prev_hash = table.hash(x as u32);
294                    // SAFETY: Hash values are guaranteed to be in bounds.
295                    *table.get_unchecked_mut(prev_hash) = (self.s - 1) as u16;
296                    // The lower 4 bytes of x>>8 correspond to
297                    // self.src[self.s..self.s + 4].
298                    let cur_hash = table.hash((x >> 8) as u32);
299                    // SAFETY: Hash values are guaranteed to be in bounds.
300                    candidate = *table.get_unchecked(cur_hash) as usize;
301                    *table.get_unchecked_mut(cur_hash) = self.s as u16;
302
303                    // SAFETY: candidate is set from table, which always
304                    // contains valid positions in the current block.
305                    let y = bytes::loadu_u32_le(srcp.add(candidate));
306                    if (x >> 8) as u32 != y {
307                        // If we didn't get a hit, update the next hash
308                        // and move on. Our initial 8 byte read continues to
309                        // pay off.
310                        next_hash = table.hash((x >> 16) as u32);
311                        self.s += 1;
312                        break;
313                    }
314                }
315            }
316        }
317    }
318
319    /// Emits one or more copy operations with the given offset and length.
320    /// offset must be in the range [1, 65535] and len must be in the range
321    /// [4, 65535].
322    #[inline(always)]
323    fn emit_copy(&mut self, offset: usize, mut len: usize) {
324        debug_assert!(1 <= offset && offset <= 65535);
325        // Copy operations only allow lengths up to 64, but we'll allow bigger
326        // lengths and emit as many operations as we need.
327        //
328        // N.B. Since our block size is 64KB, we never actually emit a copy 4
329        // operation.
330        debug_assert!(4 <= len && len <= 65535);
331
332        // Emit copy 2 operations until we don't have to.
333        // We check on 68 here and emit a shorter copy than 64 below because
334        // it is cheaper to, e.g., encode a length 67 copy as a length 60
335        // copy 2 followed by a length 7 copy 1 than to encode it as a length
336        // 64 copy 2 followed by a length 3 copy 2. They key here is that a
337        // copy 1 operation requires at least length 4 which forces a length 3
338        // copy to use a copy 2 operation.
339        while len >= 68 {
340            self.emit_copy2(offset, 64);
341            len -= 64;
342        }
343        if len > 64 {
344            self.emit_copy2(offset, 60);
345            len -= 60;
346        }
347        // If we can squeeze the last copy into a copy 1 operation, do it.
348        if len <= 11 && offset <= 2047 {
349            self.dst[self.d] = (((offset >> 8) as u8) << 5)
350                | (((len - 4) as u8) << 2)
351                | (Tag::Copy1 as u8);
352            self.dst[self.d + 1] = offset as u8;
353            self.d += 2;
354        } else {
355            self.emit_copy2(offset, len);
356        }
357    }
358
359    /// Emits a "copy 2" operation with the given offset and length. The
360    /// offset and length must be valid for a copy 2 operation. i.e., offset
361    /// must be in the range [1, 65535] and len must be in the range [1, 64].
362    #[inline(always)]
363    fn emit_copy2(&mut self, offset: usize, len: usize) {
364        debug_assert!(1 <= offset && offset <= 65535);
365        debug_assert!(1 <= len && len <= 64);
366        self.dst[self.d] = (((len - 1) as u8) << 2) | (Tag::Copy2 as u8);
367        bytes::write_u16_le(offset as u16, &mut self.dst[self.d + 1..]);
368        self.d += 3;
369    }
370
371    /// Attempts to extend a match from the current position in self.src with
372    /// the candidate position given.
373    ///
374    /// This method uses unaligned loads and elides bounds checks, so the
375    /// caller must guarantee that cand points to a valid location in self.src
376    /// and is less than the current position in src.
377    #[inline(always)]
378    unsafe fn extend_match(&mut self, mut cand: usize) {
379        debug_assert!(cand < self.s);
380        while self.s + 8 <= self.src.len() {
381            let srcp = self.src.as_ptr();
382            // SAFETY: The loop invariant guarantees that there is at least
383            // 8 bytes to read at self.src + self.s. Since cand must be
384            // guaranteed by the caller to be valid and less than self.s, it
385            // also has enough room to read 8 bytes.
386            //
387            // TODO(ag): Despite my best efforts, I couldn't get this to
388            // autovectorize with 128-bit loads. The logic after the loads
389            // appears to be a little too clever...
390            let x = bytes::loadu_u64_ne(srcp.add(self.s));
391            let y = bytes::loadu_u64_ne(srcp.add(cand));
392            if x == y {
393                // If all 8 bytes are equal, move on...
394                self.s += 8;
395                cand += 8;
396            } else {
397                // Otherwise, find the last byte that was equal. We can do
398                // this efficiently by interpreted x/y as little endian
399                // numbers, which lets us use the number of trailing zeroes
400                // as a proxy for the number of equivalent bits (after an XOR).
401                let z = x.to_le() ^ y.to_le();
402                self.s += z.trailing_zeros() as usize / 8;
403                return;
404            }
405        }
406        // When we have fewer than 8 bytes left in the block, fall back to the
407        // slow loop.
408        while self.s < self.src.len() && self.src[self.s] == self.src[cand] {
409            self.s += 1;
410            cand += 1;
411        }
412    }
413
414    /// Executes any cleanup when the current block has finished compressing.
415    /// In particular, it emits any leftover bytes as a literal.
416    #[inline(always)]
417    fn done(&mut self) {
418        if self.next_emit < self.src.len() {
419            let lit_end = self.src.len();
420            unsafe {
421                // SAFETY: Both next_emit and lit_end are trivially in bounds
422                // given the conditional and definition above.
423                self.emit_literal(lit_end);
424            }
425        }
426    }
427
428    /// Emits a literal from self.src[self.next_emit..lit_end].
429    ///
430    /// This uses unaligned loads and elides bounds checks, so the caller must
431    /// guarantee that self.src[self.next_emit..lit_end] is valid.
432    #[inline(always)]
433    unsafe fn emit_literal(&mut self, lit_end: usize) {
434        let lit_start = self.next_emit;
435        let len = lit_end - lit_start;
436        let n = len.checked_sub(1).unwrap();
437        if n <= 59 {
438            self.dst[self.d] = ((n as u8) << 2) | (Tag::Literal as u8);
439            self.d += 1;
440            if len <= 16 && lit_start + 16 <= self.src.len() {
441                // SAFETY: lit_start is equivalent to self.next_emit, which is
442                // only set to self.s immediately following a copy emit. The
443                // conditional above also ensures that there is at least 16
444                // bytes of room in both src and dst.
445                //
446                // dst is big enough because the buffer is guaranteed to
447                // be big enough to hold biggest possible compressed size plus
448                // an extra 32 bytes, which exceeds the 16 byte copy here.
449                let srcp = self.src.as_ptr().add(lit_start);
450                let dstp = self.dst.as_mut_ptr().add(self.d);
451                ptr::copy_nonoverlapping(srcp, dstp, 16);
452                self.d += len;
453                return;
454            }
455        } else if n < 256 {
456            self.dst[self.d] = (60 << 2) | (Tag::Literal as u8);
457            self.dst[self.d + 1] = n as u8;
458            self.d += 2;
459        } else {
460            self.dst[self.d] = (61 << 2) | (Tag::Literal as u8);
461            bytes::write_u16_le(n as u16, &mut self.dst[self.d + 1..]);
462            self.d += 3;
463        }
464        // SAFETY: lit_start is equivalent to self.next_emit, which is only set
465        // to self.s immediately following a copy, which implies that it always
466        // points to valid bytes in self.src.
467        //
468        // We can't guarantee that there are at least len bytes though, which
469        // must be guaranteed by the caller and is why this method is unsafe.
470        let srcp = self.src.as_ptr().add(lit_start);
471        let dstp = self.dst.as_mut_ptr().add(self.d);
472        ptr::copy_nonoverlapping(srcp, dstp, len);
473        self.d += len;
474    }
475}
476
477/// `BlockTable` is a map from 4 byte sequences to positions of their most
478/// recent occurrence in a block. In particular, this table lets us quickly
479/// find candidates for compression.
480///
481/// We expose the `hash` method so that callers can be fastidious about the
482/// number of times a hash is computed.
483struct BlockTable<'a> {
484    table: &'a mut [u16],
485    /// The number of bits required to shift the hash such that the result
486    /// is less than table.len().
487    shift: u32,
488}
489
490impl Encoder {
491    fn block_table(&mut self, block_size: usize) -> BlockTable<'_> {
492        let mut shift: u32 = 32 - 8;
493        let mut table_size = 256;
494        while table_size < MAX_TABLE_SIZE && table_size < block_size {
495            shift -= 1;
496            table_size *= 2;
497        }
498        // If our block size is small, then use a small stack allocated table
499        // instead of putting a bigger one on the heap. This particular
500        // optimization is important if the caller is using Snappy to compress
501        // many small blocks. (The memset savings alone is considerable.)
502        let table: &mut [u16] = if table_size <= SMALL_TABLE_SIZE {
503            &mut self.small[0..table_size]
504        } else {
505            if self.big.is_empty() {
506                // Interestingly, using `self.big.resize` here led to some
507                // very weird code getting generated that led to a large
508                // slow down. Forcing the issue with a new vec seems to
509                // fix it. ---AG
510                self.big = vec![0; MAX_TABLE_SIZE];
511            }
512            &mut self.big[0..table_size]
513        };
514        for x in &mut *table {
515            *x = 0;
516        }
517        BlockTable { table: table, shift: shift }
518    }
519}
520
521impl<'a> BlockTable<'a> {
522    #[inline(always)]
523    fn hash(&self, x: u32) -> usize {
524        (x.wrapping_mul(0x1E35A7BD) >> self.shift) as usize
525    }
526}
527
528impl<'a> Deref for BlockTable<'a> {
529    type Target = [u16];
530    fn deref(&self) -> &[u16] {
531        self.table
532    }
533}
534
535impl<'a> DerefMut for BlockTable<'a> {
536    fn deref_mut(&mut self) -> &mut [u16] {
537        self.table
538    }
539}