1#[cfg(feature = "with-alloc")]
5use crate::alloc::boxed::Box;
6use core::{cmp, mem};
7
8use crate::inflate::core::{decompress, inflate_flags, DecompressorOxide, TINFL_LZ_DICT_SIZE};
9use crate::inflate::TINFLStatus;
10use crate::{DataFormat, MZError, MZFlush, MZResult, MZStatus, StreamResult};
11
12pub trait ResetPolicy {
14 fn reset(&self, state: &mut InflateState);
16}
17
18pub struct MinReset;
22
23impl ResetPolicy for MinReset {
24 fn reset(&self, state: &mut InflateState) {
25 state.decompressor().init();
26 state.dict_ofs = 0;
27 state.dict_avail = 0;
28 state.first_call = true;
29 state.has_flushed = false;
30 state.last_status = TINFLStatus::NeedsMoreInput;
31 }
32}
33
34pub struct ZeroReset;
36
37impl ResetPolicy for ZeroReset {
38 #[inline]
39 fn reset(&self, state: &mut InflateState) {
40 MinReset.reset(state);
41 state.dict = [0; TINFL_LZ_DICT_SIZE];
42 }
43}
44
45pub struct FullReset(pub DataFormat);
49
50impl ResetPolicy for FullReset {
51 #[inline]
52 fn reset(&self, state: &mut InflateState) {
53 ZeroReset.reset(state);
54 state.data_format = self.0;
55 }
56}
57
58#[derive(Clone)]
61pub struct InflateState {
62 decomp: DecompressorOxide,
64
65 dict: [u8; TINFL_LZ_DICT_SIZE],
71 dict_ofs: usize,
73 dict_avail: usize,
75
76 first_call: bool,
77 has_flushed: bool,
78
79 data_format: DataFormat,
82 last_status: TINFLStatus,
83}
84
85impl Default for InflateState {
86 fn default() -> Self {
87 InflateState {
88 decomp: DecompressorOxide::default(),
89 dict: [0; TINFL_LZ_DICT_SIZE],
90 dict_ofs: 0,
91 dict_avail: 0,
92 first_call: true,
93 has_flushed: false,
94 data_format: DataFormat::Raw,
95 last_status: TINFLStatus::NeedsMoreInput,
96 }
97 }
98}
99impl InflateState {
100 pub fn new(data_format: DataFormat) -> InflateState {
109 InflateState {
110 data_format,
111 ..Default::default()
112 }
113 }
114
115 #[cfg(feature = "with-alloc")]
121 pub fn new_boxed(data_format: DataFormat) -> Box<InflateState> {
122 let mut b: Box<InflateState> = Box::default();
123 b.data_format = data_format;
124 b
125 }
126
127 pub fn decompressor(&mut self) -> &mut DecompressorOxide {
129 &mut self.decomp
130 }
131
132 pub const fn last_status(&self) -> TINFLStatus {
134 self.last_status
135 }
136
137 #[cfg(feature = "with-alloc")]
143 pub fn new_boxed_with_window_bits(window_bits: i32) -> Box<InflateState> {
144 let mut b: Box<InflateState> = Box::default();
145 b.data_format = DataFormat::from_window_bits(window_bits);
146 b
147 }
148
149 #[inline]
150 pub fn reset(&mut self, data_format: DataFormat) {
153 self.reset_as(FullReset(data_format));
154 }
155
156 #[inline]
157 pub fn reset_as<T: ResetPolicy>(&mut self, policy: T) {
159 policy.reset(self)
160 }
161}
162
163pub fn inflate(
187 state: &mut InflateState,
188 input: &[u8],
189 output: &mut [u8],
190 flush: MZFlush,
191) -> StreamResult {
192 let mut bytes_consumed = 0;
193 let mut bytes_written = 0;
194 let mut next_in = input;
195 let mut next_out = output;
196
197 if flush == MZFlush::Full {
198 return StreamResult::error(MZError::Stream);
199 }
200
201 let mut decomp_flags = if state.data_format == DataFormat::Zlib {
202 inflate_flags::TINFL_FLAG_COMPUTE_ADLER32
203 } else {
204 inflate_flags::TINFL_FLAG_IGNORE_ADLER32
205 };
206
207 if (state.data_format == DataFormat::Zlib)
208 | (state.data_format == DataFormat::ZLibIgnoreChecksum)
209 {
210 decomp_flags |= inflate_flags::TINFL_FLAG_PARSE_ZLIB_HEADER;
211 }
212
213 let first_call = state.first_call;
214 state.first_call = false;
215 if state.last_status == TINFLStatus::FailedCannotMakeProgress {
216 return StreamResult::error(MZError::Buf);
217 }
218 if (state.last_status as i32) < 0 {
219 return StreamResult::error(MZError::Data);
220 }
221
222 if state.has_flushed && (flush != MZFlush::Finish) {
223 return StreamResult::error(MZError::Stream);
224 }
225 state.has_flushed |= flush == MZFlush::Finish;
226
227 if (flush == MZFlush::Finish) && first_call {
228 decomp_flags |= inflate_flags::TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF;
229
230 let status = decompress(&mut state.decomp, next_in, next_out, 0, decomp_flags);
234 let in_bytes = status.1;
235 let out_bytes = status.2;
236 let status = status.0;
237
238 state.last_status = status;
239
240 bytes_consumed += in_bytes;
241 bytes_written += out_bytes;
242
243 let ret_status = {
244 if status == TINFLStatus::FailedCannotMakeProgress {
245 Err(MZError::Buf)
246 } else if (status as i32) < 0 {
247 Err(MZError::Data)
248 } else if status != TINFLStatus::Done {
249 state.last_status = TINFLStatus::Failed;
250 Err(MZError::Buf)
251 } else {
252 Ok(MZStatus::StreamEnd)
253 }
254 };
255 return StreamResult {
256 bytes_consumed,
257 bytes_written,
258 status: ret_status,
259 };
260 }
261
262 if flush != MZFlush::Finish {
263 decomp_flags |= inflate_flags::TINFL_FLAG_HAS_MORE_INPUT;
264 }
265
266 if state.dict_avail != 0 {
267 bytes_written += push_dict_out(state, &mut next_out);
268 return StreamResult {
269 bytes_consumed,
270 bytes_written,
271 status: Ok(
272 if (state.last_status == TINFLStatus::Done) && (state.dict_avail == 0) {
273 MZStatus::StreamEnd
274 } else {
275 MZStatus::Ok
276 },
277 ),
278 };
279 }
280
281 let status = inflate_loop(
282 state,
283 &mut next_in,
284 &mut next_out,
285 &mut bytes_consumed,
286 &mut bytes_written,
287 decomp_flags,
288 flush,
289 );
290 StreamResult {
291 bytes_consumed,
292 bytes_written,
293 status,
294 }
295}
296
297fn inflate_loop(
298 state: &mut InflateState,
299 next_in: &mut &[u8],
300 next_out: &mut &mut [u8],
301 total_in: &mut usize,
302 total_out: &mut usize,
303 decomp_flags: u32,
304 flush: MZFlush,
305) -> MZResult {
306 let orig_in_len = next_in.len();
307 loop {
308 let status = decompress(
309 &mut state.decomp,
310 next_in,
311 &mut state.dict,
312 state.dict_ofs,
313 decomp_flags,
314 );
315
316 let in_bytes = status.1;
317 let out_bytes = status.2;
318 let status = status.0;
319
320 state.last_status = status;
321
322 *next_in = &next_in[in_bytes..];
323 *total_in += in_bytes;
324
325 state.dict_avail = out_bytes;
326 *total_out += push_dict_out(state, next_out);
327
328 if status == TINFLStatus::FailedCannotMakeProgress {
330 return Err(MZError::Buf);
331 }
332 else if (status as i32) < 0 {
334 return Err(MZError::Data);
335 }
336
337 if (status == TINFLStatus::NeedsMoreInput) && orig_in_len == 0 {
340 return Err(MZError::Buf);
341 }
342
343 if flush == MZFlush::Finish {
344 if status == TINFLStatus::Done {
345 return if state.dict_avail != 0 {
348 Err(MZError::Buf)
349 } else {
350 Ok(MZStatus::StreamEnd)
351 };
352 } else if next_out.is_empty() {
354 return Err(MZError::Buf);
355 }
356 } else {
357 let empty_buf = next_in.is_empty() || next_out.is_empty();
359 if (status == TINFLStatus::Done) || empty_buf || (state.dict_avail != 0) {
360 return if (status == TINFLStatus::Done) && (state.dict_avail == 0) {
361 Ok(MZStatus::StreamEnd)
363 } else {
364 Ok(MZStatus::Ok)
366 };
367 }
368 }
369 }
370}
371
372fn push_dict_out(state: &mut InflateState, next_out: &mut &mut [u8]) -> usize {
373 let n = cmp::min(state.dict_avail, next_out.len());
374 (next_out[..n]).copy_from_slice(&state.dict[state.dict_ofs..state.dict_ofs + n]);
375 *next_out = &mut mem::take(next_out)[n..];
376 state.dict_avail -= n;
377 state.dict_ofs = (state.dict_ofs + (n)) & (TINFL_LZ_DICT_SIZE - 1);
378 n
379}
380
381#[cfg(all(test, feature = "with-alloc"))]
382mod test {
383 use super::{inflate, InflateState};
384 use crate::{DataFormat, MZFlush, MZStatus};
385 use alloc::vec;
386
387 #[test]
388 fn test_state() {
389 let encoded = [
390 120u8, 156, 243, 72, 205, 201, 201, 215, 81, 168, 202, 201, 76, 82, 4, 0, 27, 101, 4,
391 19,
392 ];
393 let mut out = vec![0; 50];
394 let mut state = InflateState::new_boxed(DataFormat::Zlib);
395 let res = inflate(&mut state, &encoded, &mut out, MZFlush::Finish);
396 let status = res.status.expect("Failed to decompress!");
397 assert_eq!(status, MZStatus::StreamEnd);
398 assert_eq!(out[..res.bytes_written as usize], b"Hello, zlib!"[..]);
399 assert_eq!(res.bytes_consumed, encoded.len());
400
401 state.reset_as(super::ZeroReset);
402 out.iter_mut().map(|x| *x = 0).count();
403 let res = inflate(&mut state, &encoded, &mut out, MZFlush::Finish);
404 let status = res.status.expect("Failed to decompress!");
405 assert_eq!(status, MZStatus::StreamEnd);
406 assert_eq!(out[..res.bytes_written as usize], b"Hello, zlib!"[..]);
407 assert_eq!(res.bytes_consumed, encoded.len());
408
409 state.reset_as(super::MinReset);
410 out.iter_mut().map(|x| *x = 0).count();
411 let res = inflate(&mut state, &encoded, &mut out, MZFlush::Finish);
412 let status = res.status.expect("Failed to decompress!");
413 assert_eq!(status, MZStatus::StreamEnd);
414 assert_eq!(out[..res.bytes_written as usize], b"Hello, zlib!"[..]);
415 assert_eq!(res.bytes_consumed, encoded.len());
416 assert_eq!(state.decompressor().adler32(), Some(459605011));
417
418 state = InflateState::new_boxed(DataFormat::ZLibIgnoreChecksum);
420 out.iter_mut().map(|x| *x = 0).count();
421 let res = inflate(&mut state, &encoded, &mut out, MZFlush::Finish);
422 let status = res.status.expect("Failed to decompress!");
423 assert_eq!(status, MZStatus::StreamEnd);
424 assert_eq!(out[..res.bytes_written as usize], b"Hello, zlib!"[..]);
425 assert_eq!(res.bytes_consumed, encoded.len());
426 assert_eq!(state.decompressor().adler32(), Some(1));
428 assert_eq!(state.decompressor().adler32_header(), Some(459605011))
430 }
431
432 #[test]
433 fn test_partial_continue() {
434 let encoded = [
435 120u8, 156, 243, 72, 205, 201, 201, 215, 81, 168, 202, 201, 76, 82, 4, 0, 27, 101, 4,
436 19,
437 ];
438
439 let mut out = vec![0; 50];
441 let mut state = InflateState::new_boxed(DataFormat::Zlib);
442 let mut part_in = 0;
443 let mut part_out = 0;
444 for i in 1..=encoded.len() {
445 let res = inflate(
446 &mut state,
447 &encoded[part_in..i],
448 &mut out[part_out..],
449 MZFlush::None,
450 );
451 let status = res.status.expect("Failed to decompress!");
452 if i == encoded.len() {
453 assert_eq!(status, MZStatus::StreamEnd);
454 } else {
455 assert_eq!(status, MZStatus::Ok);
456 }
457 part_out += res.bytes_written as usize;
458 part_in += res.bytes_consumed;
459 }
460
461 assert_eq!(out[..part_out as usize], b"Hello, zlib!"[..]);
462 assert_eq!(part_in, encoded.len());
463 assert_eq!(state.decompressor().adler32(), Some(459605011));
464 }
465
466 #[test]
469 fn test_rewind_and_resume() {
470 let encoded = [
471 120u8, 156, 243, 72, 205, 201, 201, 215, 81, 168, 202, 201, 76, 82, 4, 0, 27, 101, 4,
472 19,
473 ];
474 let decoded = b"Hello, zlib!";
475
476 let mut out = vec![0; 50];
478 let mut state = InflateState::new_boxed(DataFormat::Zlib);
479 let res1 = inflate(&mut state, &encoded[..10], &mut out, MZFlush::None);
480 let status = res1.status.expect("Failed to decompress!");
481 assert_eq!(status, MZStatus::Ok);
482
483 let mut resume = state.clone();
485 drop(state);
486
487 let res2 = inflate(
489 &mut resume,
490 &encoded[res1.bytes_consumed..],
491 &mut out[res1.bytes_written..],
492 MZFlush::Finish,
493 );
494 let status = res2.status.expect("Failed to decompress!");
495 assert_eq!(status, MZStatus::StreamEnd);
496
497 assert_eq!(res1.bytes_consumed + res2.bytes_consumed, encoded.len());
498 assert_eq!(res1.bytes_written + res2.bytes_written, decoded.len());
499 assert_eq!(
500 &out[..res1.bytes_written + res2.bytes_written as usize],
501 decoded
502 );
503 assert_eq!(resume.decompressor().adler32(), Some(459605011));
504 }
505}