mz_storage_operators/oneshot_source/
csv.rs1use std::fmt::Debug;
13use std::io;
14use std::pin::Pin;
15use std::sync::Arc;
16
17use futures::TryStreamExt;
18use futures::stream::{BoxStream, StreamExt};
19use itertools::Itertools;
20use mz_pgcopy::CopyCsvFormatParams;
21use mz_repr::{Datum, RelationDesc, Row, RowArena};
22use serde::{Deserialize, Serialize};
23use smallvec::{SmallVec, smallvec};
24use tokio_util::io::StreamReader;
25
26use crate::oneshot_source::{
27 Encoding, OneshotFormat, OneshotObject, OneshotSource, StorageErrorX, StorageErrorXKind,
28};
29
30#[derive(Debug, Clone)]
31pub struct CsvDecoder {
32 params: CopyCsvFormatParams<'static>,
34 column_types: Arc<[mz_pgrepr::Type]>,
36}
37
38impl CsvDecoder {
39 pub fn new(params: CopyCsvFormatParams<'static>, desc: &RelationDesc) -> Self {
40 let column_types = desc
41 .iter_types()
42 .map(|x| &x.scalar_type)
43 .map(mz_pgrepr::Type::from)
44 .collect();
45 CsvDecoder {
46 params,
47 column_types,
48 }
49 }
50}
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct CsvWorkRequest<O, C> {
55 object: O,
56 checksum: C,
57 encodings: SmallVec<[Encoding; 1]>,
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
61pub struct CsvRecord {
62 bytes: Vec<u8>,
63 ranges: Vec<std::ops::Range<usize>>,
64}
65
66impl OneshotFormat for CsvDecoder {
67 type WorkRequest<S>
68 = CsvWorkRequest<S::Object, S::Checksum>
69 where
70 S: OneshotSource;
71 type RecordChunk = CsvRecord;
72
73 async fn split_work<S: OneshotSource + Send>(
74 &self,
75 _source: S,
76 object: S::Object,
77 checksum: S::Checksum,
78 ) -> Result<Vec<Self::WorkRequest<S>>, StorageErrorX> {
79 let encodings = if object.name().ends_with(".gz") {
91 smallvec![Encoding::Gzip]
92 } else if object.name().ends_with(".bz2") {
93 smallvec![Encoding::Bzip2]
94 } else if object.name().ends_with(".xz") {
95 smallvec![Encoding::Xz]
96 } else if object.name().ends_with(".zst") {
97 smallvec![Encoding::Zstd]
98 } else {
99 smallvec![]
100 };
101
102 let request = CsvWorkRequest {
103 object,
104 checksum,
105 encodings,
106 };
107 Ok(vec![request])
108 }
109
110 fn fetch_work<'a, S: OneshotSource + Sync + 'static>(
111 &'a self,
112 source: &'a S,
113 request: Self::WorkRequest<S>,
114 ) -> BoxStream<'a, Result<Self::RecordChunk, StorageErrorX>> {
115 let CsvWorkRequest {
116 object,
117 checksum,
118 encodings,
119 } = request;
120
121 let raw_byte_stream = source
123 .get(object, checksum, None)
124 .map_err(|e| io::Error::new(io::ErrorKind::Interrupted, format!("{e:?}")));
125 let stream_reader = StreamReader::new(raw_byte_stream);
126
127 assert!(encodings.len() <= 1, "TODO support multiple encodings");
129
130 let reader: Pin<Box<dyn tokio::io::AsyncRead + Send>> = if let Some(encoding) =
132 encodings.into_iter().next()
133 {
134 tracing::info!(?encoding, "decompressing byte stream");
135 match encoding {
136 Encoding::Bzip2 => {
137 let decoder = async_compression::tokio::bufread::BzDecoder::new(stream_reader);
138 Box::pin(decoder)
139 }
140 Encoding::Gzip => {
141 let decoder =
142 async_compression::tokio::bufread::GzipDecoder::new(stream_reader);
143 Box::pin(decoder)
144 }
145 Encoding::Xz => {
146 let decoder = async_compression::tokio::bufread::XzDecoder::new(stream_reader);
147 Box::pin(decoder)
148 }
149 Encoding::Zstd => {
150 let decoder =
151 async_compression::tokio::bufread::ZstdDecoder::new(stream_reader);
152 Box::pin(decoder)
153 }
154 }
155 } else {
156 Box::pin(stream_reader)
157 };
158
159 let (double_quote, escape) = if self.params.quote == self.params.escape {
160 (true, None)
161 } else {
162 (false, Some(self.params.escape))
163 };
164
165 let reader = csv_async::AsyncReaderBuilder::new()
167 .delimiter(self.params.delimiter)
168 .quote(self.params.quote)
169 .has_headers(self.params.header)
170 .double_quote(double_quote)
171 .escape(escape)
172 .flexible(true)
175 .create_reader(reader);
176
177 reader
179 .into_byte_records()
180 .map_ok(|record| {
181 let bytes = record.as_slice().to_vec();
182 let ranges = (0..record.len())
183 .map(|idx| record.range(idx).expect("known to exist"))
184 .collect();
185 CsvRecord { bytes, ranges }
186 })
187 .map_err(|err| StorageErrorXKind::from(err).with_context("csv decoding"))
188 .boxed()
189 }
190
191 fn decode_chunk(
192 &self,
193 chunk: Self::RecordChunk,
194 rows: &mut Vec<Row>,
195 ) -> Result<usize, StorageErrorX> {
196 let CsvRecord { bytes, ranges } = chunk;
197
198 if self.column_types.len() != ranges.len() {
200 let msg = format!(
201 "wrong number of columns, desc: {} record: {}",
202 self.column_types.len(),
203 ranges.len()
204 );
205 return Err(StorageErrorXKind::invalid_record_batch(msg).into());
206 }
207
208 let str_slices = ranges.into_iter().map(|range| {
209 bytes
210 .get(range)
211 .ok_or_else(|| StorageErrorXKind::programming_error("invalid byte range"))
212 });
213
214 let mut row = Row::default();
216 let mut packer = row.packer();
217 let arena = RowArena::new();
218
219 for (typ, maybe_raw_value) in self.column_types.iter().zip_eq(str_slices) {
220 let raw_value = maybe_raw_value?;
221
222 if raw_value == self.params.null.as_bytes() {
223 packer.push(Datum::Null);
224 } else {
225 let value = mz_pgrepr::Value::decode_text(typ, raw_value).map_err(|err| {
226 StorageErrorXKind::invalid_record_batch(err.to_string())
227 .with_context("decode_text")
228 })?;
229 packer.push(value.into_datum(&arena, typ));
230 }
231 }
232
233 rows.push(row);
234
235 Ok(1)
236 }
237}