mz_storage_operators/oneshot_source/
csv.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! CSV to Row Decoder.
11
12use std::fmt::Debug;
13use std::io;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use futures::TryStreamExt;
18use futures::stream::{BoxStream, StreamExt};
19use mz_pgcopy::CopyCsvFormatParams;
20use mz_repr::{Datum, RelationDesc, Row, RowArena};
21use serde::{Deserialize, Serialize};
22use smallvec::{SmallVec, smallvec};
23use tokio_util::io::StreamReader;
24
25use crate::oneshot_source::{
26    Encoding, OneshotFormat, OneshotObject, OneshotSource, StorageErrorX, StorageErrorXKind,
27};
28
29#[derive(Debug, Clone)]
30pub struct CsvDecoder {
31    /// Properties of the CSV Reader.
32    params: CopyCsvFormatParams<'static>,
33    /// Types of the table we're copying into.
34    column_types: Arc<[mz_pgrepr::Type]>,
35}
36
37impl CsvDecoder {
38    pub fn new(params: CopyCsvFormatParams<'static>, desc: &RelationDesc) -> Self {
39        let column_types = desc
40            .iter_types()
41            .map(|x| &x.scalar_type)
42            .map(mz_pgrepr::Type::from)
43            .collect();
44        CsvDecoder {
45            params,
46            column_types,
47        }
48    }
49}
50
51/// Instructions on how to parse a single CSV file.
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct CsvWorkRequest<O, C> {
54    object: O,
55    checksum: C,
56    encodings: SmallVec<[Encoding; 1]>,
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize)]
60pub struct CsvRecord {
61    bytes: Vec<u8>,
62    ranges: Vec<std::ops::Range<usize>>,
63}
64
65impl OneshotFormat for CsvDecoder {
66    type WorkRequest<S>
67        = CsvWorkRequest<S::Object, S::Checksum>
68    where
69        S: OneshotSource;
70    type RecordChunk = CsvRecord;
71
72    async fn split_work<S: OneshotSource + Send>(
73        &self,
74        _source: S,
75        object: S::Object,
76        checksum: S::Checksum,
77    ) -> Result<Vec<Self::WorkRequest<S>>, StorageErrorX> {
78        // Decoding a CSV in parallel is hard.
79        //
80        // TODO(cf3): If necessary, we can get a 2x speedup by parsing a CSV
81        // from the start and end in parallel, and meeting in the middle.
82        //
83        // See <https://badrish.net/papers/dp-sigmod19.pdf> for general parallelization strategies.
84
85        // TODO(cf1): Check the encodings from the object to determine
86        // what decompression to apply. Also support the user manually
87        // specifying certain encodings.
88
89        let encodings = if object.name().ends_with(".gz") {
90            smallvec![Encoding::Gzip]
91        } else if object.name().ends_with(".bz2") {
92            smallvec![Encoding::Bzip2]
93        } else if object.name().ends_with(".xz") {
94            smallvec![Encoding::Xz]
95        } else if object.name().ends_with(".zst") {
96            smallvec![Encoding::Zstd]
97        } else {
98            smallvec![]
99        };
100
101        let request = CsvWorkRequest {
102            object,
103            checksum,
104            encodings,
105        };
106        Ok(vec![request])
107    }
108
109    fn fetch_work<'a, S: OneshotSource + Sync + 'static>(
110        &'a self,
111        source: &'a S,
112        request: Self::WorkRequest<S>,
113    ) -> BoxStream<'a, Result<Self::RecordChunk, StorageErrorX>> {
114        let CsvWorkRequest {
115            object,
116            checksum,
117            encodings,
118        } = request;
119
120        // Wrap our `Stream<Bytes>` into a type that implements `tokio::io::AsyncRead`.
121        let raw_byte_stream = source
122            .get(object, checksum, None)
123            .map_err(|e| io::Error::new(io::ErrorKind::Interrupted, format!("{e:?}")));
124        let stream_reader = StreamReader::new(raw_byte_stream);
125
126        // TODO(cf3): Support multiple encodings.
127        assert!(encodings.len() <= 1, "TODO support multiple encodings");
128
129        // Decompress the byte stream, if necessary.
130        let reader: Pin<Box<dyn tokio::io::AsyncRead + Send>> = if let Some(encoding) =
131            encodings.into_iter().next()
132        {
133            tracing::info!(?encoding, "decompressing byte stream");
134            match encoding {
135                Encoding::Bzip2 => {
136                    let decoder = async_compression::tokio::bufread::BzDecoder::new(stream_reader);
137                    Box::pin(decoder)
138                }
139                Encoding::Gzip => {
140                    let decoder =
141                        async_compression::tokio::bufread::GzipDecoder::new(stream_reader);
142                    Box::pin(decoder)
143                }
144                Encoding::Xz => {
145                    let decoder = async_compression::tokio::bufread::XzDecoder::new(stream_reader);
146                    Box::pin(decoder)
147                }
148                Encoding::Zstd => {
149                    let decoder =
150                        async_compression::tokio::bufread::ZstdDecoder::new(stream_reader);
151                    Box::pin(decoder)
152                }
153            }
154        } else {
155            Box::pin(stream_reader)
156        };
157
158        let (double_quote, escape) = if self.params.quote == self.params.escape {
159            (true, None)
160        } else {
161            (false, Some(self.params.escape))
162        };
163
164        // Configure our CSV reader.
165        let reader = csv_async::AsyncReaderBuilder::new()
166            .delimiter(self.params.delimiter)
167            .quote(self.params.quote)
168            .has_headers(self.params.header)
169            .double_quote(double_quote)
170            .escape(escape)
171            // Be maximally permissive. If there is a a record with the wrong
172            // number of columns Row decoding will error, if we care.
173            .flexible(true)
174            .create_reader(reader);
175
176        // Return a stream of records.
177        reader
178            .into_byte_records()
179            .map_ok(|record| {
180                let bytes = record.as_slice().to_vec();
181                let ranges = (0..record.len())
182                    .map(|idx| record.range(idx).expect("known to exist"))
183                    .collect();
184                CsvRecord { bytes, ranges }
185            })
186            .map_err(|err| StorageErrorXKind::from(err).with_context("csv decoding"))
187            .boxed()
188    }
189
190    fn decode_chunk(
191        &self,
192        chunk: Self::RecordChunk,
193        rows: &mut Vec<Row>,
194    ) -> Result<usize, StorageErrorX> {
195        let CsvRecord { bytes, ranges } = chunk;
196
197        // Make sure the CSV record has the correct number of columns.
198        if self.column_types.len() != ranges.len() {
199            let msg = format!(
200                "wrong number of columns, desc: {} record: {}",
201                self.column_types.len(),
202                ranges.len()
203            );
204            return Err(StorageErrorXKind::invalid_record_batch(msg).into());
205        }
206
207        let str_slices = ranges.into_iter().map(|range| {
208            bytes
209                .get(range)
210                .ok_or_else(|| StorageErrorXKind::programming_error("invalid byte range"))
211        });
212
213        // Decode a Row from the CSV record.
214        let mut row = Row::default();
215        let mut packer = row.packer();
216        let arena = RowArena::new();
217
218        for (typ, maybe_raw_value) in self.column_types.iter().zip(str_slices) {
219            let raw_value = maybe_raw_value?;
220
221            if raw_value == self.params.null.as_bytes() {
222                packer.push(Datum::Null);
223            } else {
224                let value = mz_pgrepr::Value::decode_text(typ, raw_value).map_err(|err| {
225                    StorageErrorXKind::invalid_record_batch(err.to_string())
226                        .with_context("decode_text")
227                })?;
228                packer.push(value.into_datum(&arena, typ));
229            }
230        }
231
232        rows.push(row);
233
234        Ok(1)
235    }
236}