1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use itertools::Itertools;
15use mz_repr::{
16 Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
17 SqlScalarType,
18};
19use proptest::prelude::{Arbitrary, any};
20use proptest::strategy::{BoxedStrategy, Strategy};
21use serde::{Deserialize, Serialize};
22
23static END_OF_COPY_MARKER: &[u8] = b"\\.";
24
25fn encode_copy_row_binary(
26 row: &RowRef,
27 typ: &SqlRelationType,
28 out: &mut Vec<u8>,
29) -> Result<(), io::Error> {
30 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
31
32 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
34 io::Error::new(
35 io::ErrorKind::InvalidData,
36 "column count does not fit into an i16",
37 )
38 })?;
39
40 out.extend(count.to_be_bytes());
41 let mut buf = BytesMut::new();
42 for (field, typ) in row
43 .iter()
44 .zip_eq(&typ.column_types)
45 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
46 {
47 match field {
48 None => out.extend(NULL_BYTES),
49 Some(field) => {
50 buf.clear();
51 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
52 out.extend(
53 i32::try_from(buf.len())
54 .map_err(|_| {
55 io::Error::new(
56 io::ErrorKind::InvalidData,
57 "field length does not fit into an i32",
58 )
59 })?
60 .to_be_bytes(),
61 );
62 out.extend(&buf);
63 }
64 }
65 }
66 Ok(())
67}
68
69fn encode_copy_row_text(
70 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
71 row: &RowRef,
72 typ: &SqlRelationType,
73 out: &mut Vec<u8>,
74) -> Result<(), io::Error> {
75 let null = null.as_bytes();
76 let mut buf = BytesMut::new();
77 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
78 if idx > 0 {
79 out.push(*delimiter);
80 }
81 match field {
82 None => out.extend(null),
83 Some(field) => {
84 buf.clear();
85 field.encode_text(&mut buf);
86 for b in &buf {
87 match b {
88 b'\\' => out.extend(b"\\\\"),
89 b'\n' => out.extend(b"\\n"),
90 b'\r' => out.extend(b"\\r"),
91 b'\t' => out.extend(b"\\t"),
92 _ => out.push(*b),
93 }
94 }
95 }
96 }
97 }
98 out.push(b'\n');
99 Ok(())
100}
101
102fn encode_copy_row_csv(
103 CopyCsvFormatParams {
104 delimiter: delim,
105 quote,
106 escape,
107 header: _,
108 null,
109 }: &CopyCsvFormatParams,
110 row: &RowRef,
111 typ: &SqlRelationType,
112 out: &mut Vec<u8>,
113) -> Result<(), io::Error> {
114 let null = null.as_bytes();
115 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
116 let mut buf = BytesMut::new();
117 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
118 if idx > 0 {
119 out.push(*delim);
120 }
121 match field {
122 None => out.extend(null),
123 Some(field) => {
124 buf.clear();
125 field.encode_text(&mut buf);
126 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
132 || buf.iter().any(is_special)
133 || &*buf == null
134 {
135 out.push(*quote);
139 for b in &buf {
140 if *b == *quote || *b == *escape {
141 out.push(*escape);
142 }
143 out.push(*b);
144 }
145 out.push(*quote);
146 } else {
147 out.extend(&buf);
150 }
151 }
152 }
153 }
154 out.push(b'\n');
155 Ok(())
156}
157
158pub struct CopyTextFormatParser<'a> {
159 data: &'a [u8],
160 position: usize,
161 column_delimiter: u8,
162 null_string: &'a str,
163 buffer: Vec<u8>,
164}
165
166impl<'a> CopyTextFormatParser<'a> {
167 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
168 Self {
169 data,
170 position: 0,
171 column_delimiter,
172 null_string,
173 buffer: Vec::new(),
174 }
175 }
176
177 fn peek(&self) -> Option<u8> {
178 if self.position < self.data.len() {
179 Some(self.data[self.position])
180 } else {
181 None
182 }
183 }
184
185 fn consume_n(&mut self, n: usize) {
186 self.position = std::cmp::min(self.position + n, self.data.len());
187 }
188
189 pub fn is_eof(&self) -> bool {
190 self.peek().is_none() || self.is_end_of_copy_marker()
191 }
192
193 pub fn is_end_of_copy_marker(&self) -> bool {
194 self.check_bytes(END_OF_COPY_MARKER)
195 }
196
197 fn is_end_of_line(&self) -> bool {
198 match self.peek() {
199 Some(b'\n') | None => true,
200 _ => false,
201 }
202 }
203
204 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
205 if self.is_end_of_line() {
206 self.consume_n(1);
207 Ok(())
208 } else {
209 Err(io::Error::new(
210 io::ErrorKind::InvalidData,
211 "extra data after last expected column",
212 ))
213 }
214 }
215
216 fn is_column_delimiter(&self) -> bool {
217 self.check_bytes(&[self.column_delimiter])
218 }
219
220 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
221 if self.consume_bytes(&[self.column_delimiter]) {
222 Ok(())
223 } else {
224 Err(io::Error::new(
225 io::ErrorKind::InvalidData,
226 "missing data for column",
227 ))
228 }
229 }
230
231 fn check_bytes(&self, bytes: &[u8]) -> bool {
232 self.data
233 .get(self.position..self.position + bytes.len())
234 .map_or(false, |d| d == bytes)
235 }
236
237 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
238 if self.check_bytes(bytes) {
239 self.consume_n(bytes.len());
240 true
241 } else {
242 false
243 }
244 }
245
246 fn consume_null_string(&mut self) -> bool {
247 if self.null_string.is_empty() {
248 self.is_column_delimiter()
251 || self.is_end_of_line()
252 || self.is_end_of_copy_marker()
253 || self.is_eof()
254 } else {
255 self.consume_bytes(self.null_string.as_bytes())
256 }
257 }
258
259 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
260 if self.consume_null_string() {
261 return Ok(None);
262 }
263
264 let mut start = self.position;
265
266 self.buffer.clear();
268
269 while !self.is_eof() && !self.is_end_of_copy_marker() {
270 if self.is_end_of_line() || self.is_column_delimiter() {
271 break;
272 }
273 match self.peek() {
274 Some(b'\\') => {
275 self.buffer.extend(&self.data[start..self.position]);
277
278 self.consume_n(1);
279 match self.peek() {
280 Some(b'b') => {
281 self.consume_n(1);
282 self.buffer.push(8);
283 }
284 Some(b'f') => {
285 self.consume_n(1);
286 self.buffer.push(12);
287 }
288 Some(b'n') => {
289 self.consume_n(1);
290 self.buffer.push(b'\n');
291 }
292 Some(b'r') => {
293 self.consume_n(1);
294 self.buffer.push(b'\r');
295 }
296 Some(b't') => {
297 self.consume_n(1);
298 self.buffer.push(b'\t');
299 }
300 Some(b'v') => {
301 self.consume_n(1);
302 self.buffer.push(11);
303 }
304 Some(b'x') => {
305 self.consume_n(1);
306 match self.peek() {
307 Some(_c @ b'0'..=b'9')
308 | Some(_c @ b'A'..=b'F')
309 | Some(_c @ b'a'..=b'f') => {
310 let mut value: u8 = 0;
311 let decode_nibble = |b| match b {
312 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
313 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
314 Some(c @ b'0'..=b'9') => Some(c - b'0'),
315 _ => None,
316 };
317 for _ in 0..2 {
318 match decode_nibble(self.peek()) {
319 Some(c) => {
320 self.consume_n(1);
321 value = (value << 4) | c;
322 }
323 _ => break,
324 }
325 }
326 self.buffer.push(value);
327 }
328 _ => {
329 self.buffer.push(b'x');
330 }
331 }
332 }
333 Some(_c @ b'0'..=b'7') => {
334 let mut value: u8 = 0;
335 for _ in 0..3 {
336 match self.peek() {
337 Some(c @ b'0'..=b'7') => {
338 self.consume_n(1);
339 value = (value << 3) | (c - b'0');
340 }
341 _ => break,
342 }
343 }
344 self.buffer.push(value);
345 }
346 Some(c) => {
347 self.consume_n(1);
348 self.buffer.push(c);
349 }
350 None => {
351 self.buffer.push(b'\\');
352 }
353 }
354
355 start = self.position;
356 }
357 Some(_) => {
358 self.consume_n(1);
359 }
360 None => {}
361 }
362 }
363
364 if self.buffer.is_empty() {
366 Ok(Some(&self.data[start..self.position]))
367 } else {
368 self.buffer.extend(&self.data[start..self.position]);
371 Ok(Some(&self.buffer[..]))
372 }
373 }
374
375 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
377 RawIterator {
378 parser: self,
379 current_column: 0,
380 num_columns,
381 truncate: false,
382 }
383 }
384
385 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
387 RawIterator {
388 parser: self,
389 current_column: 0,
390 num_columns,
391 truncate: true,
392 }
393 }
394}
395
396pub struct RawIterator<'a> {
397 parser: CopyTextFormatParser<'a>,
398 current_column: usize,
399 num_columns: usize,
400 truncate: bool,
401}
402
403impl<'a> RawIterator<'a> {
404 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
405 if self.current_column > self.num_columns {
406 return None;
407 }
408
409 if self.current_column == self.num_columns {
410 if !self.truncate {
411 if let Some(err) = self.parser.expect_end_of_line().err() {
412 return Some(Err(err));
413 }
414 }
415
416 return None;
417 }
418
419 if self.current_column > 0 {
420 if let Some(err) = self.parser.expect_column_delimiter().err() {
421 return Some(Err(err));
422 }
423 }
424
425 self.current_column += 1;
426 Some(self.parser.consume_raw_value())
427 }
428}
429
430#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
431pub enum CopyFormatParams<'a> {
432 Text(CopyTextFormatParams<'a>),
433 Csv(CopyCsvFormatParams<'a>),
434 Binary,
435 Parquet,
436}
437
438impl CopyFormatParams<'static> {
439 pub fn file_extension(&self) -> &str {
440 match self {
441 &CopyFormatParams::Text(_) => "txt",
442 &CopyFormatParams::Csv(_) => "csv",
443 &CopyFormatParams::Binary => "bin",
444 &CopyFormatParams::Parquet => "parquet",
445 }
446 }
447
448 pub fn requires_header(&self) -> bool {
449 match self {
450 CopyFormatParams::Text(_) => false,
451 CopyFormatParams::Csv(params) => params.header,
452 CopyFormatParams::Binary => false,
453 CopyFormatParams::Parquet => false,
454 }
455 }
456}
457
458pub fn decode_copy_format<'a>(
460 data: &[u8],
461 column_types: &[mz_pgrepr::Type],
462 params: CopyFormatParams<'a>,
463) -> Result<Vec<Row>, io::Error> {
464 match params {
465 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
466 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
467 CopyFormatParams::Binary => Err(io::Error::new(
468 io::ErrorKind::Unsupported,
469 "cannot decode as binary format",
470 )),
471 CopyFormatParams::Parquet => {
472 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
474 }
475 }
476}
477
478pub fn encode_copy_format<'a>(
480 params: &CopyFormatParams<'a>,
481 row: &RowRef,
482 typ: &SqlRelationType,
483 out: &mut Vec<u8>,
484) -> Result<(), io::Error> {
485 match params {
486 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
487 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
488 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
489 CopyFormatParams::Parquet => {
490 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
492 }
493 }
494}
495
496pub fn encode_copy_format_header<'a>(
497 params: &CopyFormatParams<'a>,
498 desc: &RelationDesc,
499 out: &mut Vec<u8>,
500) -> Result<(), io::Error> {
501 match params {
502 CopyFormatParams::Text(_) => Ok(()),
503 CopyFormatParams::Binary => Ok(()),
504 CopyFormatParams::Csv(params) => {
505 let mut header_row = Row::with_capacity(desc.arity());
506 header_row
507 .packer()
508 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
509 let typ = SqlRelationType::new(vec![
510 SqlColumnType {
511 scalar_type: SqlScalarType::String,
512 nullable: false,
513 };
514 desc.arity()
515 ]);
516 encode_copy_row_csv(params, &header_row, &typ, out)
517 }
518 CopyFormatParams::Parquet => {
519 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
521 }
522 }
523}
524
525#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
526pub struct CopyTextFormatParams<'a> {
527 pub null: Cow<'a, str>,
528 pub delimiter: u8,
529}
530
531impl<'a> Default for CopyTextFormatParams<'a> {
532 fn default() -> Self {
533 CopyTextFormatParams {
534 delimiter: b'\t',
535 null: Cow::from("\\N"),
536 }
537 }
538}
539
540pub fn decode_copy_format_text(
541 data: &[u8],
542 column_types: &[mz_pgrepr::Type],
543 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
544) -> Result<Vec<Row>, io::Error> {
545 let mut rows = Vec::new();
546
547 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
549 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
550 let mut row = Vec::new();
551 let buf = RowArena::new();
552 for (col, typ) in column_types.iter().enumerate() {
553 if col > 0 {
554 parser.expect_column_delimiter()?;
555 }
556 let raw_value = parser.consume_raw_value()?;
557 if let Some(raw_value) = raw_value {
558 match mz_pgrepr::Value::decode_text(typ, raw_value) {
559 Ok(value) => {
560 row.push(
561 value
562 .into_datum_decode_error(&buf, typ, "column")
563 .map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?,
564 );
565 }
566 Err(err) => {
567 let msg = format!("unable to decode column: {}", err);
568 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
569 }
570 }
571 } else {
572 row.push(Datum::Null);
573 }
574 }
575 parser.expect_end_of_line()?;
576 rows.push(Row::pack(row));
577 }
578 Ok(rows)
581}
582
583#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
584pub struct CopyCsvFormatParams<'a> {
585 pub delimiter: u8,
586 pub quote: u8,
587 pub escape: u8,
588 pub header: bool,
589 pub null: Cow<'a, str>,
590}
591
592impl<'a> CopyCsvFormatParams<'a> {
593 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
594 CopyCsvFormatParams {
595 delimiter: self.delimiter,
596 quote: self.quote,
597 escape: self.escape,
598 header: self.header,
599 null: Cow::Owned(self.null.to_string()),
600 }
601 }
602}
603
604impl Arbitrary for CopyCsvFormatParams<'static> {
605 type Parameters = ();
606 type Strategy = BoxedStrategy<Self>;
607
608 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
609 (
610 any::<u8>(),
611 any::<u8>(),
612 any::<u8>(),
613 any::<bool>(),
614 any::<String>(),
615 )
616 .prop_map(|(delimiter, diff, escape, header, null)| {
617 let diff = diff.saturating_sub(1).max(1);
619 let quote = delimiter.wrapping_add(diff);
620
621 Self::try_new(
622 Some(delimiter),
623 Some(quote),
624 Some(escape),
625 Some(header),
626 Some(null),
627 )
628 .expect("delimiter and quote should be different")
629 })
630 .boxed()
631 }
632}
633
634impl<'a> Default for CopyCsvFormatParams<'a> {
635 fn default() -> Self {
636 CopyCsvFormatParams {
637 delimiter: b',',
638 quote: b'"',
639 escape: b'"',
640 header: false,
641 null: Cow::from(""),
642 }
643 }
644}
645
646impl<'a> CopyCsvFormatParams<'a> {
647 pub fn try_new(
648 delimiter: Option<u8>,
649 quote: Option<u8>,
650 escape: Option<u8>,
651 header: Option<bool>,
652 null: Option<String>,
653 ) -> Result<CopyCsvFormatParams<'a>, String> {
654 let mut params = CopyCsvFormatParams::default();
655
656 if let Some(delimiter) = delimiter {
657 params.delimiter = delimiter;
658 }
659 if let Some(quote) = quote {
660 params.quote = quote;
661 params.escape = quote;
663 }
664 if let Some(escape) = escape {
665 params.escape = escape;
666 }
667 if let Some(header) = header {
668 params.header = header;
669 }
670 if let Some(null) = null {
671 params.null = Cow::from(null);
672 }
673
674 if params.quote == params.delimiter {
675 return Err("COPY delimiter and quote must be different".to_string());
676 }
677 Ok(params)
678 }
679}
680
681struct DecodedField {
689 start: usize,
690 end: usize,
691 quoted: bool,
692}
693
694pub fn decode_copy_format_csv(
695 data: &[u8],
696 column_types: &[mz_pgrepr::Type],
697 CopyCsvFormatParams {
698 delimiter,
699 quote,
700 escape,
701 null,
702 header,
703 }: CopyCsvFormatParams,
704) -> Result<Vec<Row>, io::Error> {
705 let (double_quote, escape) = if quote == escape {
706 (true, None)
707 } else {
708 (false, Some(escape))
709 };
710
711 let mut rdr = csv_core::ReaderBuilder::new()
712 .delimiter(delimiter)
713 .quote(quote)
714 .escape(escape)
715 .double_quote(double_quote)
716 .build();
717
718 let null_as_bytes = null.as_bytes();
719 let mut rows = Vec::new();
720 let mut input = data;
727 let mut output = vec![0u8; data.len().max(1024)];
728 let mut out_pos = 0;
729 let mut fields: Vec<DecodedField> = Vec::new();
730 let mut field_start = 0;
731 let mut field_quoted: Option<bool> = None;
732 let mut skip_header = header;
733 let mut at_record_start = true;
737
738 loop {
739 if field_quoted.is_none() {
740 if at_record_start {
741 while matches!(input.first(), Some(&b'\r') | Some(&b'\n')) {
754 input = &input[1..];
755 }
756 }
757 field_quoted = Some(input.first() == Some("e));
758 field_start = out_pos;
759 }
760
761 let (result, nin, nout) = rdr.read_field(input, &mut output[out_pos..]);
762 input = &input[nin..];
763 out_pos += nout;
764
765 match result {
766 csv_core::ReadFieldResult::Field { record_end } => {
767 fields.push(DecodedField {
768 start: field_start,
769 end: out_pos,
770 quoted: field_quoted.take().unwrap(),
771 });
772 at_record_start = record_end;
775
776 if record_end {
777 if skip_header {
778 skip_header = false;
779 } else if let [
780 DecodedField {
781 start,
782 end,
783 quoted: false,
784 },
785 ] = fields[..]
786 && &output[start..end] == END_OF_COPY_MARKER
787 {
788 return Ok(rows);
792 } else {
793 match fields.len().cmp(&column_types.len()) {
794 std::cmp::Ordering::Less => {
795 return Err(io::Error::new(
796 io::ErrorKind::InvalidData,
797 "missing data for column",
798 ));
799 }
800 std::cmp::Ordering::Greater => {
801 return Err(io::Error::new(
802 io::ErrorKind::InvalidData,
803 "extra data after last expected column",
804 ));
805 }
806 std::cmp::Ordering::Equal => {}
807 }
808
809 let mut row_builder = SharedRow::get();
810 let mut row_packer = row_builder.packer();
811 for (typ, field) in column_types.iter().zip_eq(fields.iter()) {
812 let raw_value = &output[field.start..field.end];
813 if !field.quoted && raw_value == null_as_bytes {
814 row_packer.push(Datum::Null);
815 } else {
816 let s = match std::str::from_utf8(raw_value) {
817 Ok(s) => s,
818 Err(err) => {
819 let msg = format!("invalid utf8 data in column: {}", err);
820 return Err(io::Error::new(
821 io::ErrorKind::InvalidData,
822 msg,
823 ));
824 }
825 };
826 if let Err(err) =
827 mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer)
828 {
829 let msg = format!("unable to decode column: {}", err);
830 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
831 }
832 }
833 }
834 rows.push(row_builder.clone());
835 }
836 fields.clear();
837 out_pos = 0;
838 }
839 }
840 csv_core::ReadFieldResult::OutputFull => {
841 let new_len = output.len().saturating_mul(2).max(out_pos + 1);
842 output.resize(new_len, 0);
843 }
844 csv_core::ReadFieldResult::InputEmpty => {
845 }
851 csv_core::ReadFieldResult::End => return Ok(rows),
852 }
853 }
854}
855
856#[cfg(test)]
857mod tests {
858 use mz_ore::collections::CollectionExt;
859 use mz_repr::SqlColumnType;
860 use proptest::prelude::*;
861
862 use super::*;
863
864 #[mz_ore::test]
865 fn test_copy_format_text_parser() {
866 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
867 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
868 assert!(parser.is_column_delimiter());
869 parser
870 .expect_column_delimiter()
871 .expect("expected column delimiter");
872 assert_eq!(
873 parser
874 .consume_raw_value()
875 .expect("unexpected error")
876 .expect("unexpected empty result"),
877 "\nt e".as_bytes()
878 );
879 parser
880 .expect_column_delimiter()
881 .expect("expected column delimiter");
882 assert!(
884 parser
885 .consume_raw_value()
886 .expect("unexpected error")
887 .is_none()
888 );
889 parser
890 .expect_column_delimiter()
891 .expect("expected column delimiter");
892 assert!(parser.is_end_of_line());
893 parser.expect_end_of_line().expect("expected eol");
894 assert_eq!(
896 parser
897 .consume_raw_value()
898 .expect("unexpected error")
899 .expect("unexpected empty result"),
900 "`\n}J".as_bytes()
901 );
902 parser.expect_end_of_line().expect("expected eol");
903 assert_eq!(
905 parser
906 .consume_raw_value()
907 .expect("unexpected error")
908 .expect("unexpected empty result"),
909 "$$S".as_bytes()
910 );
911 assert!(parser.is_eof());
912 }
913
914 #[mz_ore::test]
915 fn test_copy_format_text_empty_null_string() {
916 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
917 let expect = vec![
918 vec![None, None],
919 vec![Some("10"), Some("20")],
920 vec![Some("30"), None],
921 vec![Some("40"), None],
922 ];
923 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
924 for line in expect {
925 for (i, value) in line.iter().enumerate() {
926 if i > 0 {
927 parser
928 .expect_column_delimiter()
929 .expect("expected column delimiter");
930 }
931 match value {
932 Some(s) => {
933 assert!(!parser.consume_null_string());
934 assert_eq!(
935 parser
936 .consume_raw_value()
937 .expect("unexpected error")
938 .expect("unexpected empty result"),
939 s.as_bytes()
940 );
941 }
942 None => {
943 assert!(parser.consume_null_string());
944 }
945 }
946 }
947 parser.expect_end_of_line().expect("expected eol");
948 }
949 }
950
951 #[mz_ore::test]
952 fn test_copy_format_text_parser_escapes() {
953 struct TestCase {
954 input: &'static str,
955 expect: &'static [u8],
956 }
957 let tests = vec![
958 TestCase {
959 input: "simple",
960 expect: b"simple",
961 },
962 TestCase {
963 input: r#"new\nline"#,
964 expect: b"new\nline",
965 },
966 TestCase {
967 input: r#"\b\f\n\r\t\v\\"#,
968 expect: b"\x08\x0c\n\r\t\x0b\\",
969 },
970 TestCase {
971 input: r#"\0\12\123"#,
972 expect: &[0, 0o12, 0o123],
973 },
974 TestCase {
975 input: r#"\x1\xaf"#,
976 expect: &[0x01, 0xaf],
977 },
978 TestCase {
979 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
980 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
981 },
982 TestCase {
983 input: r#"\\\""#,
984 expect: b"\\\"",
985 },
986 TestCase {
987 input: r#"\x"#,
988 expect: b"x",
989 },
990 TestCase {
991 input: r#"\xg"#,
992 expect: b"xg",
993 },
994 TestCase {
995 input: r#"\"#,
996 expect: b"\\",
997 },
998 TestCase {
999 input: r#"\8"#,
1000 expect: b"8",
1001 },
1002 TestCase {
1003 input: r#"\a"#,
1004 expect: b"a",
1005 },
1006 TestCase {
1007 input: r#"\x\xg\8\xH\x32\s\"#,
1008 expect: b"xxg8xH2s\\",
1009 },
1010 ];
1011
1012 for test in tests {
1013 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1014 assert_eq!(
1015 parser
1016 .consume_raw_value()
1017 .expect("unexpected error")
1018 .expect("unexpected empty result"),
1019 test.expect,
1020 "input: {}, expect: {:?}",
1021 test.input,
1022 std::str::from_utf8(test.expect),
1023 );
1024 assert!(parser.is_eof());
1025 }
1026 }
1027
1028 #[mz_ore::test]
1029 fn test_copy_csv_format_params() {
1030 assert_eq!(
1031 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1032 Ok(CopyCsvFormatParams {
1033 delimiter: b't',
1034 quote: b'q',
1035 escape: b'q',
1036 header: false,
1037 null: Cow::from(""),
1038 })
1039 );
1040
1041 assert_eq!(
1042 CopyCsvFormatParams::try_new(
1043 Some(b't'),
1044 Some(b'q'),
1045 Some(b'e'),
1046 Some(true),
1047 Some("null".to_string())
1048 ),
1049 Ok(CopyCsvFormatParams {
1050 delimiter: b't',
1051 quote: b'q',
1052 escape: b'e',
1053 header: true,
1054 null: Cow::from("null"),
1055 })
1056 );
1057
1058 assert_eq!(
1059 CopyCsvFormatParams::try_new(
1060 None,
1061 Some(b','),
1062 Some(b'e'),
1063 Some(true),
1064 Some("null".to_string())
1065 ),
1066 Err("COPY delimiter and quote must be different".to_string())
1067 );
1068 }
1069
1070 #[mz_ore::test]
1071 fn test_copy_csv_row() -> Result<(), io::Error> {
1072 let mut row = Row::default();
1073 let mut packer = row.packer();
1074 packer.push(Datum::from("1,2,\"3\""));
1075 packer.push(Datum::Null);
1076 packer.push(Datum::from(1000u64));
1077 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
1079
1080 let typ: SqlRelationType = SqlRelationType::new(vec![
1081 SqlColumnType {
1082 scalar_type: mz_repr::SqlScalarType::String,
1083 nullable: false,
1084 },
1085 SqlColumnType {
1086 scalar_type: mz_repr::SqlScalarType::String,
1087 nullable: true,
1088 },
1089 SqlColumnType {
1090 scalar_type: mz_repr::SqlScalarType::UInt64,
1091 nullable: false,
1092 },
1093 SqlColumnType {
1094 scalar_type: mz_repr::SqlScalarType::String,
1095 nullable: false,
1096 },
1097 SqlColumnType {
1098 scalar_type: mz_repr::SqlScalarType::String,
1099 nullable: false,
1100 },
1101 ]);
1102
1103 let mut out = Vec::new();
1104
1105 struct TestCase<'a> {
1106 params: CopyCsvFormatParams<'a>,
1107 expected: &'static [u8],
1108 }
1109
1110 let tests = [
1111 TestCase {
1112 params: CopyCsvFormatParams::default(),
1113 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1114 },
1115 TestCase {
1116 params: CopyCsvFormatParams {
1117 null: Cow::from("NULL"),
1118 quote: b'q',
1119 escape: b'e',
1120 ..Default::default()
1121 },
1122 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1123 },
1124 ];
1125
1126 for TestCase { params, expected } in tests {
1127 out.clear();
1128 let params = CopyFormatParams::Csv(params);
1129 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1130 let output = std::str::from_utf8(&out);
1131 assert_eq!(output, std::str::from_utf8(expected));
1132 }
1133
1134 Ok(())
1135 }
1136
1137 #[mz_ore::test]
1138 fn test_decode_copy_format_csv_end_marker() {
1139 let column_types = vec![mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String)];
1145
1146 let decode_strings = |input: &[u8]| -> Vec<String> {
1147 decode_copy_format_csv(input, &column_types, CopyCsvFormatParams::default())
1148 .expect("decode should succeed")
1149 .iter()
1150 .map(|r| match r.iter().next().unwrap() {
1151 Datum::String(s) => s.to_owned(),
1152 d => panic!("unexpected datum: {:?}", d),
1153 })
1154 .collect()
1155 };
1156
1157 for eol in [&b"\n"[..], b"\r\n", b"\r"] {
1158 let join = |lines: &[&str]| -> Vec<u8> {
1159 let mut out = Vec::new();
1160 for line in lines {
1161 out.extend_from_slice(line.as_bytes());
1162 out.extend_from_slice(eol);
1163 }
1164 out
1165 };
1166
1167 assert_eq!(
1169 decode_strings(&join(&["before", "\"\\.\"", "after"])),
1170 vec!["before", "\\.", "after"],
1171 "quoted marker, eol={eol:?}"
1172 );
1173
1174 assert_eq!(
1176 decode_strings(&join(&["first", "\\.", "ignored"])),
1177 vec!["first"],
1178 "bare marker, eol={eol:?}"
1179 );
1180 }
1181 }
1182
1183 #[mz_ore::test]
1184 fn test_decode_copy_format_csv_leading_orphan() {
1185 let column_types = vec![
1191 mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1192 mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1193 ];
1194
1195 let rows = decode_copy_format_csv(
1196 b"\n\"\",x\r\ny,z\r\n",
1197 &column_types,
1198 CopyCsvFormatParams::default(),
1199 )
1200 .expect("decode should succeed");
1201 let got: Vec<Vec<Datum>> = rows.iter().map(|r| r.iter().collect()).collect();
1202 assert_eq!(got.len(), 2);
1203 assert_eq!(got[0][0], Datum::String(""));
1207 assert_eq!(got[0][1], Datum::String("x"));
1208 assert_eq!(got[1][0], Datum::String("y"));
1209 assert_eq!(got[1][1], Datum::String("z"));
1210 }
1211
1212 #[mz_ore::test]
1213 fn test_decode_copy_format_csv_quoted_null() {
1214 let column_types = vec![
1217 mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1218 mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1219 ];
1220
1221 let cases: &[(CopyCsvFormatParams, &[(&str, &str)], &[[Option<&str>; 2]])] = &[
1226 (
1228 CopyCsvFormatParams::default(),
1229 &[("a", ""), ("b", "\"\""), ("\"\"", "c")],
1230 &[
1231 [Some("a"), None],
1232 [Some("b"), Some("")],
1233 [Some(""), Some("c")],
1234 ],
1235 ),
1236 (
1238 CopyCsvFormatParams {
1239 null: Cow::from("NULL"),
1240 ..Default::default()
1241 },
1242 &[("a", "NULL"), ("b", "\"NULL\""), ("NULL", "c")],
1243 &[
1244 [Some("a"), None],
1245 [Some("b"), Some("NULL")],
1246 [None, Some("c")],
1247 ],
1248 ),
1249 ];
1250
1251 for eol in [&b"\n"[..], b"\r\n", b"\r"] {
1252 for (params, lines, expected) in cases {
1253 let mut input = Vec::new();
1254 for (a, b) in *lines {
1255 input.extend_from_slice(a.as_bytes());
1256 input.push(b',');
1257 input.extend_from_slice(b.as_bytes());
1258 input.extend_from_slice(eol);
1259 }
1260 let rows = decode_copy_format_csv(&input, &column_types, params.clone())
1261 .expect("decode should succeed");
1262 assert_eq!(rows.len(), expected.len(), "eol={eol:?}");
1263 for (row, want) in rows.iter().zip_eq(expected.iter()) {
1264 let got: Vec<Datum> = row.iter().collect();
1265 assert_eq!(got.len(), 2);
1266 for (g, w) in got.iter().zip_eq(want.iter()) {
1267 match (g, w) {
1268 (Datum::Null, None) => {}
1269 (Datum::String(s), Some(w)) => assert_eq!(s, w, "eol={eol:?}"),
1270 _ => panic!("mismatch: got {g:?}, want {w:?}, eol={eol:?}"),
1271 }
1272 }
1273 }
1274 }
1275 }
1276 }
1277
1278 proptest! {
1279 #[mz_ore::test]
1280 #[cfg_attr(miri, ignore)]
1281 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1282 let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1284 let row = Row::pack_slice(&[datum]);
1285 let typ = SqlRelationType::new(vec![
1286 SqlColumnType {
1287 scalar_type: scalar_type.clone(),
1288 nullable: true,
1289 }
1290 ]);
1291
1292 let mut buf = Vec::new();
1293 let mut csv_params = copy_csv_params.clone();
1294 csv_params.header = false;
1296 let params = CopyFormatParams::Csv(csv_params);
1297
1298 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1300 let column_types = typ
1301 .column_types
1302 .iter()
1303 .map(|x| &x.scalar_type)
1304 .map(mz_pgrepr::Type::from)
1305 .collect::<Vec<mz_pgrepr::Type>>();
1306 let result = decode_copy_format(&buf, &column_types, params);
1307
1308 match result {
1309 Ok(rows) => {
1310 let out_str = std::str::from_utf8(&buf[..]);
1311
1312 prop_assert_eq!(
1313 rows.len(),
1314 1,
1315 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1316 );
1317 let output = rows.into_element();
1318
1319 prop_assert_eq!(
1320 row,
1321 output,
1322 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1323 );
1324 }
1325 _ => {
1326 }
1328 }
1329
1330 Ok(())
1331 };
1332
1333 for scalar_type in SqlScalarType::enumerate() {
1335 for datum in scalar_type.interesting_datums() {
1336 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1338 let mut buf = bytes::BytesMut::new();
1339 value.encode_text(&mut buf);
1340
1341 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1342 if datum_str == copy_csv_params.null {
1343 continue;
1344 }
1345 }
1346 }
1347
1348 let updated_datum = match datum {
1349 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1351 continue;
1352 }
1353 Datum::String(s) => {
1354 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1356 continue;
1357 } else {
1358 Datum::String(s)
1359 }
1360 }
1361 other => other,
1362 };
1363
1364 let result = try_roundtrip_datum(scalar_type, updated_datum);
1365 prop_assert!(result.is_ok(), "failure: {result:?}");
1366 }
1367 }
1368 }
1369 }
1370}