mz_persist_types/
parquet.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Parquet serialization and deserialization for persist data.
11
12use std::fmt::Debug;
13use std::io::Write;
14use std::sync::Arc;
15
16use arrow::array::{Array, RecordBatch};
17use arrow::datatypes::{Fields, Schema as ArrowSchema};
18use parquet::arrow::ArrowWriter;
19use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
20use parquet::basic::Encoding;
21use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
22use parquet::file::reader::ChunkReader;
23use proptest::prelude::*;
24use proptest_derive::Arbitrary;
25
26/// Configuration for encoding columnar data.
27#[derive(Debug, Copy, Clone, Arbitrary)]
28pub struct EncodingConfig {
29    /// Enable dictionary encoding for Parquet data.
30    pub use_dictionary: bool,
31    /// Compression format for Parquet data.
32    pub compression: CompressionFormat,
33}
34
35impl Default for EncodingConfig {
36    fn default() -> Self {
37        EncodingConfig {
38            use_dictionary: false,
39            compression: CompressionFormat::default(),
40        }
41    }
42}
43
44/// Compression format to apply to columnar data.
45#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Arbitrary)]
46pub enum CompressionFormat {
47    /// No compression.
48    #[default]
49    None,
50    /// snappy
51    Snappy,
52    /// lz4
53    Lz4,
54    /// brotli
55    Brotli(CompressionLevel<1, 11, 1>),
56    /// zstd
57    Zstd(CompressionLevel<1, 22, 1>),
58    /// gzip
59    Gzip(CompressionLevel<1, 9, 6>),
60}
61
62impl CompressionFormat {
63    /// Parse a [`CompressionFormat`] from a string, falling back to defaults if the string is not valid.
64    pub fn from_str(s: &str) -> Self {
65        fn parse_level<const MIN: i32, const MAX: i32, const D: i32>(
66            name: &'static str,
67            val: &str,
68        ) -> CompressionLevel<MIN, MAX, D> {
69            match CompressionLevel::from_str(val) {
70                Ok(level) => level,
71                Err(err) => {
72                    tracing::error!("invalid {name} compression level, err: {err}");
73                    CompressionLevel::default()
74                }
75            }
76        }
77
78        match s.to_lowercase().as_str() {
79            "" => CompressionFormat::None,
80            "none" => CompressionFormat::None,
81            "snappy" => CompressionFormat::Snappy,
82            "lz4" => CompressionFormat::Lz4,
83            other => match other.split_once('-') {
84                Some(("brotli", level)) => CompressionFormat::Brotli(parse_level("brotli", level)),
85                Some(("zstd", level)) => CompressionFormat::Zstd(parse_level("zstd", level)),
86                Some(("gzip", level)) => CompressionFormat::Gzip(parse_level("gzip", level)),
87                _ => {
88                    tracing::error!("unrecognized compression format {s}, returning None");
89                    CompressionFormat::None
90                }
91            },
92        }
93    }
94}
95
96impl From<CompressionFormat> for parquet::basic::Compression {
97    fn from(value: CompressionFormat) -> Self {
98        match value {
99            CompressionFormat::None => parquet::basic::Compression::UNCOMPRESSED,
100            CompressionFormat::Lz4 => parquet::basic::Compression::LZ4_RAW,
101            CompressionFormat::Snappy => parquet::basic::Compression::SNAPPY,
102            CompressionFormat::Brotli(level) => {
103                let level: u32 = level.0.try_into().expect("known not negative");
104                let level = parquet::basic::BrotliLevel::try_new(level).expect("known valid");
105                parquet::basic::Compression::BROTLI(level)
106            }
107            CompressionFormat::Zstd(level) => {
108                let level = parquet::basic::ZstdLevel::try_new(level.0).expect("known valid");
109                parquet::basic::Compression::ZSTD(level)
110            }
111            CompressionFormat::Gzip(level) => {
112                let level: u32 = level.0.try_into().expect("known not negative");
113                let level = parquet::basic::GzipLevel::try_new(level).expect("known valid");
114                parquet::basic::Compression::GZIP(level)
115            }
116        }
117    }
118}
119
120/// Level of compression for columnar data.
121#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122pub struct CompressionLevel<const MIN: i32, const MAX: i32, const DEFAULT: i32>(i32);
123
124impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Default
125    for CompressionLevel<MIN, MAX, DEFAULT>
126{
127    fn default() -> Self {
128        CompressionLevel(DEFAULT)
129    }
130}
131
132impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> CompressionLevel<MIN, MAX, DEFAULT> {
133    /// Try creating a [`CompressionLevel`] from the provided value, returning an error if it is
134    /// outside the `MIN` and `MAX` bounds.
135    pub const fn try_new(val: i32) -> Result<Self, i32> {
136        if val < MIN {
137            Err(val)
138        } else if val > MAX {
139            Err(val)
140        } else {
141            Ok(CompressionLevel(val))
142        }
143    }
144
145    /// Parse a [`CompressionLevel`] form the provided string, returning an error if the string is
146    /// not valid.
147    pub fn from_str(s: &str) -> Result<Self, String> {
148        let val = s.parse::<i32>().map_err(|e| e.to_string())?;
149        Self::try_new(val).map_err(|e| e.to_string())
150    }
151}
152
153impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Arbitrary
154    for CompressionLevel<MIN, MAX, DEFAULT>
155{
156    type Parameters = ();
157    type Strategy = BoxedStrategy<CompressionLevel<MIN, MAX, DEFAULT>>;
158
159    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
160        ({ MIN }..={ MAX }).prop_map(CompressionLevel).boxed()
161    }
162}
163
164/// Encodes a set of [`Array`]s into Parquet.
165pub fn encode_arrays<W: Write + Send>(
166    w: &mut W,
167    fields: Fields,
168    arrays: Vec<Arc<dyn Array>>,
169    config: &EncodingConfig,
170) -> Result<(), anyhow::Error> {
171    let schema = Arc::new(ArrowSchema::new(fields));
172    let props = WriterProperties::builder()
173        .set_dictionary_enabled(config.use_dictionary)
174        .set_encoding(Encoding::PLAIN)
175        .set_statistics_enabled(EnabledStatistics::None)
176        .set_compression(config.compression.into())
177        .set_writer_version(WriterVersion::PARQUET_2_0)
178        .set_data_page_size_limit(1024 * 1024)
179        .set_max_row_group_size(usize::MAX)
180        .build();
181    let mut writer = ArrowWriter::try_new(w, Arc::clone(&schema), Some(props))?;
182
183    let record_batch = RecordBatch::try_new(schema, arrays)?;
184
185    writer.write(&record_batch)?;
186    writer.flush()?;
187    writer.close()?;
188
189    Ok(())
190}
191
192/// Decodes a [`RecordBatch`] from the provided reader.
193pub fn decode_arrays<R: ChunkReader + 'static>(
194    r: R,
195) -> Result<ParquetRecordBatchReader, anyhow::Error> {
196    let builder = ParquetRecordBatchReaderBuilder::try_new(r)?;
197
198    // To match arrow2, we default the batch size to the number of rows in the RowGroup.
199    let row_groups = builder.metadata().row_groups();
200    if row_groups.len() > 1 {
201        anyhow::bail!("found more than 1 RowGroup")
202    }
203    let num_rows = row_groups
204        .get(0)
205        .map(|g| g.num_rows())
206        .unwrap_or(1024)
207        .try_into()
208        .unwrap();
209    let builder = builder.with_batch_size(num_rows);
210
211    let reader = builder.build()?;
212    Ok(reader)
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[mz_ore::test]
220    fn smoketest_compression_level_parsing() {
221        let cases = &[
222            ("", CompressionFormat::None),
223            ("none", CompressionFormat::None),
224            ("snappy", CompressionFormat::Snappy),
225            ("lz4", CompressionFormat::Lz4),
226            ("lZ4", CompressionFormat::Lz4),
227            ("gzip-1", CompressionFormat::Gzip(CompressionLevel(1))),
228            ("GZIp-6", CompressionFormat::Gzip(CompressionLevel(6))),
229            ("gzip-9", CompressionFormat::Gzip(CompressionLevel(9))),
230            ("brotli-1", CompressionFormat::Brotli(CompressionLevel(1))),
231            ("BROtli-8", CompressionFormat::Brotli(CompressionLevel(8))),
232            ("brotli-11", CompressionFormat::Brotli(CompressionLevel(11))),
233            ("zstd-1", CompressionFormat::Zstd(CompressionLevel(1))),
234            ("zstD-10", CompressionFormat::Zstd(CompressionLevel(10))),
235            ("zstd-22", CompressionFormat::Zstd(CompressionLevel(22))),
236            ("foo", CompressionFormat::None),
237            // Invalid values that fallback to the default values.
238            ("gzip-0", CompressionFormat::Gzip(Default::default())),
239            ("gzip-10", CompressionFormat::Gzip(Default::default())),
240            ("brotli-0", CompressionFormat::Brotli(Default::default())),
241            ("brotli-12", CompressionFormat::Brotli(Default::default())),
242            ("zstd-0", CompressionFormat::Zstd(Default::default())),
243            ("zstd-23", CompressionFormat::Zstd(Default::default())),
244        ];
245        for (s, val) in cases {
246            assert_eq!(CompressionFormat::from_str(s), *val);
247        }
248    }
249}