1use std::io;
8
9pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};
10
11use crate::dict::{DecoderDictionary, EncoderDictionary};
12use crate::map_error_code;
13
14pub trait Operation {
18 fn run<C: WriteBuf + ?Sized>(
25 &mut self,
26 input: &mut InBuffer<'_>,
27 output: &mut OutBuffer<'_, C>,
28 ) -> io::Result<usize>;
29
30 fn run_on_buffers(
35 &mut self,
36 input: &[u8],
37 output: &mut [u8],
38 ) -> io::Result<Status> {
39 let mut input = InBuffer::around(input);
40 let mut output = OutBuffer::around(output);
41
42 let remaining = self.run(&mut input, &mut output)?;
43
44 Ok(Status {
45 remaining,
46 bytes_read: input.pos(),
47 bytes_written: output.pos(),
48 })
49 }
50
51 fn flush<C: WriteBuf + ?Sized>(
56 &mut self,
57 output: &mut OutBuffer<'_, C>,
58 ) -> io::Result<usize> {
59 let _ = output;
60 Ok(0)
61 }
62
63 fn reinit(&mut self) -> io::Result<()> {
67 Ok(())
68 }
69
70 fn finish<C: WriteBuf + ?Sized>(
77 &mut self,
78 output: &mut OutBuffer<'_, C>,
79 finished_frame: bool,
80 ) -> io::Result<usize> {
81 let _ = output;
82 let _ = finished_frame;
83 Ok(0)
84 }
85}
86
87pub struct NoOp;
89
90impl Operation for NoOp {
91 fn run<C: WriteBuf + ?Sized>(
92 &mut self,
93 input: &mut InBuffer<'_>,
94 output: &mut OutBuffer<'_, C>,
95 ) -> io::Result<usize> {
96 let src = &input.src[input.pos..];
98 let output_pos = output.pos();
100 let dst = unsafe { output.as_mut_ptr().add(output_pos) };
101
102 let len = usize::min(src.len(), output.capacity() - output_pos);
104 let src = &src[..len];
105
106 unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len) };
110 input.set_pos(input.pos() + len);
111 unsafe { output.set_pos(output_pos + len) };
112
113 Ok(0)
114 }
115}
116
117pub struct Status {
119 pub remaining: usize,
125
126 pub bytes_read: usize,
128
129 pub bytes_written: usize,
131}
132
133pub struct Decoder<'a> {
135 context: MaybeOwnedDCtx<'a>,
136}
137
138impl Decoder<'static> {
139 pub fn new() -> io::Result<Self> {
141 Self::with_dictionary(&[])
142 }
143
144 pub fn with_dictionary(dictionary: &[u8]) -> io::Result<Self> {
146 let mut context = zstd_safe::DCtx::create();
147 context.init().map_err(map_error_code)?;
148 context
149 .load_dictionary(dictionary)
150 .map_err(map_error_code)?;
151 Ok(Decoder {
152 context: MaybeOwnedDCtx::Owned(context),
153 })
154 }
155}
156
157impl<'a> Decoder<'a> {
158 pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
160 Self {
161 context: MaybeOwnedDCtx::Borrowed(context),
162 }
163 }
164
165 pub fn with_prepared_dictionary<'b>(
167 dictionary: &DecoderDictionary<'b>,
168 ) -> io::Result<Self>
169 where
170 'b: 'a,
171 {
172 let mut context = zstd_safe::DCtx::create();
173 context
174 .ref_ddict(dictionary.as_ddict())
175 .map_err(map_error_code)?;
176 Ok(Decoder {
177 context: MaybeOwnedDCtx::Owned(context),
178 })
179 }
180
181 pub fn with_ref_prefix<'b>(ref_prefix: &'b [u8]) -> io::Result<Self>
183 where
184 'b: 'a,
185 {
186 let mut context = zstd_safe::DCtx::create();
187 context.ref_prefix(ref_prefix).map_err(map_error_code)?;
188 Ok(Decoder {
189 context: MaybeOwnedDCtx::Owned(context),
190 })
191 }
192
193 pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
195 match &mut self.context {
196 MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
197 MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
198 }
199 .map_err(map_error_code)?;
200 Ok(())
201 }
202}
203
204impl Operation for Decoder<'_> {
205 fn run<C: WriteBuf + ?Sized>(
206 &mut self,
207 input: &mut InBuffer<'_>,
208 output: &mut OutBuffer<'_, C>,
209 ) -> io::Result<usize> {
210 match &mut self.context {
211 MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
212 MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
213 }
214 .map_err(map_error_code)
215 }
216
217 fn flush<C: WriteBuf + ?Sized>(
218 &mut self,
219 output: &mut OutBuffer<'_, C>,
220 ) -> io::Result<usize> {
221 self.run(&mut InBuffer::around(&[]), output)?;
223
224 if output.pos() < output.capacity() {
226 Ok(0)
228 } else {
229 Ok(1)
231 }
232 }
233
234 fn reinit(&mut self) -> io::Result<()> {
235 match &mut self.context {
236 MaybeOwnedDCtx::Owned(x) => {
237 x.reset(zstd_safe::ResetDirective::SessionOnly)
238 }
239 MaybeOwnedDCtx::Borrowed(x) => {
240 x.reset(zstd_safe::ResetDirective::SessionOnly)
241 }
242 }
243 .map_err(map_error_code)?;
244 Ok(())
245 }
246
247 fn finish<C: WriteBuf + ?Sized>(
248 &mut self,
249 _output: &mut OutBuffer<'_, C>,
250 finished_frame: bool,
251 ) -> io::Result<usize> {
252 if finished_frame {
253 Ok(0)
254 } else {
255 Err(io::Error::new(
256 io::ErrorKind::UnexpectedEof,
257 "incomplete frame",
258 ))
259 }
260 }
261}
262
263pub struct Encoder<'a> {
265 context: MaybeOwnedCCtx<'a>,
266}
267
268impl Encoder<'static> {
269 pub fn new(level: i32) -> io::Result<Self> {
271 Self::with_dictionary(level, &[])
272 }
273
274 pub fn with_dictionary(level: i32, dictionary: &[u8]) -> io::Result<Self> {
276 let mut context = zstd_safe::CCtx::create();
277
278 context
279 .set_parameter(CParameter::CompressionLevel(level))
280 .map_err(map_error_code)?;
281
282 context
283 .load_dictionary(dictionary)
284 .map_err(map_error_code)?;
285
286 Ok(Encoder {
287 context: MaybeOwnedCCtx::Owned(context),
288 })
289 }
290}
291
292impl<'a> Encoder<'a> {
293 pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
295 Self {
296 context: MaybeOwnedCCtx::Borrowed(context),
297 }
298 }
299
300 pub fn with_prepared_dictionary<'b>(
302 dictionary: &EncoderDictionary<'b>,
303 ) -> io::Result<Self>
304 where
305 'b: 'a,
306 {
307 let mut context = zstd_safe::CCtx::create();
308 context
309 .ref_cdict(dictionary.as_cdict())
310 .map_err(map_error_code)?;
311 Ok(Encoder {
312 context: MaybeOwnedCCtx::Owned(context),
313 })
314 }
315
316 pub fn with_ref_prefix<'b>(
318 level: i32,
319 ref_prefix: &'b [u8],
320 ) -> io::Result<Self>
321 where
322 'b: 'a,
323 {
324 let mut context = zstd_safe::CCtx::create();
325
326 context
327 .set_parameter(CParameter::CompressionLevel(level))
328 .map_err(map_error_code)?;
329
330 context.ref_prefix(ref_prefix).map_err(map_error_code)?;
331
332 Ok(Encoder {
333 context: MaybeOwnedCCtx::Owned(context),
334 })
335 }
336
337 pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
339 match &mut self.context {
340 MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
341 MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
342 }
343 .map_err(map_error_code)?;
344 Ok(())
345 }
346
347 pub fn set_pledged_src_size(
356 &mut self,
357 pledged_src_size: Option<u64>,
358 ) -> io::Result<()> {
359 match &mut self.context {
360 MaybeOwnedCCtx::Owned(x) => {
361 x.set_pledged_src_size(pledged_src_size)
362 }
363 MaybeOwnedCCtx::Borrowed(x) => {
364 x.set_pledged_src_size(pledged_src_size)
365 }
366 }
367 .map_err(map_error_code)?;
368 Ok(())
369 }
370}
371
372impl<'a> Operation for Encoder<'a> {
373 fn run<C: WriteBuf + ?Sized>(
374 &mut self,
375 input: &mut InBuffer<'_>,
376 output: &mut OutBuffer<'_, C>,
377 ) -> io::Result<usize> {
378 match &mut self.context {
379 MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
380 MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
381 }
382 .map_err(map_error_code)
383 }
384
385 fn flush<C: WriteBuf + ?Sized>(
386 &mut self,
387 output: &mut OutBuffer<'_, C>,
388 ) -> io::Result<usize> {
389 match &mut self.context {
390 MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
391 MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
392 }
393 .map_err(map_error_code)
394 }
395
396 fn finish<C: WriteBuf + ?Sized>(
397 &mut self,
398 output: &mut OutBuffer<'_, C>,
399 _finished_frame: bool,
400 ) -> io::Result<usize> {
401 match &mut self.context {
402 MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
403 MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
404 }
405 .map_err(map_error_code)
406 }
407
408 fn reinit(&mut self) -> io::Result<()> {
409 match &mut self.context {
410 MaybeOwnedCCtx::Owned(x) => {
411 x.reset(zstd_safe::ResetDirective::SessionOnly)
412 }
413 MaybeOwnedCCtx::Borrowed(x) => {
414 x.reset(zstd_safe::ResetDirective::SessionOnly)
415 }
416 }
417 .map_err(map_error_code)?;
418 Ok(())
419 }
420}
421
422enum MaybeOwnedCCtx<'a> {
423 Owned(zstd_safe::CCtx<'a>),
424 Borrowed(&'a mut zstd_safe::CCtx<'static>),
425}
426
427enum MaybeOwnedDCtx<'a> {
428 Owned(zstd_safe::DCtx<'a>),
429 Borrowed(&'a mut zstd_safe::DCtx<'static>),
430}
431
432#[cfg(test)]
433mod tests {
434
435 #[cfg(feature = "arrays")]
437 #[test]
438 fn test_cycle() {
439 use super::{Decoder, Encoder, InBuffer, Operation, OutBuffer};
440
441 let mut encoder = Encoder::new(1).unwrap();
442 let mut decoder = Decoder::new().unwrap();
443
444 let mut input = InBuffer::around(b"AbcdefAbcdefabcdef");
446
447 let mut output = [0u8; 128];
448 let mut output = OutBuffer::around(&mut output);
449
450 loop {
451 encoder.run(&mut input, &mut output).unwrap();
452
453 if input.pos == input.src.len() {
454 break;
455 }
456 }
457 encoder.finish(&mut output, true).unwrap();
458
459 let initial_data = input.src;
460
461 let mut input = InBuffer::around(output.as_slice());
463 let mut output = [0u8; 128];
464 let mut output = OutBuffer::around(&mut output);
465
466 loop {
467 decoder.run(&mut input, &mut output).unwrap();
468
469 if input.pos == input.src.len() {
470 break;
471 }
472 }
473
474 assert_eq!(initial_data, output.as_slice());
475 }
476}