bzip2/
write.rs

1//! Writer-based compression/decompression streams
2
3use std::io;
4use std::io::prelude::*;
5
6#[cfg(feature = "tokio")]
7use futures::Poll;
8#[cfg(feature = "tokio")]
9use tokio_io::{AsyncRead, AsyncWrite};
10
11use {Action, Compress, Compression, Decompress, Status};
12
13/// A compression stream which will have uncompressed data written to it and
14/// will write compressed data to an output stream.
15pub struct BzEncoder<W: Write> {
16    data: Compress,
17    obj: Option<W>,
18    buf: Vec<u8>,
19    done: bool,
20}
21
22/// A compression stream which will have compressed data written to it and
23/// will write uncompressed data to an output stream.
24pub struct BzDecoder<W: Write> {
25    data: Decompress,
26    obj: Option<W>,
27    buf: Vec<u8>,
28    done: bool,
29}
30
31impl<W: Write> BzEncoder<W> {
32    /// Create a new compression stream which will compress at the given level
33    /// to write compress output to the give output stream.
34    pub fn new(obj: W, level: Compression) -> BzEncoder<W> {
35        BzEncoder {
36            data: Compress::new(level, 30),
37            obj: Some(obj),
38            buf: Vec::with_capacity(32 * 1024),
39            done: false,
40        }
41    }
42
43    fn dump(&mut self) -> io::Result<()> {
44        while self.buf.len() > 0 {
45            let n = match self.obj.as_mut().unwrap().write(&self.buf) {
46                Ok(n) => n,
47                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
48                Err(err) => return Err(err),
49            };
50            self.buf.drain(..n);
51        }
52        Ok(())
53    }
54
55    /// Acquires a reference to the underlying writer.
56    pub fn get_ref(&self) -> &W {
57        self.obj.as_ref().unwrap()
58    }
59
60    /// Acquires a mutable reference to the underlying writer.
61    ///
62    /// Note that mutating the output/input state of the stream may corrupt this
63    /// object, so care must be taken when using this method.
64    pub fn get_mut(&mut self) -> &mut W {
65        self.obj.as_mut().unwrap()
66    }
67
68    /// Attempt to finish this output stream, writing out final chunks of data.
69    ///
70    /// Note that this function can only be used once data has finished being
71    /// written to the output stream. After this function is called then further
72    /// calls to `write` may result in a panic.
73    ///
74    /// # Panics
75    ///
76    /// Attempts to write data to this stream may result in a panic after this
77    /// function is called.
78    pub fn try_finish(&mut self) -> io::Result<()> {
79        while !self.done {
80            self.dump()?;
81            let res = self.data.compress_vec(&[], &mut self.buf, Action::Finish);
82            if res == Ok(Status::StreamEnd) {
83                self.done = true;
84                break;
85            }
86        }
87        self.dump()
88    }
89
90    /// Consumes this encoder, flushing the output stream.
91    ///
92    /// This will flush the underlying data stream and then return the contained
93    /// writer if the flush succeeded.
94    ///
95    /// Note that this function may not be suitable to call in a situation where
96    /// the underlying stream is an asynchronous I/O stream. To finish a stream
97    /// the `try_finish` (or `shutdown`) method should be used instead. To
98    /// re-acquire ownership of a stream it is safe to call this method after
99    /// `try_finish` or `shutdown` has returned `Ok`.
100    pub fn finish(mut self) -> io::Result<W> {
101        self.try_finish()?;
102        Ok(self.obj.take().unwrap())
103    }
104
105    /// Returns the number of bytes produced by the compressor
106    ///
107    /// Note that, due to buffering, this only bears any relation to
108    /// `total_in()` after a call to `flush()`.  At that point,
109    /// `total_out() / total_in()` is the compression ratio.
110    pub fn total_out(&self) -> u64 {
111        self.data.total_out()
112    }
113
114    /// Returns the number of bytes consumed by the compressor
115    /// (e.g. the number of bytes written to this stream.)
116    pub fn total_in(&self) -> u64 {
117        self.data.total_in()
118    }
119}
120
121impl<W: Write> Write for BzEncoder<W> {
122    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
123        loop {
124            self.dump()?;
125
126            let total_in = self.total_in();
127            self.data
128                .compress_vec(data, &mut self.buf, Action::Run)
129                .unwrap();
130            let written = (self.total_in() - total_in) as usize;
131
132            if written > 0 || data.len() == 0 {
133                return Ok(written);
134            }
135        }
136    }
137
138    fn flush(&mut self) -> io::Result<()> {
139        loop {
140            self.dump()?;
141            let before = self.total_out();
142            self.data
143                .compress_vec(&[], &mut self.buf, Action::Flush)
144                .unwrap();
145
146            if before == self.total_out() {
147                break;
148            }
149        }
150        self.obj.as_mut().unwrap().flush()
151    }
152}
153
154#[cfg(feature = "tokio")]
155impl<W: AsyncWrite> AsyncWrite for BzEncoder<W> {
156    fn shutdown(&mut self) -> Poll<(), io::Error> {
157        try_nb!(self.try_finish());
158        self.get_mut().shutdown()
159    }
160}
161
162impl<W: Read + Write> Read for BzEncoder<W> {
163    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
164        self.get_mut().read(buf)
165    }
166}
167
168#[cfg(feature = "tokio")]
169impl<W: AsyncRead + AsyncWrite> AsyncRead for BzEncoder<W> {}
170
171impl<W: Write> Drop for BzEncoder<W> {
172    fn drop(&mut self) {
173        if self.obj.is_some() {
174            let _ = self.try_finish();
175        }
176    }
177}
178
179impl<W: Write> BzDecoder<W> {
180    /// Create a new decoding stream which will decompress all data written
181    /// to it into `obj`.
182    pub fn new(obj: W) -> BzDecoder<W> {
183        BzDecoder {
184            data: Decompress::new(false),
185            obj: Some(obj),
186            buf: Vec::with_capacity(32 * 1024),
187            done: false,
188        }
189    }
190
191    /// Acquires a reference to the underlying writer.
192    pub fn get_ref(&self) -> &W {
193        self.obj.as_ref().unwrap()
194    }
195
196    /// Acquires a mutable reference to the underlying writer.
197    ///
198    /// Note that mutating the output/input state of the stream may corrupt this
199    /// object, so care must be taken when using this method.
200    pub fn get_mut(&mut self) -> &mut W {
201        self.obj.as_mut().unwrap()
202    }
203
204    fn dump(&mut self) -> io::Result<()> {
205        while self.buf.len() > 0 {
206            let n = match self.obj.as_mut().unwrap().write(&self.buf) {
207                Ok(n) => n,
208                Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
209                Err(err) => return Err(err),
210            };
211            self.buf.drain(..n);
212        }
213        Ok(())
214    }
215
216    /// Attempt to finish this output stream, writing out final chunks of data.
217    ///
218    /// Note that this function can only be used once data has finished being
219    /// written to the output stream. After this function is called then further
220    /// calls to `write` may result in a panic.
221    ///
222    /// # Panics
223    ///
224    /// Attempts to write data to this stream may result in a panic after this
225    /// function is called.
226    pub fn try_finish(&mut self) -> io::Result<()> {
227        while !self.done {
228            self.write(&[])?;
229        }
230        self.dump()
231    }
232
233    /// Unwrap the underlying writer, finishing the compression stream.
234    ///
235    /// Note that this function may not be suitable to call in a situation where
236    /// the underlying stream is an asynchronous I/O stream. To finish a stream
237    /// the `try_finish` (or `shutdown`) method should be used instead. To
238    /// re-acquire ownership of a stream it is safe to call this method after
239    /// `try_finish` or `shutdown` has returned `Ok`.
240    pub fn finish(&mut self) -> io::Result<W> {
241        self.try_finish()?;
242        Ok(self.obj.take().unwrap())
243    }
244
245    /// Returns the number of bytes produced by the decompressor
246    ///
247    /// Note that, due to buffering, this only bears any relation to
248    /// `total_in()` after a call to `flush()`.  At that point,
249    /// `total_in() / total_out()` is the compression ratio.
250    pub fn total_out(&self) -> u64 {
251        self.data.total_out()
252    }
253
254    /// Returns the number of bytes consumed by the decompressor
255    /// (e.g. the number of bytes written to this stream.)
256    pub fn total_in(&self) -> u64 {
257        self.data.total_in()
258    }
259}
260
261impl<W: Write> Write for BzDecoder<W> {
262    fn write(&mut self, data: &[u8]) -> io::Result<usize> {
263        if self.done {
264            return Ok(0);
265        }
266        loop {
267            self.dump()?;
268
269            let before = self.total_in();
270            let res = self.data.decompress_vec(data, &mut self.buf);
271            let written = (self.total_in() - before) as usize;
272
273            let res = res.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
274
275            if res == Status::StreamEnd {
276                self.done = true;
277            }
278            if written > 0 || data.len() == 0 || self.done {
279                return Ok(written);
280            }
281        }
282    }
283
284    fn flush(&mut self) -> io::Result<()> {
285        self.dump()?;
286        self.obj.as_mut().unwrap().flush()
287    }
288}
289
290#[cfg(feature = "tokio")]
291impl<W: AsyncWrite> AsyncWrite for BzDecoder<W> {
292    fn shutdown(&mut self) -> Poll<(), io::Error> {
293        try_nb!(self.try_finish());
294        self.get_mut().shutdown()
295    }
296}
297
298impl<W: Read + Write> Read for BzDecoder<W> {
299    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
300        self.get_mut().read(buf)
301    }
302}
303
304#[cfg(feature = "tokio")]
305impl<W: AsyncRead + AsyncWrite> AsyncRead for BzDecoder<W> {}
306
307impl<W: Write> Drop for BzDecoder<W> {
308    fn drop(&mut self) {
309        if self.obj.is_some() {
310            let _ = self.try_finish();
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::{BzDecoder, BzEncoder};
318    use partial_io::{GenInterrupted, PartialWithErrors, PartialWrite};
319    use std::io::prelude::*;
320    use std::iter::repeat;
321
322    #[test]
323    fn smoke() {
324        let d = BzDecoder::new(Vec::new());
325        let mut c = BzEncoder::new(d, ::Compression::default());
326        c.write_all(b"12834").unwrap();
327        let s = repeat("12345").take(100000).collect::<String>();
328        c.write_all(s.as_bytes()).unwrap();
329        let data = c.finish().unwrap().finish().unwrap();
330        assert_eq!(&data[0..5], b"12834");
331        assert_eq!(data.len(), 500005);
332        assert!(format!("12834{}", s).as_bytes() == &*data);
333    }
334
335    #[test]
336    fn write_empty() {
337        let d = BzDecoder::new(Vec::new());
338        let mut c = BzEncoder::new(d, ::Compression::default());
339        c.write(b"").unwrap();
340        let data = c.finish().unwrap().finish().unwrap();
341        assert_eq!(&data[..], b"");
342    }
343
344    #[test]
345    fn qc() {
346        ::quickcheck::quickcheck(test as fn(_) -> _);
347
348        fn test(v: Vec<u8>) -> bool {
349            let w = BzDecoder::new(Vec::new());
350            let mut w = BzEncoder::new(w, ::Compression::default());
351            w.write_all(&v).unwrap();
352            v == w.finish().unwrap().finish().unwrap()
353        }
354    }
355
356    #[test]
357    fn qc_partial() {
358        quickcheck6::quickcheck(test as fn(_, _, _) -> _);
359
360        fn test(
361            v: Vec<u8>,
362            encode_ops: PartialWithErrors<GenInterrupted>,
363            decode_ops: PartialWithErrors<GenInterrupted>,
364        ) -> bool {
365            let w = BzDecoder::new(PartialWrite::new(Vec::new(), decode_ops));
366            let mut w = BzEncoder::new(PartialWrite::new(w, encode_ops), ::Compression::default());
367            w.write_all(&v).unwrap();
368            v == w
369                .finish()
370                .unwrap()
371                .into_inner()
372                .finish()
373                .unwrap()
374                .into_inner()
375        }
376    }
377}