1use std::io;
4use std::io::prelude::*;
5
6#[cfg(feature = "tokio")]
7use futures::Poll;
8#[cfg(feature = "tokio")]
9use tokio_io::{AsyncRead, AsyncWrite};
10
11use {Action, Compress, Compression, Decompress, Status};
12
13pub struct BzEncoder<W: Write> {
16 data: Compress,
17 obj: Option<W>,
18 buf: Vec<u8>,
19 done: bool,
20}
21
22pub struct BzDecoder<W: Write> {
25 data: Decompress,
26 obj: Option<W>,
27 buf: Vec<u8>,
28 done: bool,
29}
30
31impl<W: Write> BzEncoder<W> {
32 pub fn new(obj: W, level: Compression) -> BzEncoder<W> {
35 BzEncoder {
36 data: Compress::new(level, 30),
37 obj: Some(obj),
38 buf: Vec::with_capacity(32 * 1024),
39 done: false,
40 }
41 }
42
43 fn dump(&mut self) -> io::Result<()> {
44 while self.buf.len() > 0 {
45 let n = match self.obj.as_mut().unwrap().write(&self.buf) {
46 Ok(n) => n,
47 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
48 Err(err) => return Err(err),
49 };
50 self.buf.drain(..n);
51 }
52 Ok(())
53 }
54
55 pub fn get_ref(&self) -> &W {
57 self.obj.as_ref().unwrap()
58 }
59
60 pub fn get_mut(&mut self) -> &mut W {
65 self.obj.as_mut().unwrap()
66 }
67
68 pub fn try_finish(&mut self) -> io::Result<()> {
79 while !self.done {
80 self.dump()?;
81 let res = self.data.compress_vec(&[], &mut self.buf, Action::Finish);
82 if res == Ok(Status::StreamEnd) {
83 self.done = true;
84 break;
85 }
86 }
87 self.dump()
88 }
89
90 pub fn finish(mut self) -> io::Result<W> {
101 self.try_finish()?;
102 Ok(self.obj.take().unwrap())
103 }
104
105 pub fn total_out(&self) -> u64 {
111 self.data.total_out()
112 }
113
114 pub fn total_in(&self) -> u64 {
117 self.data.total_in()
118 }
119}
120
121impl<W: Write> Write for BzEncoder<W> {
122 fn write(&mut self, data: &[u8]) -> io::Result<usize> {
123 loop {
124 self.dump()?;
125
126 let total_in = self.total_in();
127 self.data
128 .compress_vec(data, &mut self.buf, Action::Run)
129 .unwrap();
130 let written = (self.total_in() - total_in) as usize;
131
132 if written > 0 || data.len() == 0 {
133 return Ok(written);
134 }
135 }
136 }
137
138 fn flush(&mut self) -> io::Result<()> {
139 loop {
140 self.dump()?;
141 let before = self.total_out();
142 self.data
143 .compress_vec(&[], &mut self.buf, Action::Flush)
144 .unwrap();
145
146 if before == self.total_out() {
147 break;
148 }
149 }
150 self.obj.as_mut().unwrap().flush()
151 }
152}
153
154#[cfg(feature = "tokio")]
155impl<W: AsyncWrite> AsyncWrite for BzEncoder<W> {
156 fn shutdown(&mut self) -> Poll<(), io::Error> {
157 try_nb!(self.try_finish());
158 self.get_mut().shutdown()
159 }
160}
161
162impl<W: Read + Write> Read for BzEncoder<W> {
163 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
164 self.get_mut().read(buf)
165 }
166}
167
168#[cfg(feature = "tokio")]
169impl<W: AsyncRead + AsyncWrite> AsyncRead for BzEncoder<W> {}
170
171impl<W: Write> Drop for BzEncoder<W> {
172 fn drop(&mut self) {
173 if self.obj.is_some() {
174 let _ = self.try_finish();
175 }
176 }
177}
178
179impl<W: Write> BzDecoder<W> {
180 pub fn new(obj: W) -> BzDecoder<W> {
183 BzDecoder {
184 data: Decompress::new(false),
185 obj: Some(obj),
186 buf: Vec::with_capacity(32 * 1024),
187 done: false,
188 }
189 }
190
191 pub fn get_ref(&self) -> &W {
193 self.obj.as_ref().unwrap()
194 }
195
196 pub fn get_mut(&mut self) -> &mut W {
201 self.obj.as_mut().unwrap()
202 }
203
204 fn dump(&mut self) -> io::Result<()> {
205 while self.buf.len() > 0 {
206 let n = match self.obj.as_mut().unwrap().write(&self.buf) {
207 Ok(n) => n,
208 Err(ref err) if err.kind() == io::ErrorKind::Interrupted => continue,
209 Err(err) => return Err(err),
210 };
211 self.buf.drain(..n);
212 }
213 Ok(())
214 }
215
216 pub fn try_finish(&mut self) -> io::Result<()> {
227 while !self.done {
228 self.write(&[])?;
229 }
230 self.dump()
231 }
232
233 pub fn finish(&mut self) -> io::Result<W> {
241 self.try_finish()?;
242 Ok(self.obj.take().unwrap())
243 }
244
245 pub fn total_out(&self) -> u64 {
251 self.data.total_out()
252 }
253
254 pub fn total_in(&self) -> u64 {
257 self.data.total_in()
258 }
259}
260
261impl<W: Write> Write for BzDecoder<W> {
262 fn write(&mut self, data: &[u8]) -> io::Result<usize> {
263 if self.done {
264 return Ok(0);
265 }
266 loop {
267 self.dump()?;
268
269 let before = self.total_in();
270 let res = self.data.decompress_vec(data, &mut self.buf);
271 let written = (self.total_in() - before) as usize;
272
273 let res = res.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
274
275 if res == Status::StreamEnd {
276 self.done = true;
277 }
278 if written > 0 || data.len() == 0 || self.done {
279 return Ok(written);
280 }
281 }
282 }
283
284 fn flush(&mut self) -> io::Result<()> {
285 self.dump()?;
286 self.obj.as_mut().unwrap().flush()
287 }
288}
289
290#[cfg(feature = "tokio")]
291impl<W: AsyncWrite> AsyncWrite for BzDecoder<W> {
292 fn shutdown(&mut self) -> Poll<(), io::Error> {
293 try_nb!(self.try_finish());
294 self.get_mut().shutdown()
295 }
296}
297
298impl<W: Read + Write> Read for BzDecoder<W> {
299 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
300 self.get_mut().read(buf)
301 }
302}
303
304#[cfg(feature = "tokio")]
305impl<W: AsyncRead + AsyncWrite> AsyncRead for BzDecoder<W> {}
306
307impl<W: Write> Drop for BzDecoder<W> {
308 fn drop(&mut self) {
309 if self.obj.is_some() {
310 let _ = self.try_finish();
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::{BzDecoder, BzEncoder};
318 use partial_io::{GenInterrupted, PartialWithErrors, PartialWrite};
319 use std::io::prelude::*;
320 use std::iter::repeat;
321
322 #[test]
323 fn smoke() {
324 let d = BzDecoder::new(Vec::new());
325 let mut c = BzEncoder::new(d, ::Compression::default());
326 c.write_all(b"12834").unwrap();
327 let s = repeat("12345").take(100000).collect::<String>();
328 c.write_all(s.as_bytes()).unwrap();
329 let data = c.finish().unwrap().finish().unwrap();
330 assert_eq!(&data[0..5], b"12834");
331 assert_eq!(data.len(), 500005);
332 assert!(format!("12834{}", s).as_bytes() == &*data);
333 }
334
335 #[test]
336 fn write_empty() {
337 let d = BzDecoder::new(Vec::new());
338 let mut c = BzEncoder::new(d, ::Compression::default());
339 c.write(b"").unwrap();
340 let data = c.finish().unwrap().finish().unwrap();
341 assert_eq!(&data[..], b"");
342 }
343
344 #[test]
345 fn qc() {
346 ::quickcheck::quickcheck(test as fn(_) -> _);
347
348 fn test(v: Vec<u8>) -> bool {
349 let w = BzDecoder::new(Vec::new());
350 let mut w = BzEncoder::new(w, ::Compression::default());
351 w.write_all(&v).unwrap();
352 v == w.finish().unwrap().finish().unwrap()
353 }
354 }
355
356 #[test]
357 fn qc_partial() {
358 quickcheck6::quickcheck(test as fn(_, _, _) -> _);
359
360 fn test(
361 v: Vec<u8>,
362 encode_ops: PartialWithErrors<GenInterrupted>,
363 decode_ops: PartialWithErrors<GenInterrupted>,
364 ) -> bool {
365 let w = BzDecoder::new(PartialWrite::new(Vec::new(), decode_ops));
366 let mut w = BzEncoder::new(PartialWrite::new(w, encode_ops), ::Compression::default());
367 w.write_all(&v).unwrap();
368 v == w
369 .finish()
370 .unwrap()
371 .into_inner()
372 .finish()
373 .unwrap()
374 .into_inner()
375 }
376 }
377}