1use itertools::Itertools;
11use mz_repr::{Datum, Row};
12use mz_storage_types::errors::DecodeErrorKind;
13use mz_storage_types::sources::encoding::CsvEncoding;
14
15#[derive(Debug)]
16pub struct CsvDecoderState {
17 next_row_is_header: bool,
18 header_names: Option<Vec<String>>,
19 n_cols: usize,
20 output: Vec<u8>,
21 output_cursor: usize,
22 ends: Vec<usize>,
23 ends_cursor: usize,
24 csv_reader: csv_core::Reader,
25 row_buf: Row,
26 events_error: usize,
27 events_success: usize,
28}
29
30impl CsvDecoderState {
31 fn total_events(&self) -> usize {
32 self.events_error + self.events_success
33 }
34
35 pub fn new(format: CsvEncoding) -> Self {
36 let CsvEncoding { columns, delimiter } = format;
37 let n_cols = columns.arity();
38
39 let header_names = columns.into_header_names();
40 Self {
41 next_row_is_header: header_names.is_some(),
42 header_names,
43 n_cols,
44 output: vec![0],
45 output_cursor: 0,
46 ends: vec![0],
47 ends_cursor: 1,
48 csv_reader: csv_core::ReaderBuilder::new().delimiter(delimiter).build(),
49 row_buf: Row::default(),
50 events_error: 0,
51 events_success: 0,
52 }
53 }
54
55 pub fn reset_for_new_object(&mut self) {
56 if self.header_names.is_some() {
57 self.next_row_is_header = true;
58 }
59 }
60
61 pub fn decode(&mut self, chunk: &mut &[u8]) -> Result<Option<Row>, DecodeErrorKind> {
62 loop {
63 let (result, n_input, n_output, n_ends) = self.csv_reader.read_record(
64 *chunk,
65 &mut self.output[self.output_cursor..],
66 &mut self.ends[self.ends_cursor..],
67 );
68 self.output_cursor += n_output;
69 *chunk = &(*chunk)[n_input..];
70 self.ends_cursor += n_ends;
71 match result {
72 csv_core::ReadRecordResult::InputEmpty => break Ok(None),
74 csv_core::ReadRecordResult::OutputFull => {
75 let length = self.output.len();
76 self.output.extend(std::iter::repeat(0).take(length));
77 }
78 csv_core::ReadRecordResult::OutputEndsFull => {
79 let length = self.ends.len();
80 self.ends.extend(std::iter::repeat(0).take(length));
81 }
82 csv_core::ReadRecordResult::Record | csv_core::ReadRecordResult::End => {
84 let result = {
85 let ends_valid = self.ends_cursor - 1;
86 if ends_valid == 0 {
87 break Ok(None);
88 }
89 if ends_valid != self.n_cols {
90 self.events_error += 1;
91 Err(DecodeErrorKind::Text(
92 format!(
93 "CSV error at record number {}: expected {} columns, got {}.",
94 self.total_events(),
95 self.n_cols,
96 ends_valid
97 )
98 .into(),
99 ))
100 } else {
101 match std::str::from_utf8(&self.output[0..self.output_cursor]) {
102 Ok(output) => {
103 self.events_success += 1;
104 let mut row_packer = self.row_buf.packer();
105 row_packer.extend((0..self.n_cols).map(|i| {
106 Datum::String(&output[self.ends[i]..self.ends[i + 1]])
107 }));
108 self.output_cursor = 0;
109 self.ends_cursor = 1;
110 Ok(Some(self.row_buf.clone()))
111 }
112 Err(e) => {
113 self.events_error += 1;
114 Err(DecodeErrorKind::Text(
115 format!(
116 "CSV error at record number {}: invalid UTF-8 ({})",
117 self.total_events(),
118 e
119 )
120 .into(),
121 ))
122 }
123 }
124 }
125 };
126
127 if self.next_row_is_header {
129 self.next_row_is_header = false;
130
131 if let Ok(Some(row)) = &result {
132 let mismatched = row
133 .iter()
134 .zip_eq(self.header_names.iter().flatten())
135 .enumerate()
136 .find(|(_, (actual, expected))| actual.unwrap_str() != &**expected);
137 if let Some((i, (actual, expected))) = mismatched {
138 break Err(DecodeErrorKind::Text(
139 format!(
140 "source file contains incorrect columns '{:?}', \
141 first mismatched column at index {} expected={} actual={}",
142 row,
143 i + 1,
144 expected,
145 actual
146 )
147 .into(),
148 ));
149 }
150 }
151 if chunk.is_empty() {
152 break Ok(None);
153 } else if result.is_err() {
154 break result;
155 }
156 } else {
157 break result;
158 }
159 }
160 }
161 }
162 }
163}