1use std::io::{Read, Write};
26use std::str::FromStr;
27
28use anyhow::Error;
29use flate2::Compression;
30use flate2::read::DeflateDecoder;
31use flate2::write::DeflateEncoder;
32
33use crate::error::{DecodeError, Error as AvroError};
34use crate::types::{ToAvro, Value};
35
36#[derive(Clone, Copy, Debug, PartialEq)]
38pub enum Codec {
39 Null,
41 Deflate,
45 #[cfg(feature = "snappy")]
46 Snappy,
50}
51
52impl ToAvro for Codec {
53 fn avro(self) -> Value {
54 Value::Bytes(
55 match self {
56 Codec::Null => "null",
57 Codec::Deflate => "deflate",
58 #[cfg(feature = "snappy")]
59 Codec::Snappy => "snappy",
60 }
61 .to_owned()
62 .into_bytes(),
63 )
64 }
65}
66
67impl FromStr for Codec {
68 type Err = AvroError;
69
70 fn from_str(s: &str) -> Result<Self, Self::Err> {
71 match s {
72 "null" => Ok(Codec::Null),
73 "deflate" => Ok(Codec::Deflate),
74 #[cfg(feature = "snappy")]
75 "snappy" => Ok(Codec::Snappy),
76 other => Err(DecodeError::UnrecognizedCodec(other.to_string()).into()),
77 }
78 }
79}
80
81impl Codec {
82 pub fn compress(self, stream: &mut Vec<u8>) -> Result<(), Error> {
84 match self {
85 Codec::Null => (),
86 Codec::Deflate => {
87 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default());
88 encoder.write_all(stream)?;
89 *stream = encoder.finish()?;
90 }
91 #[cfg(feature = "snappy")]
92 Codec::Snappy => {
93 use byteorder::ByteOrder;
94
95 let mut encoded: Vec<u8> = vec![0; snap::raw::max_compress_len(stream.len())];
96 let compressed_size =
97 snap::raw::Encoder::new().compress(&stream[..], &mut encoded[..])?;
98
99 let crc = {
100 let mut hasher = crc32fast::Hasher::new();
101 hasher.update(stream);
102 hasher.finalize()
103 };
104 byteorder::BigEndian::write_u32(&mut encoded[compressed_size..], crc);
105 encoded.truncate(compressed_size + 4);
106
107 *stream = encoded;
108 }
109 };
110
111 Ok(())
112 }
113
114 pub fn decompress(self, stream: &mut Vec<u8>) -> Result<(), AvroError> {
116 match self {
117 Codec::Null => (),
118 Codec::Deflate => {
119 let mut decoded = Vec::new();
120 {
121 let mut decoder = DeflateDecoder::new(&**stream);
122 decoder.read_to_end(&mut decoded)?;
123 }
124 *stream = decoded;
125 }
126 #[cfg(feature = "snappy")]
127 Codec::Snappy => {
128 use byteorder::ByteOrder;
129
130 let decompressed_size = snap::raw::decompress_len(&stream[..stream.len() - 4])
131 .map_err(std::io::Error::from)?;
132 let mut decoded = vec![0; decompressed_size];
133 snap::raw::Decoder::new()
134 .decompress(&stream[..stream.len() - 4], &mut decoded[..])
135 .map_err(std::io::Error::from)?;
136
137 let expected_crc = byteorder::BigEndian::read_u32(&stream[stream.len() - 4..]);
138 let actual_crc = {
139 let mut hasher = crc32fast::Hasher::new();
140 hasher.update(&decoded);
141 hasher.finalize()
142 };
143
144 if expected_crc != actual_crc {
145 return Err(DecodeError::BadSnappyChecksum {
146 expected: expected_crc,
147 actual: actual_crc,
148 }
149 .into());
150 }
151 *stream = decoded;
152 }
153 };
154
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 static INPUT: &[u8] = b"theanswertolifetheuniverseandeverythingis42theanswertolifetheuniverseandeverythingis4theanswertolifetheuniverseandeverythingis2";
164
165 #[mz_ore::test]
166 fn null_compress_and_decompress() {
167 let codec = Codec::Null;
168 let mut stream = INPUT.to_vec();
169 codec.compress(&mut stream).unwrap();
170 assert_eq!(INPUT, stream.as_slice());
171 codec.decompress(&mut stream).unwrap();
172 assert_eq!(INPUT, stream.as_slice());
173 }
174
175 #[mz_ore::test]
176 #[cfg_attr(miri, ignore)] fn deflate_compress_and_decompress() {
178 let codec = Codec::Deflate;
179 let mut stream = INPUT.to_vec();
180 codec.compress(&mut stream).unwrap();
181 assert_ne!(INPUT, stream.as_slice());
182 assert!(INPUT.len() > stream.len());
183 codec.decompress(&mut stream).unwrap();
184 assert_eq!(INPUT, stream.as_slice());
185 }
186
187 #[cfg(feature = "snappy")]
188 #[mz_ore::test]
189 fn snappy_compress_and_decompress() {
190 let codec = Codec::Snappy;
191 let mut stream = INPUT.to_vec();
192 codec.compress(&mut stream).unwrap();
193 assert_ne!(INPUT, stream.as_slice());
194 assert!(INPUT.len() > stream.len());
195 codec.decompress(&mut stream).unwrap();
196 assert_eq!(INPUT, stream.as_slice());
197 }
198}