snap/decompress.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
use std::ptr;
use crate::bytes;
use crate::error::{Error, Result};
use crate::tag;
use crate::MAX_INPUT_SIZE;
/// A lookup table for quickly computing the various attributes derived from a
/// tag byte.
const TAG_LOOKUP_TABLE: TagLookupTable = TagLookupTable(tag::TAG_LOOKUP_TABLE);
/// `WORD_MASK` is a map from the size of an integer in bytes to its
/// corresponding on a 32 bit integer. This is used when we need to read an
/// integer and we know there are at least 4 bytes to read from a buffer. In
/// this case, we can read a 32 bit little endian integer and mask out only the
/// bits we need. This in particular saves a branch.
const WORD_MASK: [usize; 5] = [0, 0xFF, 0xFFFF, 0xFFFFFF, 0xFFFFFFFF];
/// Returns the decompressed size (in bytes) of the compressed bytes given.
///
/// `input` must be a sequence of bytes returned by a conforming Snappy
/// compressor.
///
/// # Errors
///
/// This function returns an error in the following circumstances:
///
/// * An invalid Snappy header was seen.
/// * The total space required for decompression exceeds `2^32 - 1`.
pub fn decompress_len(input: &[u8]) -> Result<usize> {
if input.is_empty() {
return Ok(0);
}
Ok(Header::read(input)?.decompress_len)
}
/// Decoder is a raw decoder for decompressing bytes in the Snappy format.
///
/// This decoder does not use the Snappy frame format and simply decompresses
/// the given bytes as if it were returned from `Encoder`.
///
/// Unless you explicitly need the low-level control, you should use
/// [`read::FrameDecoder`](../read/struct.FrameDecoder.html)
/// instead, which decompresses the Snappy frame format.
#[derive(Clone, Debug, Default)]
pub struct Decoder {
// Place holder for potential future fields.
_dummy: (),
}
impl Decoder {
/// Return a new decoder that can be used for decompressing bytes.
pub fn new() -> Decoder {
Decoder { _dummy: () }
}
/// Decompresses all bytes in `input` into `output`.
///
/// `input` must be a sequence of bytes returned by a conforming Snappy
/// compressor.
///
/// The size of `output` must be large enough to hold all decompressed
/// bytes from the `input`. The size required can be queried with the
/// `decompress_len` function.
///
/// On success, this returns the number of bytes written to `output`.
///
/// # Errors
///
/// This method returns an error in the following circumstances:
///
/// * Invalid compressed Snappy data was seen.
/// * The total space required for decompression exceeds `2^32 - 1`.
/// * `output` has length less than `decompress_len(input)`.
pub fn decompress(
&mut self,
input: &[u8],
output: &mut [u8],
) -> Result<usize> {
if input.is_empty() {
return Err(Error::Empty);
}
let hdr = Header::read(input)?;
if hdr.decompress_len > output.len() {
return Err(Error::BufferTooSmall {
given: output.len() as u64,
min: hdr.decompress_len as u64,
});
}
let dst = &mut output[..hdr.decompress_len];
let mut dec =
Decompress { src: &input[hdr.len..], s: 0, dst: dst, d: 0 };
dec.decompress()?;
Ok(dec.dst.len())
}
/// Decompresses all bytes in `input` into a freshly allocated `Vec`.
///
/// This is just like the `decompress` method, except it allocates a `Vec`
/// with the right size for you. (This is intended to be a convenience
/// method.)
///
/// This method returns an error under the same circumstances that
/// `decompress` does.
pub fn decompress_vec(&mut self, input: &[u8]) -> Result<Vec<u8>> {
let mut buf = vec![0; decompress_len(input)?];
let n = self.decompress(input, &mut buf)?;
buf.truncate(n);
Ok(buf)
}
}
/// Decompress is the state of the Snappy compressor.
struct Decompress<'s, 'd> {
/// The original compressed bytes not including the header.
src: &'s [u8],
/// The current position in the compressed bytes.
s: usize,
/// The output buffer to write the decompressed bytes.
dst: &'d mut [u8],
/// The current position in the decompressed buffer.
d: usize,
}
impl<'s, 'd> Decompress<'s, 'd> {
/// Decompresses snappy compressed bytes in `src` to `dst`.
///
/// This assumes that the header has already been read and that `dst` is
/// big enough to store all decompressed bytes.
fn decompress(&mut self) -> Result<()> {
while self.s < self.src.len() {
let byte = self.src[self.s];
self.s += 1;
if byte & 0b000000_11 == 0 {
let len = (byte >> 2) as usize + 1;
self.read_literal(len)?;
} else {
self.read_copy(byte)?;
}
}
if self.d != self.dst.len() {
return Err(Error::HeaderMismatch {
expected_len: self.dst.len() as u64,
got_len: self.d as u64,
});
}
Ok(())
}
/// Decompresses a literal from `src` starting at `s` to `dst` starting at
/// `d` and returns the updated values of `s` and `d`. `s` should point to
/// the byte immediately proceding the literal tag byte.
///
/// `len` is the length of the literal if it's <=60. Otherwise, it's the
/// length tag, indicating the number of bytes needed to read a little
/// endian integer at `src[s..]`. i.e., `61 => 1 byte`, `62 => 2 bytes`,
/// `63 => 3 bytes` and `64 => 4 bytes`.
///
/// `len` must be <=64.
#[inline(always)]
fn read_literal(&mut self, len: usize) -> Result<()> {
debug_assert!(len <= 64);
let mut len = len as u64;
// As an optimization for the common case, if the literal length is
// <=16 and we have enough room in both `src` and `dst`, copy the
// literal using unaligned loads and stores.
//
// We pick 16 bytes with the hope that it optimizes down to a 128 bit
// load/store.
if len <= 16
&& self.s + 16 <= self.src.len()
&& self.d + 16 <= self.dst.len()
{
unsafe {
// SAFETY: We know both src and dst have at least 16 bytes of
// wiggle room after s/d, even if `len` is <16, so the copy is
// safe.
let srcp = self.src.as_ptr().add(self.s);
let dstp = self.dst.as_mut_ptr().add(self.d);
// Hopefully uses SIMD registers for 128 bit load/store.
ptr::copy_nonoverlapping(srcp, dstp, 16);
}
self.d += len as usize;
self.s += len as usize;
return Ok(());
}
// When the length is bigger than 60, it indicates that we need to read
// an additional 1-4 bytes to get the real length of the literal.
if len >= 61 {
// If there aren't at least 4 bytes left to read then we know this
// is corrupt because the literal must have length >=61.
if self.s as u64 + 4 > self.src.len() as u64 {
return Err(Error::Literal {
len: 4,
src_len: (self.src.len() - self.s) as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
// Since we know there are 4 bytes left to read, read a 32 bit LE
// integer and mask away the bits we don't need.
let byte_count = len as usize - 60;
len = bytes::read_u32_le(&self.src[self.s..]) as u64;
len = (len & (WORD_MASK[byte_count] as u64)) + 1;
self.s += byte_count;
}
// If there's not enough buffer left to load or store this literal,
// then the input is corrupt.
// if self.s + len > self.src.len() || self.d + len > self.dst.len() {
if ((self.src.len() - self.s) as u64) < len
|| ((self.dst.len() - self.d) as u64) < len
{
return Err(Error::Literal {
len: len,
src_len: (self.src.len() - self.s) as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
unsafe {
// SAFETY: We've already checked the bounds, so we know this copy
// is correct.
let srcp = self.src.as_ptr().add(self.s);
let dstp = self.dst.as_mut_ptr().add(self.d);
ptr::copy_nonoverlapping(srcp, dstp, len as usize);
}
self.s += len as usize;
self.d += len as usize;
Ok(())
}
/// Reads a copy from `src` and writes the decompressed bytes to `dst`. `s`
/// should point to the byte immediately proceding the copy tag byte.
#[inline(always)]
fn read_copy(&mut self, tag_byte: u8) -> Result<()> {
// Find the copy offset and len, then advance the input past the copy.
// The rest of this function deals with reading/writing to output only.
let entry = TAG_LOOKUP_TABLE.entry(tag_byte);
let offset = entry.offset(self.src, self.s)?;
let len = entry.len();
self.s += entry.num_tag_bytes();
// What we really care about here is whether `d == 0` or `d < offset`.
// To save an extra branch, use `d < offset - 1` instead. If `d` is
// `0`, then `offset.wrapping_sub(1)` will be usize::MAX which is also
// the max value of `d`.
if self.d <= offset.wrapping_sub(1) {
return Err(Error::Offset {
offset: offset as u64,
dst_pos: self.d as u64,
});
}
// When all is said and done, dst is advanced to end.
let end = self.d + len;
// When the copy is small and the offset is at least 8 bytes away from
// `d`, then we can decompress the copy with two 64 bit unaligned
// loads/stores.
if offset >= 8 && len <= 16 && self.d + 16 <= self.dst.len() {
unsafe {
// SAFETY: We know dstp points to at least 16 bytes of memory
// from the condition above, and we also know that dstp is
// preceded by at least `offset` bytes from the `d <= offset`
// check above.
//
// We also know that dstp and dstp-8 do not overlap from the
// check above, justifying the use of copy_nonoverlapping.
let dstp = self.dst.as_mut_ptr().add(self.d);
let srcp = dstp.sub(offset);
// We can't do a single 16 byte load/store because src/dst may
// overlap with each other. Namely, the second copy here may
// copy bytes written in the first copy!
ptr::copy_nonoverlapping(srcp, dstp, 8);
ptr::copy_nonoverlapping(srcp.add(8), dstp.add(8), 8);
}
// If we have some wiggle room, try to decompress the copy 16 bytes
// at a time with 128 bit unaligned loads/stores. Remember, we can't
// just do a memcpy because decompressing copies may require copying
// overlapping memory.
//
// We need the extra wiggle room to make effective use of 128 bit
// loads/stores. Even if the store ends up copying more data than we
// need, we're careful to advance `d` by the correct amount at the end.
} else if end + 24 <= self.dst.len() {
unsafe {
// SAFETY: We know that dstp is preceded by at least `offset`
// bytes from the `d <= offset` check above.
//
// We don't know whether dstp overlaps with srcp, so we start
// by copying from srcp to dstp until they no longer overlap.
// The worst case is when dstp-src = 3 and copy length = 1. The
// first loop will issue these copy operations before stopping:
//
// [-1, 14] -> [0, 15]
// [-1, 14] -> [3, 18]
// [-1, 14] -> [9, 24]
//
// But the copy had length 1, so it was only supposed to write
// to [0, 0]. But the last copy wrote to [9, 24], which is 24
// extra bytes in dst *beyond* the end of the copy, which is
// guaranteed by the conditional above.
let mut dstp = self.dst.as_mut_ptr().add(self.d);
let mut srcp = dstp.sub(offset);
loop {
debug_assert!(dstp >= srcp);
let diff = (dstp as usize) - (srcp as usize);
if diff >= 16 {
break;
}
// srcp and dstp can overlap, so use ptr::copy.
debug_assert!(self.d + 16 <= self.dst.len());
ptr::copy(srcp, dstp, 16);
self.d += diff as usize;
dstp = dstp.add(diff);
}
while self.d < end {
ptr::copy_nonoverlapping(srcp, dstp, 16);
srcp = srcp.add(16);
dstp = dstp.add(16);
self.d += 16;
}
// At this point, `d` is likely wrong. We correct it before
// returning. It's correct value is `end`.
}
} else {
if end > self.dst.len() {
return Err(Error::CopyWrite {
len: len as u64,
dst_len: (self.dst.len() - self.d) as u64,
});
}
// Finally, the slow byte-by-byte case, which should only be used
// for the last few bytes of decompression.
while self.d != end {
self.dst[self.d] = self.dst[self.d - offset];
self.d += 1;
}
}
self.d = end;
Ok(())
}
}
/// Header represents the single varint that starts every Snappy compressed
/// block.
#[derive(Debug)]
struct Header {
/// The length of the header in bytes (i.e., the varint).
len: usize,
/// The length of the original decompressed input in bytes.
decompress_len: usize,
}
impl Header {
/// Reads the varint header from the given input.
///
/// If there was a problem reading the header then an error is returned.
/// If a header is returned then it is guaranteed to be valid.
#[inline(always)]
fn read(input: &[u8]) -> Result<Header> {
let (decompress_len, header_len) = bytes::read_varu64(input);
if header_len == 0 {
return Err(Error::Header);
}
if decompress_len > MAX_INPUT_SIZE {
return Err(Error::TooBig {
given: decompress_len as u64,
max: MAX_INPUT_SIZE,
});
}
Ok(Header { len: header_len, decompress_len: decompress_len as usize })
}
}
/// A lookup table for quickly computing the various attributes derived from
/// a tag byte. The attributes are most useful for the three "copy" tags
/// and include the length of the copy, part of the offset (for copy 1-byte
/// only) and the total number of bytes proceding the tag byte that encode
/// the other part of the offset (1 for copy 1, 2 for copy 2 and 4 for copy 4).
///
/// More specifically, the keys of the table are u8s and the values are u16s.
/// The bits of the values are laid out as follows:
///
/// xxaa abbb xxcc cccc
///
/// Where `a` is the number of bytes, `b` are the three bits of the offset
/// for copy 1 (the other 8 bits are in the byte proceding the tag byte; for
/// copy 2 and copy 4, `b = 0`), and `c` is the length of the copy (max of 64).
///
/// We could pack this in fewer bits, but the position of the three `b` bits
/// lines up with the most significant three bits in the total offset for copy
/// 1, which avoids an extra shift instruction.
///
/// In sum, this table is useful because it reduces branches and various
/// arithmetic operations.
struct TagLookupTable([u16; 256]);
impl TagLookupTable {
/// Look up the tag entry given the tag `byte`.
#[inline(always)]
fn entry(&self, byte: u8) -> TagEntry {
TagEntry(self.0[byte as usize] as usize)
}
}
/// Represents a single entry in the tag lookup table.
///
/// See the documentation in `TagLookupTable` for the bit layout.
///
/// The type is a `usize` for convenience.
struct TagEntry(usize);
impl TagEntry {
/// Return the total number of bytes proceding this tag byte required to
/// encode the offset.
fn num_tag_bytes(&self) -> usize {
self.0 >> 11
}
/// Return the total copy length, capped at 255.
fn len(&self) -> usize {
self.0 & 0xFF
}
/// Return the copy offset corresponding to this copy operation. `s` should
/// point to the position just after the tag byte that this entry was read
/// from.
///
/// This requires reading from the compressed input since the offset is
/// encoded in bytes proceding the tag byte.
fn offset(&self, src: &[u8], s: usize) -> Result<usize> {
let num_tag_bytes = self.num_tag_bytes();
let trailer =
// It is critical for this case to come first, since it is the
// fast path. We really hope that this case gets branch
// predicted.
if s + 4 <= src.len() {
unsafe {
// SAFETY: The conditional above guarantees that
// src[s..s+4] is valid to read from.
let p = src.as_ptr().add(s);
// We use WORD_MASK here to mask out the bits we don't
// need. While we're guaranteed to read 4 valid bytes,
// not all of those bytes are necessarily part of the
// offset. This is the key optimization: we don't need to
// branch on num_tag_bytes.
bytes::loadu_u32_le(p) as usize & WORD_MASK[num_tag_bytes]
}
} else if num_tag_bytes == 1 {
if s >= src.len() {
return Err(Error::CopyRead {
len: 1,
src_len: (src.len() - s) as u64,
});
}
src[s] as usize
} else if num_tag_bytes == 2 {
if s + 1 >= src.len() {
return Err(Error::CopyRead {
len: 2,
src_len: (src.len() - s) as u64,
});
}
bytes::read_u16_le(&src[s..]) as usize
} else {
return Err(Error::CopyRead {
len: num_tag_bytes as u64,
src_len: (src.len() - s) as u64,
});
};
Ok((self.0 & 0b0000_0111_0000_0000) | trailer)
}
}