1use std::fmt::Debug;
13use std::io::Write;
14use std::sync::Arc;
15
16use arrow::array::{Array, RecordBatch};
17use arrow::datatypes::{Fields, Schema as ArrowSchema};
18use parquet::arrow::ArrowWriter;
19use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
20use parquet::basic::Encoding;
21use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
22use parquet::file::reader::ChunkReader;
23use proptest::prelude::*;
24use proptest_derive::Arbitrary;
25
26#[derive(Debug, Copy, Clone, Arbitrary)]
28pub struct EncodingConfig {
29 pub use_dictionary: bool,
31 pub compression: CompressionFormat,
33}
34
35impl Default for EncodingConfig {
36 fn default() -> Self {
37 EncodingConfig {
38 use_dictionary: false,
39 compression: CompressionFormat::default(),
40 }
41 }
42}
43
44#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Arbitrary)]
46pub enum CompressionFormat {
47 #[default]
49 None,
50 Snappy,
52 Lz4,
54 Brotli(CompressionLevel<1, 11, 1>),
56 Zstd(CompressionLevel<1, 22, 1>),
58 Gzip(CompressionLevel<1, 9, 6>),
60}
61
62impl CompressionFormat {
63 pub fn from_str(s: &str) -> Self {
65 fn parse_level<const MIN: i32, const MAX: i32, const D: i32>(
66 name: &'static str,
67 val: &str,
68 ) -> CompressionLevel<MIN, MAX, D> {
69 match CompressionLevel::from_str(val) {
70 Ok(level) => level,
71 Err(err) => {
72 tracing::error!("invalid {name} compression level, err: {err}");
73 CompressionLevel::default()
74 }
75 }
76 }
77
78 match s.to_lowercase().as_str() {
79 "" => CompressionFormat::None,
80 "none" => CompressionFormat::None,
81 "snappy" => CompressionFormat::Snappy,
82 "lz4" => CompressionFormat::Lz4,
83 other => match other.split_once('-') {
84 Some(("brotli", level)) => CompressionFormat::Brotli(parse_level("brotli", level)),
85 Some(("zstd", level)) => CompressionFormat::Zstd(parse_level("zstd", level)),
86 Some(("gzip", level)) => CompressionFormat::Gzip(parse_level("gzip", level)),
87 _ => {
88 tracing::error!("unrecognized compression format {s}, returning None");
89 CompressionFormat::None
90 }
91 },
92 }
93 }
94}
95
96impl From<CompressionFormat> for parquet::basic::Compression {
97 fn from(value: CompressionFormat) -> Self {
98 match value {
99 CompressionFormat::None => parquet::basic::Compression::UNCOMPRESSED,
100 CompressionFormat::Lz4 => parquet::basic::Compression::LZ4_RAW,
101 CompressionFormat::Snappy => parquet::basic::Compression::SNAPPY,
102 CompressionFormat::Brotli(level) => {
103 let level: u32 = level.0.try_into().expect("known not negative");
104 let level = parquet::basic::BrotliLevel::try_new(level).expect("known valid");
105 parquet::basic::Compression::BROTLI(level)
106 }
107 CompressionFormat::Zstd(level) => {
108 let level = parquet::basic::ZstdLevel::try_new(level.0).expect("known valid");
109 parquet::basic::Compression::ZSTD(level)
110 }
111 CompressionFormat::Gzip(level) => {
112 let level: u32 = level.0.try_into().expect("known not negative");
113 let level = parquet::basic::GzipLevel::try_new(level).expect("known valid");
114 parquet::basic::Compression::GZIP(level)
115 }
116 }
117 }
118}
119
120#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122pub struct CompressionLevel<const MIN: i32, const MAX: i32, const DEFAULT: i32>(i32);
123
124impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Default
125 for CompressionLevel<MIN, MAX, DEFAULT>
126{
127 fn default() -> Self {
128 CompressionLevel(DEFAULT)
129 }
130}
131
132impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> CompressionLevel<MIN, MAX, DEFAULT> {
133 pub const fn try_new(val: i32) -> Result<Self, i32> {
136 if val >= MIN && val <= MAX {
137 Ok(CompressionLevel(val))
138 } else {
139 Err(val)
140 }
141 }
142
143 pub fn from_str(s: &str) -> Result<Self, String> {
146 let val = s.parse::<i32>().map_err(|e| e.to_string())?;
147 Self::try_new(val).map_err(|e| e.to_string())
148 }
149}
150
151impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Arbitrary
152 for CompressionLevel<MIN, MAX, DEFAULT>
153{
154 type Parameters = ();
155 type Strategy = BoxedStrategy<CompressionLevel<MIN, MAX, DEFAULT>>;
156
157 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
158 ({ MIN }..={ MAX }).prop_map(CompressionLevel).boxed()
159 }
160}
161
162pub fn encode_arrays<W: Write + Send>(
164 w: &mut W,
165 fields: Fields,
166 arrays: Vec<Arc<dyn Array>>,
167 config: &EncodingConfig,
168) -> Result<(), anyhow::Error> {
169 let schema = Arc::new(ArrowSchema::new(fields));
170 let props = WriterProperties::builder()
171 .set_dictionary_enabled(config.use_dictionary)
172 .set_encoding(Encoding::PLAIN)
173 .set_statistics_enabled(EnabledStatistics::None)
174 .set_compression(config.compression.into())
175 .set_writer_version(WriterVersion::PARQUET_2_0)
176 .set_data_page_size_limit(1024 * 1024)
177 .set_max_row_group_size(usize::MAX)
178 .build();
179 let mut writer = ArrowWriter::try_new(w, Arc::clone(&schema), Some(props))?;
180
181 let record_batch = RecordBatch::try_new(schema, arrays)?;
182
183 writer.write(&record_batch)?;
184 writer.flush()?;
185 writer.close()?;
186
187 Ok(())
188}
189
190pub fn decode_arrays<R: ChunkReader + 'static>(
192 r: R,
193) -> Result<ParquetRecordBatchReader, anyhow::Error> {
194 let builder = ParquetRecordBatchReaderBuilder::try_new(r)?;
195
196 let row_groups = builder.metadata().row_groups();
198 if row_groups.len() > 1 {
199 anyhow::bail!("found more than 1 RowGroup")
200 }
201 let num_rows = row_groups
202 .get(0)
203 .map(|g| g.num_rows())
204 .unwrap_or(1024)
205 .try_into()
206 .unwrap();
207 let builder = builder.with_batch_size(num_rows);
208
209 let reader = builder.build()?;
210 Ok(reader)
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[mz_ore::test]
218 fn smoketest_compression_level_parsing() {
219 let cases = &[
220 ("", CompressionFormat::None),
221 ("none", CompressionFormat::None),
222 ("snappy", CompressionFormat::Snappy),
223 ("lz4", CompressionFormat::Lz4),
224 ("lZ4", CompressionFormat::Lz4),
225 ("gzip-1", CompressionFormat::Gzip(CompressionLevel(1))),
226 ("GZIp-6", CompressionFormat::Gzip(CompressionLevel(6))),
227 ("gzip-9", CompressionFormat::Gzip(CompressionLevel(9))),
228 ("brotli-1", CompressionFormat::Brotli(CompressionLevel(1))),
229 ("BROtli-8", CompressionFormat::Brotli(CompressionLevel(8))),
230 ("brotli-11", CompressionFormat::Brotli(CompressionLevel(11))),
231 ("zstd-1", CompressionFormat::Zstd(CompressionLevel(1))),
232 ("zstD-10", CompressionFormat::Zstd(CompressionLevel(10))),
233 ("zstd-22", CompressionFormat::Zstd(CompressionLevel(22))),
234 ("foo", CompressionFormat::None),
235 ("gzip-0", CompressionFormat::Gzip(Default::default())),
237 ("gzip-10", CompressionFormat::Gzip(Default::default())),
238 ("brotli-0", CompressionFormat::Brotli(Default::default())),
239 ("brotli-12", CompressionFormat::Brotli(Default::default())),
240 ("zstd-0", CompressionFormat::Zstd(Default::default())),
241 ("zstd-23", CompressionFormat::Zstd(Default::default())),
242 ];
243 for (s, val) in cases {
244 assert_eq!(CompressionFormat::from_str(s), *val);
245 }
246 }
247}