1use std::io;
4use std::io::prelude::*;
5
6#[cfg(feature = "tokio")]
7use futures::Poll;
8#[cfg(feature = "tokio")]
9use tokio_io::{AsyncRead, AsyncWrite};
10
11use {Action, Compress, Compression, Decompress, Status};
12
13pub struct BzEncoder<R> {
18 obj: R,
19 data: Compress,
20 done: bool,
21}
22
23pub struct BzDecoder<R> {
28 obj: R,
29 data: Decompress,
30 done: bool,
31 multi: bool,
32}
33
34impl<R: BufRead> BzEncoder<R> {
35 pub fn new(r: R, level: Compression) -> BzEncoder<R> {
38 BzEncoder {
39 obj: r,
40 data: Compress::new(level, 30),
41 done: false,
42 }
43 }
44}
45
46impl<R> BzEncoder<R> {
47 pub fn get_ref(&self) -> &R {
49 &self.obj
50 }
51
52 pub fn get_mut(&mut self) -> &mut R {
57 &mut self.obj
58 }
59
60 pub fn into_inner(self) -> R {
62 self.obj
63 }
64
65 pub fn total_out(&self) -> u64 {
75 self.data.total_out()
76 }
77
78 pub fn total_in(&self) -> u64 {
81 self.data.total_in()
82 }
83}
84
85impl<R: BufRead> Read for BzEncoder<R> {
86 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
87 if self.done {
88 return Ok(0);
89 }
90 loop {
91 let (read, consumed, eof, ret);
92 {
93 let input = self.obj.fill_buf()?;
94 eof = input.is_empty();
95 let before_out = self.data.total_out();
96 let before_in = self.data.total_in();
97 let action = if eof { Action::Finish } else { Action::Run };
98 ret = self.data.compress(input, buf, action);
99 read = (self.data.total_out() - before_out) as usize;
100 consumed = (self.data.total_in() - before_in) as usize;
101 }
102 self.obj.consume(consumed);
103
104 let ret = ret.unwrap();
107
108 if read == 0 && !eof && buf.len() > 0 {
112 continue;
113 }
114 if ret == Status::StreamEnd {
115 self.done = true;
116 }
117 return Ok(read);
118 }
119 }
120}
121
122#[cfg(feature = "tokio")]
123impl<R: AsyncRead + BufRead> AsyncRead for BzEncoder<R> {}
124
125impl<W: Write> Write for BzEncoder<W> {
126 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
127 self.get_mut().write(buf)
128 }
129
130 fn flush(&mut self) -> io::Result<()> {
131 self.get_mut().flush()
132 }
133}
134
135#[cfg(feature = "tokio")]
136impl<R: AsyncWrite> AsyncWrite for BzEncoder<R> {
137 fn shutdown(&mut self) -> Poll<(), io::Error> {
138 self.get_mut().shutdown()
139 }
140}
141
142impl<R: BufRead> BzDecoder<R> {
143 pub fn new(r: R) -> BzDecoder<R> {
146 BzDecoder {
147 obj: r,
148 data: Decompress::new(false),
149 done: false,
150 multi: false,
151 }
152 }
153
154 fn multi(mut self, flag: bool) -> BzDecoder<R> {
155 self.multi = flag;
156 self
157 }
158}
159
160impl<R> BzDecoder<R> {
161 pub fn get_ref(&self) -> &R {
163 &self.obj
164 }
165
166 pub fn get_mut(&mut self) -> &mut R {
171 &mut self.obj
172 }
173
174 pub fn into_inner(self) -> R {
176 self.obj
177 }
178
179 pub fn total_in(&self) -> u64 {
184 self.data.total_in()
185 }
186
187 pub fn total_out(&self) -> u64 {
189 self.data.total_out()
190 }
191}
192
193impl<R: BufRead> Read for BzDecoder<R> {
194 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
195 loop {
196 if self.done && !self.multi {
197 return Ok(0);
198 }
199 let (read, consumed, remaining, ret);
200 {
201 let input = self.obj.fill_buf()?;
202 if self.done {
203 assert!(self.multi);
204 if input.is_empty() {
205 return Ok(0);
207 } else {
208 self.data = Decompress::new(false);
210 self.done = false;
211 }
212 }
213 let before_out = self.data.total_out();
214 let before_in = self.data.total_in();
215 ret = self.data.decompress(input, buf);
216 read = (self.data.total_out() - before_out) as usize;
217 consumed = (self.data.total_in() - before_in) as usize;
218 remaining = input.len() - consumed;
219 }
220 self.obj.consume(consumed);
221
222 let ret = ret.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
223 if ret == Status::StreamEnd {
224 self.done = true;
225 } else if consumed == 0 && remaining == 0 && read == 0 {
226 return Err(io::Error::new(
227 io::ErrorKind::UnexpectedEof,
228 "decompression not finished but EOF reached",
229 ));
230 }
231
232 if read > 0 || buf.len() == 0 {
233 return Ok(read);
234 }
235 }
236 }
237}
238
239#[cfg(feature = "tokio")]
240impl<R: AsyncRead + BufRead> AsyncRead for BzDecoder<R> {}
241
242impl<W: Write> Write for BzDecoder<W> {
243 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
244 self.get_mut().write(buf)
245 }
246
247 fn flush(&mut self) -> io::Result<()> {
248 self.get_mut().flush()
249 }
250}
251
252#[cfg(feature = "tokio")]
253impl<R: AsyncWrite> AsyncWrite for BzDecoder<R> {
254 fn shutdown(&mut self) -> Poll<(), io::Error> {
255 self.get_mut().shutdown()
256 }
257}
258
259pub struct MultiBzDecoder<R>(BzDecoder<R>);
264
265impl<R: BufRead> MultiBzDecoder<R> {
266 pub fn new(r: R) -> MultiBzDecoder<R> {
269 MultiBzDecoder(BzDecoder::new(r).multi(true))
270 }
271}
272
273impl<R> MultiBzDecoder<R> {
274 pub fn get_ref(&self) -> &R {
276 self.0.get_ref()
277 }
278
279 pub fn get_mut(&mut self) -> &mut R {
284 self.0.get_mut()
285 }
286
287 pub fn into_inner(self) -> R {
289 self.0.into_inner()
290 }
291}
292
293impl<R: BufRead> Read for MultiBzDecoder<R> {
294 fn read(&mut self, into: &mut [u8]) -> io::Result<usize> {
295 self.0.read(into)
296 }
297}
298
299#[cfg(feature = "tokio")]
300impl<R: AsyncRead + BufRead> AsyncRead for MultiBzDecoder<R> {}
301
302impl<R: BufRead + Write> Write for MultiBzDecoder<R> {
303 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
304 self.get_mut().write(buf)
305 }
306
307 fn flush(&mut self) -> io::Result<()> {
308 self.get_mut().flush()
309 }
310}
311
312#[cfg(feature = "tokio")]
313impl<R: AsyncWrite + BufRead> AsyncWrite for MultiBzDecoder<R> {
314 fn shutdown(&mut self) -> Poll<(), io::Error> {
315 self.get_mut().shutdown()
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::MultiBzDecoder;
322 use std::io::{BufReader, Read};
323
324 #[test]
325 fn bug_61() {
326 let compressed_bytes = include_bytes!("../tests/bug_61.bz2");
327 let uncompressed_bytes = include_bytes!("../tests/bug_61.raw");
328 let reader = BufReader::with_capacity(8192, compressed_bytes.as_ref());
329
330 let mut d = MultiBzDecoder::new(reader);
331 let mut data = Vec::new();
332
333 assert_eq!(d.read_to_end(&mut data).unwrap(), uncompressed_bytes.len());
334 assert_eq!(data, uncompressed_bytes);
335 }
336}