mz_storage/decode/
csv.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use itertools::Itertools;
11use mz_repr::{Datum, Row};
12use mz_storage_types::errors::DecodeErrorKind;
13use mz_storage_types::sources::encoding::CsvEncoding;
14
15#[derive(Debug)]
16pub struct CsvDecoderState {
17    next_row_is_header: bool,
18    header_names: Option<Vec<String>>,
19    n_cols: usize,
20    output: Vec<u8>,
21    output_cursor: usize,
22    ends: Vec<usize>,
23    ends_cursor: usize,
24    csv_reader: csv_core::Reader,
25    row_buf: Row,
26    events_error: usize,
27    events_success: usize,
28}
29
30impl CsvDecoderState {
31    fn total_events(&self) -> usize {
32        self.events_error + self.events_success
33    }
34
35    pub fn new(format: CsvEncoding) -> Self {
36        let CsvEncoding { columns, delimiter } = format;
37        let n_cols = columns.arity();
38
39        let header_names = columns.into_header_names();
40        Self {
41            next_row_is_header: header_names.is_some(),
42            header_names,
43            n_cols,
44            output: vec![0],
45            output_cursor: 0,
46            ends: vec![0],
47            ends_cursor: 1,
48            csv_reader: csv_core::ReaderBuilder::new().delimiter(delimiter).build(),
49            row_buf: Row::default(),
50            events_error: 0,
51            events_success: 0,
52        }
53    }
54
55    pub fn reset_for_new_object(&mut self) {
56        if self.header_names.is_some() {
57            self.next_row_is_header = true;
58        }
59    }
60
61    pub fn decode(&mut self, chunk: &mut &[u8]) -> Result<Option<Row>, DecodeErrorKind> {
62        loop {
63            let (result, n_input, n_output, n_ends) = self.csv_reader.read_record(
64                *chunk,
65                &mut self.output[self.output_cursor..],
66                &mut self.ends[self.ends_cursor..],
67            );
68            self.output_cursor += n_output;
69            *chunk = &(*chunk)[n_input..];
70            self.ends_cursor += n_ends;
71            match result {
72                // Error cases
73                csv_core::ReadRecordResult::InputEmpty => break Ok(None),
74                csv_core::ReadRecordResult::OutputFull => {
75                    let length = self.output.len();
76                    self.output.extend(std::iter::repeat(0).take(length));
77                }
78                csv_core::ReadRecordResult::OutputEndsFull => {
79                    let length = self.ends.len();
80                    self.ends.extend(std::iter::repeat(0).take(length));
81                }
82                // Success cases
83                csv_core::ReadRecordResult::Record | csv_core::ReadRecordResult::End => {
84                    let result = {
85                        let ends_valid = self.ends_cursor - 1;
86                        if ends_valid == 0 {
87                            break Ok(None);
88                        }
89                        if ends_valid != self.n_cols {
90                            self.events_error += 1;
91                            Err(DecodeErrorKind::Text(
92                                format!(
93                                    "CSV error at record number {}: expected {} columns, got {}.",
94                                    self.total_events(),
95                                    self.n_cols,
96                                    ends_valid
97                                )
98                                .into(),
99                            ))
100                        } else {
101                            match std::str::from_utf8(&self.output[0..self.output_cursor]) {
102                                Ok(output) => {
103                                    self.events_success += 1;
104                                    let mut row_packer = self.row_buf.packer();
105                                    row_packer.extend((0..self.n_cols).map(|i| {
106                                        Datum::String(&output[self.ends[i]..self.ends[i + 1]])
107                                    }));
108                                    self.output_cursor = 0;
109                                    self.ends_cursor = 1;
110                                    Ok(Some(self.row_buf.clone()))
111                                }
112                                Err(e) => {
113                                    self.events_error += 1;
114                                    Err(DecodeErrorKind::Text(
115                                        format!(
116                                            "CSV error at record number {}: invalid UTF-8 ({})",
117                                            self.total_events(),
118                                            e
119                                        )
120                                        .into(),
121                                    ))
122                                }
123                            }
124                        }
125                    };
126
127                    // skip header rows, do not send them into dataflow
128                    if self.next_row_is_header {
129                        self.next_row_is_header = false;
130
131                        if let Ok(Some(row)) = &result {
132                            let mismatched = row
133                                .iter()
134                                .zip_eq(self.header_names.iter().flatten())
135                                .enumerate()
136                                .find(|(_, (actual, expected))| actual.unwrap_str() != &**expected);
137                            if let Some((i, (actual, expected))) = mismatched {
138                                break Err(DecodeErrorKind::Text(
139                                    format!(
140                                        "source file contains incorrect columns '{:?}', \
141                                     first mismatched column at index {} expected={} actual={}",
142                                        row,
143                                        i + 1,
144                                        expected,
145                                        actual
146                                    )
147                                    .into(),
148                                ));
149                            }
150                        }
151                        if chunk.is_empty() {
152                            break Ok(None);
153                        } else if result.is_err() {
154                            break result;
155                        }
156                    } else {
157                        break result;
158                    }
159                }
160            }
161        }
162    }
163}