Skip to main content

mz_persist_types/
parquet.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Parquet serialization and deserialization for persist data.
11
12use std::fmt::Debug;
13use std::io::Write;
14use std::sync::Arc;
15
16use arrow::array::{Array, RecordBatch};
17use arrow::datatypes::{Fields, Schema as ArrowSchema};
18use parquet::arrow::ArrowWriter;
19use parquet::arrow::arrow_reader::{ParquetRecordBatchReader, ParquetRecordBatchReaderBuilder};
20use parquet::basic::Encoding;
21use parquet::file::properties::{EnabledStatistics, WriterProperties, WriterVersion};
22use parquet::file::reader::ChunkReader;
23use proptest::prelude::*;
24use proptest_derive::Arbitrary;
25
26/// Configuration for encoding columnar data.
27#[derive(Debug, Copy, Clone, Arbitrary)]
28pub struct EncodingConfig {
29    /// Enable dictionary encoding for Parquet data.
30    pub use_dictionary: bool,
31    /// Compression format for Parquet data.
32    pub compression: CompressionFormat,
33}
34
35impl Default for EncodingConfig {
36    fn default() -> Self {
37        EncodingConfig {
38            use_dictionary: false,
39            compression: CompressionFormat::default(),
40        }
41    }
42}
43
44/// Compression format to apply to columnar data.
45#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, Arbitrary)]
46pub enum CompressionFormat {
47    /// No compression.
48    #[default]
49    None,
50    /// snappy
51    Snappy,
52    /// lz4
53    Lz4,
54    /// brotli
55    Brotli(CompressionLevel<1, 11, 1>),
56    /// zstd
57    Zstd(CompressionLevel<1, 22, 1>),
58    /// gzip
59    Gzip(CompressionLevel<1, 9, 6>),
60}
61
62impl CompressionFormat {
63    /// Parse a [`CompressionFormat`] from a string, falling back to defaults if the string is not valid.
64    pub fn from_str(s: &str) -> Self {
65        fn parse_level<const MIN: i32, const MAX: i32, const D: i32>(
66            name: &'static str,
67            val: &str,
68        ) -> CompressionLevel<MIN, MAX, D> {
69            match CompressionLevel::from_str(val) {
70                Ok(level) => level,
71                Err(err) => {
72                    tracing::error!("invalid {name} compression level, err: {err}");
73                    CompressionLevel::default()
74                }
75            }
76        }
77
78        match s.to_lowercase().as_str() {
79            "" => CompressionFormat::None,
80            "none" => CompressionFormat::None,
81            "snappy" => CompressionFormat::Snappy,
82            "lz4" => CompressionFormat::Lz4,
83            other => match other.split_once('-') {
84                Some(("brotli", level)) => CompressionFormat::Brotli(parse_level("brotli", level)),
85                Some(("zstd", level)) => CompressionFormat::Zstd(parse_level("zstd", level)),
86                Some(("gzip", level)) => CompressionFormat::Gzip(parse_level("gzip", level)),
87                _ => {
88                    tracing::error!("unrecognized compression format {s}, returning None");
89                    CompressionFormat::None
90                }
91            },
92        }
93    }
94}
95
96impl From<CompressionFormat> for parquet::basic::Compression {
97    fn from(value: CompressionFormat) -> Self {
98        match value {
99            CompressionFormat::None => parquet::basic::Compression::UNCOMPRESSED,
100            CompressionFormat::Lz4 => parquet::basic::Compression::LZ4_RAW,
101            CompressionFormat::Snappy => parquet::basic::Compression::SNAPPY,
102            CompressionFormat::Brotli(level) => {
103                let level: u32 = level.0.try_into().expect("known not negative");
104                let level = parquet::basic::BrotliLevel::try_new(level).expect("known valid");
105                parquet::basic::Compression::BROTLI(level)
106            }
107            CompressionFormat::Zstd(level) => {
108                let level = parquet::basic::ZstdLevel::try_new(level.0).expect("known valid");
109                parquet::basic::Compression::ZSTD(level)
110            }
111            CompressionFormat::Gzip(level) => {
112                let level: u32 = level.0.try_into().expect("known not negative");
113                let level = parquet::basic::GzipLevel::try_new(level).expect("known valid");
114                parquet::basic::Compression::GZIP(level)
115            }
116        }
117    }
118}
119
120/// Level of compression for columnar data.
121#[derive(Debug, Copy, Clone, PartialEq, Eq)]
122pub struct CompressionLevel<const MIN: i32, const MAX: i32, const DEFAULT: i32>(i32);
123
124impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Default
125    for CompressionLevel<MIN, MAX, DEFAULT>
126{
127    fn default() -> Self {
128        CompressionLevel(DEFAULT)
129    }
130}
131
132impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> CompressionLevel<MIN, MAX, DEFAULT> {
133    /// Try creating a [`CompressionLevel`] from the provided value, returning an error if it is
134    /// outside the `MIN` and `MAX` bounds.
135    pub const fn try_new(val: i32) -> Result<Self, i32> {
136        if val >= MIN && val <= MAX {
137            Ok(CompressionLevel(val))
138        } else {
139            Err(val)
140        }
141    }
142
143    /// Parse a [`CompressionLevel`] form the provided string, returning an error if the string is
144    /// not valid.
145    pub fn from_str(s: &str) -> Result<Self, String> {
146        let val = s.parse::<i32>().map_err(|e| e.to_string())?;
147        Self::try_new(val).map_err(|e| e.to_string())
148    }
149}
150
151impl<const MIN: i32, const MAX: i32, const DEFAULT: i32> Arbitrary
152    for CompressionLevel<MIN, MAX, DEFAULT>
153{
154    type Parameters = ();
155    type Strategy = BoxedStrategy<CompressionLevel<MIN, MAX, DEFAULT>>;
156
157    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
158        ({ MIN }..={ MAX }).prop_map(CompressionLevel).boxed()
159    }
160}
161
162/// Encodes a set of [`Array`]s into Parquet.
163pub fn encode_arrays<W: Write + Send>(
164    w: &mut W,
165    fields: Fields,
166    arrays: Vec<Arc<dyn Array>>,
167    config: &EncodingConfig,
168) -> Result<(), anyhow::Error> {
169    let schema = Arc::new(ArrowSchema::new(fields));
170    let props = WriterProperties::builder()
171        .set_dictionary_enabled(config.use_dictionary)
172        .set_encoding(Encoding::PLAIN)
173        .set_statistics_enabled(EnabledStatistics::None)
174        .set_compression(config.compression.into())
175        .set_writer_version(WriterVersion::PARQUET_2_0)
176        .set_data_page_size_limit(1024 * 1024)
177        .set_max_row_group_size(usize::MAX)
178        .build();
179    let mut writer = ArrowWriter::try_new(w, Arc::clone(&schema), Some(props))?;
180
181    let record_batch = RecordBatch::try_new(schema, arrays)?;
182
183    writer.write(&record_batch)?;
184    writer.flush()?;
185    writer.close()?;
186
187    Ok(())
188}
189
190/// Decodes a [`RecordBatch`] from the provided reader.
191pub fn decode_arrays<R: ChunkReader + 'static>(
192    r: R,
193) -> Result<ParquetRecordBatchReader, anyhow::Error> {
194    let builder = ParquetRecordBatchReaderBuilder::try_new(r)?;
195
196    // To match arrow2, we default the batch size to the number of rows in the RowGroup.
197    let row_groups = builder.metadata().row_groups();
198    if row_groups.len() > 1 {
199        anyhow::bail!("found more than 1 RowGroup")
200    }
201    let num_rows = row_groups
202        .get(0)
203        .map(|g| g.num_rows())
204        .unwrap_or(1024)
205        .try_into()
206        .unwrap();
207    let builder = builder.with_batch_size(num_rows);
208
209    let reader = builder.build()?;
210    Ok(reader)
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[mz_ore::test]
218    fn smoketest_compression_level_parsing() {
219        let cases = &[
220            ("", CompressionFormat::None),
221            ("none", CompressionFormat::None),
222            ("snappy", CompressionFormat::Snappy),
223            ("lz4", CompressionFormat::Lz4),
224            ("lZ4", CompressionFormat::Lz4),
225            ("gzip-1", CompressionFormat::Gzip(CompressionLevel(1))),
226            ("GZIp-6", CompressionFormat::Gzip(CompressionLevel(6))),
227            ("gzip-9", CompressionFormat::Gzip(CompressionLevel(9))),
228            ("brotli-1", CompressionFormat::Brotli(CompressionLevel(1))),
229            ("BROtli-8", CompressionFormat::Brotli(CompressionLevel(8))),
230            ("brotli-11", CompressionFormat::Brotli(CompressionLevel(11))),
231            ("zstd-1", CompressionFormat::Zstd(CompressionLevel(1))),
232            ("zstD-10", CompressionFormat::Zstd(CompressionLevel(10))),
233            ("zstd-22", CompressionFormat::Zstd(CompressionLevel(22))),
234            ("foo", CompressionFormat::None),
235            // Invalid values that fallback to the default values.
236            ("gzip-0", CompressionFormat::Gzip(Default::default())),
237            ("gzip-10", CompressionFormat::Gzip(Default::default())),
238            ("brotli-0", CompressionFormat::Brotli(Default::default())),
239            ("brotli-12", CompressionFormat::Brotli(Default::default())),
240            ("zstd-0", CompressionFormat::Zstd(Default::default())),
241            ("zstd-23", CompressionFormat::Zstd(Default::default())),
242        ];
243        for (s, val) in cases {
244            assert_eq!(CompressionFormat::from_str(s), *val);
245        }
246    }
247}