1use std::fmt::Debug;
13use std::io::Write;
14use std::sync::Arc;
15
16use arrow::array::{Array, RecordBatch};
17use arrow::datatypes::{Fields, Schema as ArrowSchema};
18use parquet::arrow::ArrowWriter;
19use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
20use parquet::basic::Encoding;
21use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
22use parquet::file::reader::ChunkReader;
23use proptest::prelude::*;
24use proptest_derive::Arbitrary;
25
26#[derive(Debug, Copy, Clone, Arbitrary)]
28pub struct EncodingConfig {
29 pub use_dictionary: bool,
31 pub compression: CompressionFormat,
33}
34
35impl Default for EncodingConfig {
36 fn default() -> Self {
37 EncodingConfig {
38 use_dictionary: false,
39 compression: CompressionFormat::default(),
40 }
41 }
42}
43
44#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Arbitrary)]
46pub enum CompressionFormat {
47 #[default]
49 None,
50 Snappy,
52 Lz4,
54 Brotli(CompressionLevel<1, 11, 1>),
56 Zstd(CompressionLevel<1, 22, 1>),
58 Gzip(CompressionLevel<1, 9, 6>),
60}
61
62impl CompressionFormat {
63 pub fn from_str(s: &str) -> Self {
65 fn parse_level<const MIN: i32, const MAX: i32, const D: i32>(
66 name: &'static str,
67 val: &str,
68 ) -> CompressionLevel<MIN, MAX, D> {
69 match CompressionLevel::from_str(val) {
70 Ok(level) => level,
71 Err(err) => {
72 tracing::error!("invalid {name} compression level, err: {err}");
73 CompressionLevel::default()
74 }
75 }
76 }
77
78 match s.to_lowercase().as_str() {
79 "" => CompressionFormat::None,
80 "none" => CompressionFormat::None,
81 "snappy" => CompressionFormat::Snappy,
82 "lz4" => CompressionFormat::Lz4,
83 other => match other.split_once('-') {
84 Some(("brotli", level)) => CompressionFormat::Brotli(parse_level("brotli", level)),
85 Some(("zstd", level)) => CompressionFormat::Zstd(parse_level("zstd", level)),
86 Some(("gzip", level)) => CompressionFormat::Gzip(parse_level("gzip", level)),
87 _ => {
88 tracing::error!("unrecognized compression format {s}, returning None");
89 CompressionFormat::None
90 }
91 },
92 }
93 }
94}
95
96impl From<CompressionFormat> for parquet::basic::Compression {
97 fn from(value: CompressionFormat) -> Self {
98 match value {
99 CompressionFormat::None => parquet::basic::Compression::UNCOMPRESSED,
100 CompressionFormat::Lz4 => parquet::basic::Compression::LZ4_RAW,
101 CompressionFormat::Snappy => parquet::basic::Compression::SNAPPY,
102 CompressionFormat::Brotli(level) => {
103 let level: u32 = level.0.try_into().expect("known not negative");
104 let level = parquet::basic::BrotliLevel::try_new(level).expect("known valid");
105 parquet::basic::Compression::BROTLI(level)
106 }
107 CompressionFormat::Zstd(level) => {
108 let level = parquet::basic::ZstdLevel::try_new(level.0).expect("known valid");
109 parquet::basic::Compression::ZSTD(level)
110 }
111 CompressionFormat::Gzip(level) => {
112 let level: u32 = level.0.try_into().expect("known not negative");
113 let level = parquet::basic::GzipLevel::try_new(level).expect("known valid");
114 parquet::basic::Compression::GZIP(level)
115 }
116 }
117 }
118}
119
120#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122pub struct CompressionLevel<const MIN: i32, const MAX: i32, const DEFAULT: i32>(i32);
123
124impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Default
125 for CompressionLevel<MIN, MAX, DEFAULT>
126{
127 fn default() -> Self {
128 CompressionLevel(DEFAULT)
129 }
130}
131
132impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> CompressionLevel<MIN, MAX, DEFAULT> {
133 pub const fn try_new(val: i32) -> Result<Self, i32> {
136 if val < MIN {
137 Err(val)
138 } else if val > MAX {
139 Err(val)
140 } else {
141 Ok(CompressionLevel(val))
142 }
143 }
144
145 pub fn from_str(s: &str) -> Result<Self, String> {
148 let val = s.parse::<i32>().map_err(|e| e.to_string())?;
149 Self::try_new(val).map_err(|e| e.to_string())
150 }
151}
152
153impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Arbitrary
154 for CompressionLevel<MIN, MAX, DEFAULT>
155{
156 type Parameters = ();
157 type Strategy = BoxedStrategy<CompressionLevel<MIN, MAX, DEFAULT>>;
158
159 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
160 ({ MIN }..={ MAX }).prop_map(CompressionLevel).boxed()
161 }
162}
163
164pub fn encode_arrays<W: Write + Send>(
166 w: &mut W,
167 fields: Fields,
168 arrays: Vec<Arc<dyn Array>>,
169 config: &EncodingConfig,
170) -> Result<(), anyhow::Error> {
171 let schema = Arc::new(ArrowSchema::new(fields));
172 let props = WriterProperties::builder()
173 .set_dictionary_enabled(config.use_dictionary)
174 .set_encoding(Encoding::PLAIN)
175 .set_statistics_enabled(EnabledStatistics::None)
176 .set_compression(config.compression.into())
177 .set_writer_version(WriterVersion::PARQUET_2_0)
178 .set_data_page_size_limit(1024 * 1024)
179 .set_max_row_group_size(usize::MAX)
180 .build();
181 let mut writer = ArrowWriter::try_new(w, Arc::clone(&schema), Some(props))?;
182
183 let record_batch = RecordBatch::try_new(schema, arrays)?;
184
185 writer.write(&record_batch)?;
186 writer.flush()?;
187 writer.close()?;
188
189 Ok(())
190}
191
192pub fn decode_arrays<R: ChunkReader + 'static>(
194 r: R,
195) -> Result<ParquetRecordBatchReader, anyhow::Error> {
196 let builder = ParquetRecordBatchReaderBuilder::try_new(r)?;
197
198 let row_groups = builder.metadata().row_groups();
200 if row_groups.len() > 1 {
201 anyhow::bail!("found more than 1 RowGroup")
202 }
203 let num_rows = row_groups
204 .get(0)
205 .map(|g| g.num_rows())
206 .unwrap_or(1024)
207 .try_into()
208 .unwrap();
209 let builder = builder.with_batch_size(num_rows);
210
211 let reader = builder.build()?;
212 Ok(reader)
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[mz_ore::test]
220 fn smoketest_compression_level_parsing() {
221 let cases = &[
222 ("", CompressionFormat::None),
223 ("none", CompressionFormat::None),
224 ("snappy", CompressionFormat::Snappy),
225 ("lz4", CompressionFormat::Lz4),
226 ("lZ4", CompressionFormat::Lz4),
227 ("gzip-1", CompressionFormat::Gzip(CompressionLevel(1))),
228 ("GZIp-6", CompressionFormat::Gzip(CompressionLevel(6))),
229 ("gzip-9", CompressionFormat::Gzip(CompressionLevel(9))),
230 ("brotli-1", CompressionFormat::Brotli(CompressionLevel(1))),
231 ("BROtli-8", CompressionFormat::Brotli(CompressionLevel(8))),
232 ("brotli-11", CompressionFormat::Brotli(CompressionLevel(11))),
233 ("zstd-1", CompressionFormat::Zstd(CompressionLevel(1))),
234 ("zstD-10", CompressionFormat::Zstd(CompressionLevel(10))),
235 ("zstd-22", CompressionFormat::Zstd(CompressionLevel(22))),
236 ("foo", CompressionFormat::None),
237 ("gzip-0", CompressionFormat::Gzip(Default::default())),
239 ("gzip-10", CompressionFormat::Gzip(Default::default())),
240 ("brotli-0", CompressionFormat::Brotli(Default::default())),
241 ("brotli-12", CompressionFormat::Brotli(Default::default())),
242 ("zstd-0", CompressionFormat::Zstd(Default::default())),
243 ("zstd-23", CompressionFormat::Zstd(Default::default())),
244 ];
245 for (s, val) in cases {
246 assert_eq!(CompressionFormat::from_str(s), *val);
247 }
248 }
249}