1use std::io;
4use std::io::prelude::*;
5
6use crc32fast::Hasher;
7
8pub struct Crc32Reader<R> {
10 inner: R,
11 hasher: Hasher,
12 check: u32,
13 enabled: bool,
16}
17
18impl<R> Crc32Reader<R> {
19 pub(crate) fn new(inner: R, checksum: u32, ae2_encrypted: bool) -> Crc32Reader<R> {
22 Crc32Reader {
23 inner,
24 hasher: Hasher::new(),
25 check: checksum,
26 enabled: !ae2_encrypted,
27 }
28 }
29
30 fn check_matches(&self) -> bool {
31 self.check == self.hasher.clone().finalize()
32 }
33
34 pub fn into_inner(self) -> R {
35 self.inner
36 }
37}
38
39#[cold]
40fn invalid_checksum() -> io::Error {
41 io::Error::new(io::ErrorKind::InvalidData, "Invalid checksum")
42}
43
44impl<R: Read> Read for Crc32Reader<R> {
45 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
46 let count = self.inner.read(buf)?;
47
48 if self.enabled {
49 if count == 0 && !buf.is_empty() && !self.check_matches() {
50 return Err(invalid_checksum());
51 }
52 self.hasher.update(&buf[..count]);
53 }
54 Ok(count)
55 }
56
57 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
58 let start = buf.len();
59 let n = self.inner.read_to_end(buf)?;
60
61 if self.enabled {
62 self.hasher.update(&buf[start..]);
63 if !self.check_matches() {
64 return Err(invalid_checksum());
65 }
66 }
67
68 Ok(n)
69 }
70
71 fn read_to_string(&mut self, buf: &mut String) -> io::Result<usize> {
72 let start = buf.len();
73 let n = self.inner.read_to_string(buf)?;
74
75 if self.enabled {
76 self.hasher.update(&buf.as_bytes()[start..]);
77 if !self.check_matches() {
78 return Err(invalid_checksum());
79 }
80 }
81
82 Ok(n)
83 }
84}
85
86#[cfg(test)]
87mod test {
88 use super::*;
89
90 #[test]
91 fn test_empty_reader() {
92 let data: &[u8] = b"";
93 let mut buf = [0; 1];
94
95 let mut reader = Crc32Reader::new(data, 0, false);
96 assert_eq!(reader.read(&mut buf).unwrap(), 0);
97
98 let mut reader = Crc32Reader::new(data, 1, false);
99 assert!(reader
100 .read(&mut buf)
101 .unwrap_err()
102 .to_string()
103 .contains("Invalid checksum"));
104 }
105
106 #[test]
107 fn test_byte_by_byte() {
108 let data: &[u8] = b"1234";
109 let mut buf = [0; 1];
110
111 let mut reader = Crc32Reader::new(data, 0x9be3e0a3, false);
112 assert_eq!(reader.read(&mut buf).unwrap(), 1);
113 assert_eq!(reader.read(&mut buf).unwrap(), 1);
114 assert_eq!(reader.read(&mut buf).unwrap(), 1);
115 assert_eq!(reader.read(&mut buf).unwrap(), 1);
116 assert_eq!(reader.read(&mut buf).unwrap(), 0);
117 assert_eq!(reader.read(&mut buf).unwrap(), 0);
119 }
120
121 #[test]
122 fn test_zero_read() {
123 let data: &[u8] = b"1234";
124 let mut buf = [0; 5];
125
126 let mut reader = Crc32Reader::new(data, 0x9be3e0a3, false);
127 assert_eq!(reader.read(&mut buf[..0]).unwrap(), 0);
128 assert_eq!(reader.read(&mut buf).unwrap(), 4);
129 }
130}