mz_storage_operators/oneshot_source/
csv.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! CSV to Row Decoder.
11
12use std::fmt::Debug;
13use std::io;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use futures::TryStreamExt;
18use futures::stream::{BoxStream, StreamExt};
19use itertools::Itertools;
20use mz_pgcopy::CopyCsvFormatParams;
21use mz_repr::{Datum, RelationDesc, Row, RowArena};
22use serde::{Deserialize, Serialize};
23use smallvec::{SmallVec, smallvec};
24use tokio_util::io::StreamReader;
25
26use crate::oneshot_source::{
27    Encoding, OneshotFormat, OneshotObject, OneshotSource, StorageErrorX, StorageErrorXKind,
28};
29
30#[derive(Debug, Clone)]
31pub struct CsvDecoder {
32    /// Properties of the CSV Reader.
33    params: CopyCsvFormatParams<'static>,
34    /// Types of the table we're copying into.
35    column_types: Arc<[mz_pgrepr::Type]>,
36}
37
38impl CsvDecoder {
39    pub fn new(params: CopyCsvFormatParams<'static>, desc: &RelationDesc) -> Self {
40        let column_types = desc
41            .iter_types()
42            .map(|x| &x.scalar_type)
43            .map(mz_pgrepr::Type::from)
44            .collect();
45        CsvDecoder {
46            params,
47            column_types,
48        }
49    }
50}
51
52/// Instructions on how to parse a single CSV file.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct CsvWorkRequest<O, C> {
55    object: O,
56    checksum: C,
57    encodings: SmallVec<[Encoding; 1]>,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct CsvRecord {
62    bytes: Vec<u8>,
63    ranges: Vec<std::ops::Range<usize>>,
64}
65
66impl OneshotFormat for CsvDecoder {
67    type WorkRequest<S>
68        = CsvWorkRequest<S::Object, S::Checksum>
69    where
70        S: OneshotSource;
71    type RecordChunk = CsvRecord;
72
73    async fn split_work<S: OneshotSource + Send>(
74        &self,
75        _source: S,
76        object: S::Object,
77        checksum: S::Checksum,
78    ) -> Result<Vec<Self::WorkRequest<S>>, StorageErrorX> {
79        // Decoding a CSV in parallel is hard.
80        //
81        // TODO(cf3): If necessary, we can get a 2x speedup by parsing a CSV
82        // from the start and end in parallel, and meeting in the middle.
83        //
84        // See <https://badrish.net/papers/dp-sigmod19.pdf> for general parallelization strategies.
85
86        // TODO(cf1): Check the encodings from the object to determine
87        // what decompression to apply. Also support the user manually
88        // specifying certain encodings.
89
90        let encodings = if object.name().ends_with(".gz") {
91            smallvec![Encoding::Gzip]
92        } else if object.name().ends_with(".bz2") {
93            smallvec![Encoding::Bzip2]
94        } else if object.name().ends_with(".xz") {
95            smallvec![Encoding::Xz]
96        } else if object.name().ends_with(".zst") {
97            smallvec![Encoding::Zstd]
98        } else {
99            smallvec![]
100        };
101
102        let request = CsvWorkRequest {
103            object,
104            checksum,
105            encodings,
106        };
107        Ok(vec![request])
108    }
109
110    fn fetch_work<'a, S: OneshotSource + Sync + 'static>(
111        &'a self,
112        source: &'a S,
113        request: Self::WorkRequest<S>,
114    ) -> BoxStream<'a, Result<Self::RecordChunk, StorageErrorX>> {
115        let CsvWorkRequest {
116            object,
117            checksum,
118            encodings,
119        } = request;
120
121        // Wrap our `Stream<Bytes>` into a type that implements `tokio::io::AsyncRead`.
122        let raw_byte_stream = source
123            .get(object, checksum, None)
124            .map_err(|e| io::Error::new(io::ErrorKind::Interrupted, format!("{e:?}")));
125        let stream_reader = StreamReader::new(raw_byte_stream);
126
127        // TODO(cf3): Support multiple encodings.
128        assert!(encodings.len() <= 1, "TODO support multiple encodings");
129
130        // Decompress the byte stream, if necessary.
131        let reader: Pin<Box<dyn tokio::io::AsyncRead + Send>> = if let Some(encoding) =
132            encodings.into_iter().next()
133        {
134            tracing::info!(?encoding, "decompressing byte stream");
135            match encoding {
136                Encoding::Bzip2 => {
137                    let decoder = async_compression::tokio::bufread::BzDecoder::new(stream_reader);
138                    Box::pin(decoder)
139                }
140                Encoding::Gzip => {
141                    let decoder =
142                        async_compression::tokio::bufread::GzipDecoder::new(stream_reader);
143                    Box::pin(decoder)
144                }
145                Encoding::Xz => {
146                    let decoder = async_compression::tokio::bufread::XzDecoder::new(stream_reader);
147                    Box::pin(decoder)
148                }
149                Encoding::Zstd => {
150                    let decoder =
151                        async_compression::tokio::bufread::ZstdDecoder::new(stream_reader);
152                    Box::pin(decoder)
153                }
154            }
155        } else {
156            Box::pin(stream_reader)
157        };
158
159        let (double_quote, escape) = if self.params.quote == self.params.escape {
160            (true, None)
161        } else {
162            (false, Some(self.params.escape))
163        };
164
165        // Configure our CSV reader.
166        let reader = csv_async::AsyncReaderBuilder::new()
167            .delimiter(self.params.delimiter)
168            .quote(self.params.quote)
169            .has_headers(self.params.header)
170            .double_quote(double_quote)
171            .escape(escape)
172            // Be maximally permissive. If there is a a record with the wrong
173            // number of columns Row decoding will error, if we care.
174            .flexible(true)
175            .create_reader(reader);
176
177        // Return a stream of records.
178        reader
179            .into_byte_records()
180            .map_ok(|record| {
181                let bytes = record.as_slice().to_vec();
182                let ranges = (0..record.len())
183                    .map(|idx| record.range(idx).expect("known to exist"))
184                    .collect();
185                CsvRecord { bytes, ranges }
186            })
187            .map_err(|err| StorageErrorXKind::from(err).with_context("csv decoding"))
188            .boxed()
189    }
190
191    fn decode_chunk(
192        &self,
193        chunk: Self::RecordChunk,
194        rows: &mut Vec<Row>,
195    ) -> Result<usize, StorageErrorX> {
196        let CsvRecord { bytes, ranges } = chunk;
197
198        // Make sure the CSV record has the correct number of columns.
199        if self.column_types.len() != ranges.len() {
200            let msg = format!(
201                "wrong number of columns, desc: {} record: {}",
202                self.column_types.len(),
203                ranges.len()
204            );
205            return Err(StorageErrorXKind::invalid_record_batch(msg).into());
206        }
207
208        let str_slices = ranges.into_iter().map(|range| {
209            bytes
210                .get(range)
211                .ok_or_else(|| StorageErrorXKind::programming_error("invalid byte range"))
212        });
213
214        // Decode a Row from the CSV record.
215        let mut row = Row::default();
216        let mut packer = row.packer();
217        let arena = RowArena::new();
218
219        for (typ, maybe_raw_value) in self.column_types.iter().zip_eq(str_slices) {
220            let raw_value = maybe_raw_value?;
221
222            if raw_value == self.params.null.as_bytes() {
223                packer.push(Datum::Null);
224            } else {
225                let value = mz_pgrepr::Value::decode_text(typ, raw_value).map_err(|err| {
226                    StorageErrorXKind::invalid_record_batch(err.to_string())
227                        .with_context("decode_text")
228                })?;
229                packer.push(value.into_datum(&arena, typ));
230            }
231        }
232
233        rows.push(row);
234
235        Ok(1)
236    }
237}