1use std::{
2 fmt,
3 hash::Hasher,
4 io::{self, BufRead, ErrorKind},
5 mem::size_of,
6};
7use twox_hash::XxHash32;
8
9use super::header::{
10 BlockInfo, BlockMode, FrameInfo, LZ4F_LEGACY_MAGIC_NUMBER, MAGIC_NUMBER_SIZE,
11 MAX_FRAME_INFO_SIZE, MIN_FRAME_INFO_SIZE,
12};
13use super::Error;
14use crate::{
15 block::WINDOW_SIZE,
16 sink::{vec_sink_for_decompression, SliceSink},
17};
18
19pub struct FrameDecoder<R: io::Read> {
49 r: R,
51 current_frame_info: Option<FrameInfo>,
55 content_hasher: XxHash32,
57 content_len: u64,
59 src: Vec<u8>,
61 dst: Vec<u8>,
64 ext_dict_offset: usize,
67 ext_dict_len: usize,
68 dst_start: usize,
70 dst_end: usize,
72}
73
74impl<R: io::Read> FrameDecoder<R> {
75 pub fn new(rdr: R) -> FrameDecoder<R> {
77 FrameDecoder {
78 r: rdr,
79 src: Default::default(),
80 dst: Default::default(),
81 ext_dict_offset: 0,
82 ext_dict_len: 0,
83 dst_start: 0,
84 dst_end: 0,
85 current_frame_info: None,
86 content_hasher: XxHash32::with_seed(0),
87 content_len: 0,
88 }
89 }
90
91 pub fn get_ref(&self) -> &R {
93 &self.r
94 }
95
96 pub fn get_mut(&mut self) -> &mut R {
101 &mut self.r
102 }
103
104 pub fn into_inner(self) -> R {
106 self.r
107 }
108
109 fn read_frame_info(&mut self) -> Result<usize, io::Error> {
110 let mut buffer = [0u8; MAX_FRAME_INFO_SIZE];
111
112 match self.r.read(&mut buffer[..MAGIC_NUMBER_SIZE])? {
113 0 => return Ok(0),
114 MAGIC_NUMBER_SIZE => (),
115 read => self.r.read_exact(&mut buffer[read..MAGIC_NUMBER_SIZE])?,
116 }
117
118 if u32::from_le_bytes(buffer[0..MAGIC_NUMBER_SIZE].try_into().unwrap())
119 != LZ4F_LEGACY_MAGIC_NUMBER
120 {
121 match self
122 .r
123 .read(&mut buffer[MAGIC_NUMBER_SIZE..MIN_FRAME_INFO_SIZE])?
124 {
125 0 => return Ok(0),
126 MIN_FRAME_INFO_SIZE => (),
127 read => self
128 .r
129 .read_exact(&mut buffer[MAGIC_NUMBER_SIZE + read..MIN_FRAME_INFO_SIZE])?,
130 }
131 }
132 let required = FrameInfo::read_size(&buffer[..MIN_FRAME_INFO_SIZE])?;
133 if required != MIN_FRAME_INFO_SIZE && required != MAGIC_NUMBER_SIZE {
134 self.r
135 .read_exact(&mut buffer[MIN_FRAME_INFO_SIZE..required])?;
136 }
137
138 let frame_info = FrameInfo::read(&buffer[..required])?;
139 if frame_info.dict_id.is_some() {
140 return Err(Error::DictionaryNotSupported.into());
142 }
143
144 let max_block_size = frame_info.block_size.get_size();
145 let dst_size = if frame_info.block_mode == BlockMode::Linked {
146 max_block_size * 2 + WINDOW_SIZE
154 } else {
155 max_block_size
156 };
157 self.src.clear();
158 self.dst.clear();
159 self.src.reserve_exact(max_block_size);
160 self.dst.reserve_exact(dst_size);
161 self.current_frame_info = Some(frame_info);
162 self.content_hasher = XxHash32::with_seed(0);
163 self.content_len = 0;
164 self.ext_dict_len = 0;
165 self.dst_start = 0;
166 self.dst_end = 0;
167 Ok(required)
168 }
169
170 #[inline]
171 fn read_checksum(r: &mut R) -> Result<u32, io::Error> {
172 let mut checksum_buffer = [0u8; size_of::<u32>()];
173 r.read_exact(&mut checksum_buffer[..])?;
174 let checksum = u32::from_le_bytes(checksum_buffer);
175 Ok(checksum)
176 }
177
178 #[inline]
179 fn check_block_checksum(data: &[u8], expected_checksum: u32) -> Result<(), io::Error> {
180 let mut block_hasher = XxHash32::with_seed(0);
181 block_hasher.write(data);
182 let calc_checksum = block_hasher.finish() as u32;
183 if calc_checksum != expected_checksum {
184 return Err(Error::BlockChecksumError.into());
185 }
186 Ok(())
187 }
188
189 fn read_block(&mut self) -> io::Result<usize> {
190 debug_assert_eq!(self.dst_start, self.dst_end);
191 let frame_info = self.current_frame_info.as_ref().unwrap();
192
193 let max_block_size = frame_info.block_size.get_size();
195 if frame_info.block_mode == BlockMode::Linked {
196 debug_assert_eq!(self.dst.capacity(), max_block_size * 2 + WINDOW_SIZE);
202 if self.dst_start + max_block_size > self.dst.capacity() {
203 debug_assert!(self.dst_start >= max_block_size + WINDOW_SIZE);
206 self.ext_dict_offset = self.dst_start - WINDOW_SIZE;
207 self.ext_dict_len = WINDOW_SIZE;
208 self.dst_start = 0;
210 self.dst_end = 0;
211 } else if self.dst_start + self.ext_dict_len > WINDOW_SIZE {
212 let delta = self
217 .ext_dict_len
218 .min(self.dst_start + self.ext_dict_len - WINDOW_SIZE);
219 self.ext_dict_offset += delta;
220 self.ext_dict_len -= delta;
221 debug_assert!(self.dst_start + self.ext_dict_len >= WINDOW_SIZE)
222 }
223 } else {
224 debug_assert_eq!(self.ext_dict_len, 0);
225 debug_assert_eq!(self.dst.capacity(), max_block_size);
226 self.dst_start = 0;
227 self.dst_end = 0;
228 }
229
230 let block_info = {
232 let mut buffer = [0u8; 4];
233 if let Err(err) = self.r.read_exact(&mut buffer) {
234 if err.kind() == ErrorKind::UnexpectedEof {
235 return Ok(0);
236 } else {
237 return Err(err);
238 }
239 }
240 BlockInfo::read(&buffer)?
241 };
242 match block_info {
243 BlockInfo::Uncompressed(len) => {
244 let len = len as usize;
245 if len > max_block_size {
246 return Err(Error::BlockTooBig.into());
247 }
248 self.r.read_exact(vec_resize_and_get_mut(
251 &mut self.dst,
252 self.dst_start,
253 self.dst_start + len,
254 ))?;
255 if frame_info.block_checksums {
256 let expected_checksum = Self::read_checksum(&mut self.r)?;
257 Self::check_block_checksum(
258 &self.dst[self.dst_start..self.dst_start + len],
259 expected_checksum,
260 )?;
261 }
262
263 self.dst_end += len;
264 self.content_len += len as u64;
265 }
266 BlockInfo::Compressed(len) => {
267 let len = len as usize;
268 if len > max_block_size {
269 return Err(Error::BlockTooBig.into());
270 }
271 self.r
274 .read_exact(vec_resize_and_get_mut(&mut self.src, 0, len))?;
275 if frame_info.block_checksums {
276 let expected_checksum = Self::read_checksum(&mut self.r)?;
277 Self::check_block_checksum(&self.src[..len], expected_checksum)?;
278 }
279
280 let with_dict_mode =
281 frame_info.block_mode == BlockMode::Linked && self.ext_dict_len != 0;
282 let decomp_size = if with_dict_mode {
283 debug_assert!(self.dst_start + max_block_size <= self.ext_dict_offset);
284 let (head, tail) = self.dst.split_at_mut(self.ext_dict_offset);
285 let ext_dict = &tail[..self.ext_dict_len];
286
287 debug_assert!(head.len() - self.dst_start >= max_block_size);
288 crate::block::decompress::decompress_internal::<true, _>(
289 &self.src[..len],
290 &mut SliceSink::new(head, self.dst_start),
291 ext_dict,
292 )
293 } else {
294 debug_assert!(self.dst.capacity() - self.dst_start >= max_block_size);
296 crate::block::decompress::decompress_internal::<false, _>(
297 &self.src[..len],
298 &mut vec_sink_for_decompression(
299 &mut self.dst,
300 0,
301 self.dst_start,
302 self.dst_start + max_block_size,
303 ),
304 b"",
305 )
306 }
307 .map_err(Error::DecompressionError)?;
308
309 self.dst_end += decomp_size;
310 self.content_len += decomp_size as u64;
311 }
312
313 BlockInfo::EndMark => {
314 if let Some(expected) = frame_info.content_size {
315 if self.content_len != expected {
316 return Err(Error::ContentLengthError {
317 expected,
318 actual: self.content_len,
319 }
320 .into());
321 }
322 }
323 if frame_info.content_checksum {
324 let expected_checksum = Self::read_checksum(&mut self.r)?;
325 let calc_checksum = self.content_hasher.finish() as u32;
326 if calc_checksum != expected_checksum {
327 return Err(Error::ContentChecksumError.into());
328 }
329 }
330 self.current_frame_info = None;
331 return Ok(0);
332 }
333 }
334
335 if frame_info.content_checksum {
337 self.content_hasher
338 .write(&self.dst[self.dst_start..self.dst_end]);
339 }
340
341 Ok(self.dst_end - self.dst_start)
342 }
343
344 fn read_more(&mut self) -> io::Result<usize> {
345 if self.current_frame_info.is_none() && self.read_frame_info()? == 0 {
346 return Ok(0);
347 }
348 self.read_block()
349 }
350}
351
352impl<R: io::Read> io::Read for FrameDecoder<R> {
353 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
354 loop {
355 if self.dst_start < self.dst_end {
357 let read_len = std::cmp::min(self.dst_end - self.dst_start, buf.len());
358 let dst_read_end = self.dst_start + read_len;
359 buf[..read_len].copy_from_slice(&self.dst[self.dst_start..dst_read_end]);
360 self.dst_start = dst_read_end;
361 return Ok(read_len);
362 }
363 if self.read_more()? == 0 {
364 return Ok(0);
365 }
366 }
367 }
368
369 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
370 let mut written = 0;
371 loop {
372 match self.fill_buf() {
373 Ok([]) => return Ok(written),
374 Ok(b) => {
375 let s = std::str::from_utf8(b).map_err(|_| {
376 io::Error::new(
377 io::ErrorKind::InvalidData,
378 "stream did not contain valid UTF-8",
379 )
380 })?;
381 buf.push_str(s);
382 let len = s.len();
383 self.consume(len);
384 written += len;
385 }
386 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
387 Err(e) => return Err(e),
388 }
389 }
390 }
391
392 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
393 let mut written = 0;
394 loop {
395 match self.fill_buf() {
396 Ok([]) => return Ok(written),
397 Ok(b) => {
398 buf.extend_from_slice(b);
399 let len = b.len();
400 self.consume(len);
401 written += len;
402 }
403 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
404 Err(e) => return Err(e),
405 }
406 }
407 }
408}
409
410impl<R: io::Read> io::BufRead for FrameDecoder<R> {
411 fn fill_buf(&mut self) -> io::Result<&[u8]> {
412 if self.dst_start == self.dst_end {
413 self.read_more()?;
414 }
415 Ok(&self.dst[self.dst_start..self.dst_end])
416 }
417
418 fn consume(&mut self, amt: usize) {
419 assert!(amt <= self.dst_end - self.dst_start);
420 self.dst_start += amt;
421 }
422}
423
424impl<R: fmt::Debug + io::Read> fmt::Debug for FrameDecoder<R> {
425 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
426 f.debug_struct("FrameDecoder")
427 .field("r", &self.r)
428 .field("content_hasher", &self.content_hasher)
429 .field("content_len", &self.content_len)
430 .field("src", &"[...]")
431 .field("dst", &"[...]")
432 .field("dst_start", &self.dst_start)
433 .field("dst_end", &self.dst_end)
434 .field("ext_dict_offset", &self.ext_dict_offset)
435 .field("ext_dict_len", &self.ext_dict_len)
436 .field("current_frame_info", &self.current_frame_info)
437 .finish()
438 }
439}
440
441#[inline]
443fn vec_resize_and_get_mut(v: &mut Vec<u8>, start: usize, end: usize) -> &mut [u8] {
444 if end > v.len() {
445 v.resize(end, 0)
446 }
447 &mut v[start..end]
448}