mz_storage/decode/
csv.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use mz_repr::{Datum, Row};
11use mz_storage_types::errors::DecodeErrorKind;
12use mz_storage_types::sources::encoding::CsvEncoding;
13
14#[derive(Debug)]
15pub struct CsvDecoderState {
16    next_row_is_header: bool,
17    header_names: Option<Vec<String>>,
18    n_cols: usize,
19    output: Vec<u8>,
20    output_cursor: usize,
21    ends: Vec<usize>,
22    ends_cursor: usize,
23    csv_reader: csv_core::Reader,
24    row_buf: Row,
25    events_error: usize,
26    events_success: usize,
27}
28
29impl CsvDecoderState {
30    fn total_events(&self) -> usize {
31        self.events_error + self.events_success
32    }
33
34    pub fn new(format: CsvEncoding) -> Self {
35        let CsvEncoding { columns, delimiter } = format;
36        let n_cols = columns.arity();
37
38        let header_names = columns.into_header_names();
39        Self {
40            next_row_is_header: header_names.is_some(),
41            header_names,
42            n_cols,
43            output: vec![0],
44            output_cursor: 0,
45            ends: vec![0],
46            ends_cursor: 1,
47            csv_reader: csv_core::ReaderBuilder::new().delimiter(delimiter).build(),
48            row_buf: Row::default(),
49            events_error: 0,
50            events_success: 0,
51        }
52    }
53
54    pub fn reset_for_new_object(&mut self) {
55        if self.header_names.is_some() {
56            self.next_row_is_header = true;
57        }
58    }
59
60    pub fn decode(&mut self, chunk: &mut &[u8]) -> Result<Option<Row>, DecodeErrorKind> {
61        loop {
62            let (result, n_input, n_output, n_ends) = self.csv_reader.read_record(
63                *chunk,
64                &mut self.output[self.output_cursor..],
65                &mut self.ends[self.ends_cursor..],
66            );
67            self.output_cursor += n_output;
68            *chunk = &(*chunk)[n_input..];
69            self.ends_cursor += n_ends;
70            match result {
71                // Error cases
72                csv_core::ReadRecordResult::InputEmpty => break Ok(None),
73                csv_core::ReadRecordResult::OutputFull => {
74                    let length = self.output.len();
75                    self.output.extend(std::iter::repeat(0).take(length));
76                }
77                csv_core::ReadRecordResult::OutputEndsFull => {
78                    let length = self.ends.len();
79                    self.ends.extend(std::iter::repeat(0).take(length));
80                }
81                // Success cases
82                csv_core::ReadRecordResult::Record | csv_core::ReadRecordResult::End => {
83                    let result = {
84                        let ends_valid = self.ends_cursor - 1;
85                        if ends_valid == 0 {
86                            break Ok(None);
87                        }
88                        if ends_valid != self.n_cols {
89                            self.events_error += 1;
90                            Err(DecodeErrorKind::Text(
91                                format!(
92                                    "CSV error at record number {}: expected {} columns, got {}.",
93                                    self.total_events(),
94                                    self.n_cols,
95                                    ends_valid
96                                )
97                                .into(),
98                            ))
99                        } else {
100                            match std::str::from_utf8(&self.output[0..self.output_cursor]) {
101                                Ok(output) => {
102                                    self.events_success += 1;
103                                    let mut row_packer = self.row_buf.packer();
104                                    row_packer.extend((0..self.n_cols).map(|i| {
105                                        Datum::String(&output[self.ends[i]..self.ends[i + 1]])
106                                    }));
107                                    self.output_cursor = 0;
108                                    self.ends_cursor = 1;
109                                    Ok(Some(self.row_buf.clone()))
110                                }
111                                Err(e) => {
112                                    self.events_error += 1;
113                                    Err(DecodeErrorKind::Text(
114                                        format!(
115                                            "CSV error at record number {}: invalid UTF-8 ({})",
116                                            self.total_events(),
117                                            e
118                                        )
119                                        .into(),
120                                    ))
121                                }
122                            }
123                        }
124                    };
125
126                    // skip header rows, do not send them into dataflow
127                    if self.next_row_is_header {
128                        self.next_row_is_header = false;
129
130                        if let Ok(Some(row)) = &result {
131                            let mismatched = row
132                                .iter()
133                                .zip(self.header_names.iter().flatten())
134                                .enumerate()
135                                .find(|(_, (actual, expected))| actual.unwrap_str() != &**expected);
136                            if let Some((i, (actual, expected))) = mismatched {
137                                break Err(DecodeErrorKind::Text(
138                                    format!(
139                                        "source file contains incorrect columns '{:?}', \
140                                     first mismatched column at index {} expected={} actual={}",
141                                        row,
142                                        i + 1,
143                                        expected,
144                                        actual
145                                    )
146                                    .into(),
147                                ));
148                            }
149                        }
150                        if chunk.is_empty() {
151                            break Ok(None);
152                        } else if result.is_err() {
153                            break result;
154                        }
155                    } else {
156                        break result;
157                    }
158                }
159            }
160        }
161    }
162}