1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use itertools::Itertools;
16use mz_repr::{
17 Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18 SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, any};
21use proptest::strategy::{BoxedStrategy, Strategy};
22use serde::{Deserialize, Serialize};
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26fn encode_copy_row_binary(
27 row: &RowRef,
28 typ: &SqlRelationType,
29 out: &mut Vec<u8>,
30) -> Result<(), io::Error> {
31 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
32
33 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
35 io::Error::new(
36 io::ErrorKind::InvalidData,
37 "column count does not fit into an i16",
38 )
39 })?;
40
41 out.extend(count.to_be_bytes());
42 let mut buf = BytesMut::new();
43 for (field, typ) in row
44 .iter()
45 .zip_eq(&typ.column_types)
46 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
47 {
48 match field {
49 None => out.extend(NULL_BYTES),
50 Some(field) => {
51 buf.clear();
52 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
53 out.extend(
54 i32::try_from(buf.len())
55 .map_err(|_| {
56 io::Error::new(
57 io::ErrorKind::InvalidData,
58 "field length does not fit into an i32",
59 )
60 })?
61 .to_be_bytes(),
62 );
63 out.extend(&buf);
64 }
65 }
66 }
67 Ok(())
68}
69
70fn encode_copy_row_text(
71 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
72 row: &RowRef,
73 typ: &SqlRelationType,
74 out: &mut Vec<u8>,
75) -> Result<(), io::Error> {
76 let null = null.as_bytes();
77 let mut buf = BytesMut::new();
78 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
79 if idx > 0 {
80 out.push(*delimiter);
81 }
82 match field {
83 None => out.extend(null),
84 Some(field) => {
85 buf.clear();
86 field.encode_text(&mut buf);
87 for b in &buf {
88 match b {
89 b'\\' => out.extend(b"\\\\"),
90 b'\n' => out.extend(b"\\n"),
91 b'\r' => out.extend(b"\\r"),
92 b'\t' => out.extend(b"\\t"),
93 _ => out.push(*b),
94 }
95 }
96 }
97 }
98 }
99 out.push(b'\n');
100 Ok(())
101}
102
103fn encode_copy_row_csv(
104 CopyCsvFormatParams {
105 delimiter: delim,
106 quote,
107 escape,
108 header: _,
109 null,
110 }: &CopyCsvFormatParams,
111 row: &RowRef,
112 typ: &SqlRelationType,
113 out: &mut Vec<u8>,
114) -> Result<(), io::Error> {
115 let null = null.as_bytes();
116 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
117 let mut buf = BytesMut::new();
118 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
119 if idx > 0 {
120 out.push(*delim);
121 }
122 match field {
123 None => out.extend(null),
124 Some(field) => {
125 buf.clear();
126 field.encode_text(&mut buf);
127 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
133 || buf.iter().any(is_special)
134 || &*buf == null
135 {
136 out.push(*quote);
140 for b in &buf {
141 if *b == *quote || *b == *escape {
142 out.push(*escape);
143 }
144 out.push(*b);
145 }
146 out.push(*quote);
147 } else {
148 out.extend(&buf);
151 }
152 }
153 }
154 }
155 out.push(b'\n');
156 Ok(())
157}
158
159pub struct CopyTextFormatParser<'a> {
160 data: &'a [u8],
161 position: usize,
162 column_delimiter: u8,
163 null_string: &'a str,
164 buffer: Vec<u8>,
165}
166
167impl<'a> CopyTextFormatParser<'a> {
168 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
169 Self {
170 data,
171 position: 0,
172 column_delimiter,
173 null_string,
174 buffer: Vec::new(),
175 }
176 }
177
178 fn peek(&self) -> Option<u8> {
179 if self.position < self.data.len() {
180 Some(self.data[self.position])
181 } else {
182 None
183 }
184 }
185
186 fn consume_n(&mut self, n: usize) {
187 self.position = std::cmp::min(self.position + n, self.data.len());
188 }
189
190 pub fn is_eof(&self) -> bool {
191 self.peek().is_none() || self.is_end_of_copy_marker()
192 }
193
194 pub fn is_end_of_copy_marker(&self) -> bool {
195 self.check_bytes(END_OF_COPY_MARKER)
196 }
197
198 fn is_end_of_line(&self) -> bool {
199 match self.peek() {
200 Some(b'\n') | None => true,
201 _ => false,
202 }
203 }
204
205 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
206 if self.is_end_of_line() {
207 self.consume_n(1);
208 Ok(())
209 } else {
210 Err(io::Error::new(
211 io::ErrorKind::InvalidData,
212 "extra data after last expected column",
213 ))
214 }
215 }
216
217 fn is_column_delimiter(&self) -> bool {
218 self.check_bytes(&[self.column_delimiter])
219 }
220
221 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
222 if self.consume_bytes(&[self.column_delimiter]) {
223 Ok(())
224 } else {
225 Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 "missing data for column",
228 ))
229 }
230 }
231
232 fn check_bytes(&self, bytes: &[u8]) -> bool {
233 self.data
234 .get(self.position..self.position + bytes.len())
235 .map_or(false, |d| d == bytes)
236 }
237
238 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
239 if self.check_bytes(bytes) {
240 self.consume_n(bytes.len());
241 true
242 } else {
243 false
244 }
245 }
246
247 fn consume_null_string(&mut self) -> bool {
248 if self.null_string.is_empty() {
249 self.is_column_delimiter()
252 || self.is_end_of_line()
253 || self.is_end_of_copy_marker()
254 || self.is_eof()
255 } else {
256 self.consume_bytes(self.null_string.as_bytes())
257 }
258 }
259
260 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
261 if self.consume_null_string() {
262 return Ok(None);
263 }
264
265 let mut start = self.position;
266
267 self.buffer.clear();
269
270 while !self.is_eof() && !self.is_end_of_copy_marker() {
271 if self.is_end_of_line() || self.is_column_delimiter() {
272 break;
273 }
274 match self.peek() {
275 Some(b'\\') => {
276 self.buffer.extend(&self.data[start..self.position]);
278
279 self.consume_n(1);
280 match self.peek() {
281 Some(b'b') => {
282 self.consume_n(1);
283 self.buffer.push(8);
284 }
285 Some(b'f') => {
286 self.consume_n(1);
287 self.buffer.push(12);
288 }
289 Some(b'n') => {
290 self.consume_n(1);
291 self.buffer.push(b'\n');
292 }
293 Some(b'r') => {
294 self.consume_n(1);
295 self.buffer.push(b'\r');
296 }
297 Some(b't') => {
298 self.consume_n(1);
299 self.buffer.push(b'\t');
300 }
301 Some(b'v') => {
302 self.consume_n(1);
303 self.buffer.push(11);
304 }
305 Some(b'x') => {
306 self.consume_n(1);
307 match self.peek() {
308 Some(_c @ b'0'..=b'9')
309 | Some(_c @ b'A'..=b'F')
310 | Some(_c @ b'a'..=b'f') => {
311 let mut value: u8 = 0;
312 let decode_nibble = |b| match b {
313 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
314 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
315 Some(c @ b'0'..=b'9') => Some(c - b'0'),
316 _ => None,
317 };
318 for _ in 0..2 {
319 match decode_nibble(self.peek()) {
320 Some(c) => {
321 self.consume_n(1);
322 value = value << 4 | c;
323 }
324 _ => break,
325 }
326 }
327 self.buffer.push(value);
328 }
329 _ => {
330 self.buffer.push(b'x');
331 }
332 }
333 }
334 Some(_c @ b'0'..=b'7') => {
335 let mut value: u8 = 0;
336 for _ in 0..3 {
337 match self.peek() {
338 Some(c @ b'0'..=b'7') => {
339 self.consume_n(1);
340 value = value << 3 | (c - b'0');
341 }
342 _ => break,
343 }
344 }
345 self.buffer.push(value);
346 }
347 Some(c) => {
348 self.consume_n(1);
349 self.buffer.push(c);
350 }
351 None => {
352 self.buffer.push(b'\\');
353 }
354 }
355
356 start = self.position;
357 }
358 Some(_) => {
359 self.consume_n(1);
360 }
361 None => {}
362 }
363 }
364
365 if self.buffer.is_empty() {
367 Ok(Some(&self.data[start..self.position]))
368 } else {
369 self.buffer.extend(&self.data[start..self.position]);
372 Ok(Some(&self.buffer[..]))
373 }
374 }
375
376 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
378 RawIterator {
379 parser: self,
380 current_column: 0,
381 num_columns,
382 truncate: false,
383 }
384 }
385
386 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
388 RawIterator {
389 parser: self,
390 current_column: 0,
391 num_columns,
392 truncate: true,
393 }
394 }
395}
396
397pub struct RawIterator<'a> {
398 parser: CopyTextFormatParser<'a>,
399 current_column: usize,
400 num_columns: usize,
401 truncate: bool,
402}
403
404impl<'a> RawIterator<'a> {
405 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
406 if self.current_column > self.num_columns {
407 return None;
408 }
409
410 if self.current_column == self.num_columns {
411 if !self.truncate {
412 if let Some(err) = self.parser.expect_end_of_line().err() {
413 return Some(Err(err));
414 }
415 }
416
417 return None;
418 }
419
420 if self.current_column > 0 {
421 if let Some(err) = self.parser.expect_column_delimiter().err() {
422 return Some(Err(err));
423 }
424 }
425
426 self.current_column += 1;
427 Some(self.parser.consume_raw_value())
428 }
429}
430
431#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
432pub enum CopyFormatParams<'a> {
433 Text(CopyTextFormatParams<'a>),
434 Csv(CopyCsvFormatParams<'a>),
435 Binary,
436 Parquet,
437}
438
439impl CopyFormatParams<'static> {
440 pub fn file_extension(&self) -> &str {
441 match self {
442 &CopyFormatParams::Text(_) => "txt",
443 &CopyFormatParams::Csv(_) => "csv",
444 &CopyFormatParams::Binary => "bin",
445 &CopyFormatParams::Parquet => "parquet",
446 }
447 }
448
449 pub fn requires_header(&self) -> bool {
450 match self {
451 CopyFormatParams::Text(_) => false,
452 CopyFormatParams::Csv(params) => params.header,
453 CopyFormatParams::Binary => false,
454 CopyFormatParams::Parquet => false,
455 }
456 }
457}
458
459pub fn decode_copy_format<'a>(
461 data: &[u8],
462 column_types: &[mz_pgrepr::Type],
463 params: CopyFormatParams<'a>,
464) -> Result<Vec<Row>, io::Error> {
465 match params {
466 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
467 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
468 CopyFormatParams::Binary => Err(io::Error::new(
469 io::ErrorKind::Unsupported,
470 "cannot decode as binary format",
471 )),
472 CopyFormatParams::Parquet => {
473 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
475 }
476 }
477}
478
479pub fn encode_copy_format<'a>(
481 params: &CopyFormatParams<'a>,
482 row: &RowRef,
483 typ: &SqlRelationType,
484 out: &mut Vec<u8>,
485) -> Result<(), io::Error> {
486 match params {
487 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
488 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
489 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
490 CopyFormatParams::Parquet => {
491 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
493 }
494 }
495}
496
497pub fn encode_copy_format_header<'a>(
498 params: &CopyFormatParams<'a>,
499 desc: &RelationDesc,
500 out: &mut Vec<u8>,
501) -> Result<(), io::Error> {
502 match params {
503 CopyFormatParams::Text(_) => Ok(()),
504 CopyFormatParams::Binary => Ok(()),
505 CopyFormatParams::Csv(params) => {
506 let mut header_row = Row::with_capacity(desc.arity());
507 header_row
508 .packer()
509 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
510 let typ = SqlRelationType::new(vec![
511 SqlColumnType {
512 scalar_type: SqlScalarType::String,
513 nullable: false,
514 };
515 desc.arity()
516 ]);
517 encode_copy_row_csv(params, &header_row, &typ, out)
518 }
519 CopyFormatParams::Parquet => {
520 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
522 }
523 }
524}
525
526#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
527pub struct CopyTextFormatParams<'a> {
528 pub null: Cow<'a, str>,
529 pub delimiter: u8,
530}
531
532impl<'a> Default for CopyTextFormatParams<'a> {
533 fn default() -> Self {
534 CopyTextFormatParams {
535 delimiter: b'\t',
536 null: Cow::from("\\N"),
537 }
538 }
539}
540
541pub fn decode_copy_format_text(
542 data: &[u8],
543 column_types: &[mz_pgrepr::Type],
544 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
545) -> Result<Vec<Row>, io::Error> {
546 let mut rows = Vec::new();
547
548 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
550 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
551 let mut row = Vec::new();
552 let buf = RowArena::new();
553 for (col, typ) in column_types.iter().enumerate() {
554 if col > 0 {
555 parser.expect_column_delimiter()?;
556 }
557 let raw_value = parser.consume_raw_value()?;
558 if let Some(raw_value) = raw_value {
559 match mz_pgrepr::Value::decode_text(typ, raw_value) {
560 Ok(value) => {
561 row.push(
562 value
563 .into_datum_decode_error(&buf, typ, "column")
564 .map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?,
565 );
566 }
567 Err(err) => {
568 let msg = format!("unable to decode column: {}", err);
569 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
570 }
571 }
572 } else {
573 row.push(Datum::Null);
574 }
575 }
576 parser.expect_end_of_line()?;
577 rows.push(Row::pack(row));
578 }
579 Ok(rows)
582}
583
584#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
585pub struct CopyCsvFormatParams<'a> {
586 pub delimiter: u8,
587 pub quote: u8,
588 pub escape: u8,
589 pub header: bool,
590 pub null: Cow<'a, str>,
591}
592
593impl<'a> CopyCsvFormatParams<'a> {
594 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
595 CopyCsvFormatParams {
596 delimiter: self.delimiter,
597 quote: self.quote,
598 escape: self.escape,
599 header: self.header,
600 null: Cow::Owned(self.null.to_string()),
601 }
602 }
603}
604
605impl Arbitrary for CopyCsvFormatParams<'static> {
606 type Parameters = ();
607 type Strategy = BoxedStrategy<Self>;
608
609 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
610 (
611 any::<u8>(),
612 any::<u8>(),
613 any::<u8>(),
614 any::<bool>(),
615 any::<String>(),
616 )
617 .prop_map(|(delimiter, diff, escape, header, null)| {
618 let diff = diff.saturating_sub(1).max(1);
620 let quote = delimiter.wrapping_add(diff);
621
622 Self::try_new(
623 Some(delimiter),
624 Some(quote),
625 Some(escape),
626 Some(header),
627 Some(null),
628 )
629 .expect("delimiter and quote should be different")
630 })
631 .boxed()
632 }
633}
634
635impl<'a> Default for CopyCsvFormatParams<'a> {
636 fn default() -> Self {
637 CopyCsvFormatParams {
638 delimiter: b',',
639 quote: b'"',
640 escape: b'"',
641 header: false,
642 null: Cow::from(""),
643 }
644 }
645}
646
647impl<'a> CopyCsvFormatParams<'a> {
648 pub fn try_new(
649 delimiter: Option<u8>,
650 quote: Option<u8>,
651 escape: Option<u8>,
652 header: Option<bool>,
653 null: Option<String>,
654 ) -> Result<CopyCsvFormatParams<'a>, String> {
655 let mut params = CopyCsvFormatParams::default();
656
657 if let Some(delimiter) = delimiter {
658 params.delimiter = delimiter;
659 }
660 if let Some(quote) = quote {
661 params.quote = quote;
662 params.escape = quote;
664 }
665 if let Some(escape) = escape {
666 params.escape = escape;
667 }
668 if let Some(header) = header {
669 params.header = header;
670 }
671 if let Some(null) = null {
672 params.null = Cow::from(null);
673 }
674
675 if params.quote == params.delimiter {
676 return Err("COPY delimiter and quote must be different".to_string());
677 }
678 Ok(params)
679 }
680}
681
682pub fn decode_copy_format_csv(
683 data: &[u8],
684 column_types: &[mz_pgrepr::Type],
685 CopyCsvFormatParams {
686 delimiter,
687 quote,
688 escape,
689 null,
690 header,
691 }: CopyCsvFormatParams,
692) -> Result<Vec<Row>, io::Error> {
693 let mut rows = Vec::new();
694
695 let (double_quote, escape) = if quote == escape {
696 (true, None)
697 } else {
698 (false, Some(escape))
699 };
700
701 let mut rdr = ReaderBuilder::new()
702 .delimiter(delimiter)
703 .quote(quote)
704 .has_headers(header)
705 .double_quote(double_quote)
706 .escape(escape)
707 .flexible(true)
710 .from_reader(data);
711
712 let null_as_bytes = null.as_bytes();
713
714 let mut record = ByteRecord::new();
715
716 while rdr.read_byte_record(&mut record)? {
717 if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
718 break;
719 }
720
721 match record.len().cmp(&column_types.len()) {
722 std::cmp::Ordering::Less => Err(io::Error::new(
723 io::ErrorKind::InvalidData,
724 "missing data for column",
725 )),
726 std::cmp::Ordering::Greater => Err(io::Error::new(
727 io::ErrorKind::InvalidData,
728 "extra data after last expected column",
729 )),
730 std::cmp::Ordering::Equal => Ok(()),
731 }?;
732
733 let mut row_builder = SharedRow::get();
734 let mut row_packer = row_builder.packer();
735
736 for (typ, raw_value) in column_types.iter().zip_eq(record.iter()) {
737 if raw_value == null_as_bytes {
738 row_packer.push(Datum::Null);
739 } else {
740 let s = match std::str::from_utf8(raw_value) {
741 Ok(s) => s,
742 Err(err) => {
743 let msg = format!("invalid utf8 data in column: {}", err);
744 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
745 }
746 };
747 match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
748 Ok(()) => {}
749 Err(err) => {
750 let msg = format!("unable to decode column: {}", err);
751 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
752 }
753 }
754 }
755 }
756 rows.push(row_builder.clone());
757 }
758
759 Ok(rows)
760}
761
762#[cfg(test)]
763mod tests {
764 use mz_ore::collections::CollectionExt;
765 use mz_repr::SqlColumnType;
766 use proptest::prelude::*;
767
768 use super::*;
769
770 #[mz_ore::test]
771 fn test_copy_format_text_parser() {
772 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
773 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
774 assert!(parser.is_column_delimiter());
775 parser
776 .expect_column_delimiter()
777 .expect("expected column delimiter");
778 assert_eq!(
779 parser
780 .consume_raw_value()
781 .expect("unexpected error")
782 .expect("unexpected empty result"),
783 "\nt e".as_bytes()
784 );
785 parser
786 .expect_column_delimiter()
787 .expect("expected column delimiter");
788 assert!(
790 parser
791 .consume_raw_value()
792 .expect("unexpected error")
793 .is_none()
794 );
795 parser
796 .expect_column_delimiter()
797 .expect("expected column delimiter");
798 assert!(parser.is_end_of_line());
799 parser.expect_end_of_line().expect("expected eol");
800 assert_eq!(
802 parser
803 .consume_raw_value()
804 .expect("unexpected error")
805 .expect("unexpected empty result"),
806 "`\n}J".as_bytes()
807 );
808 parser.expect_end_of_line().expect("expected eol");
809 assert_eq!(
811 parser
812 .consume_raw_value()
813 .expect("unexpected error")
814 .expect("unexpected empty result"),
815 "$$S".as_bytes()
816 );
817 assert!(parser.is_eof());
818 }
819
820 #[mz_ore::test]
821 fn test_copy_format_text_empty_null_string() {
822 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
823 let expect = vec![
824 vec![None, None],
825 vec![Some("10"), Some("20")],
826 vec![Some("30"), None],
827 vec![Some("40"), None],
828 ];
829 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
830 for line in expect {
831 for (i, value) in line.iter().enumerate() {
832 if i > 0 {
833 parser
834 .expect_column_delimiter()
835 .expect("expected column delimiter");
836 }
837 match value {
838 Some(s) => {
839 assert!(!parser.consume_null_string());
840 assert_eq!(
841 parser
842 .consume_raw_value()
843 .expect("unexpected error")
844 .expect("unexpected empty result"),
845 s.as_bytes()
846 );
847 }
848 None => {
849 assert!(parser.consume_null_string());
850 }
851 }
852 }
853 parser.expect_end_of_line().expect("expected eol");
854 }
855 }
856
857 #[mz_ore::test]
858 fn test_copy_format_text_parser_escapes() {
859 struct TestCase {
860 input: &'static str,
861 expect: &'static [u8],
862 }
863 let tests = vec![
864 TestCase {
865 input: "simple",
866 expect: b"simple",
867 },
868 TestCase {
869 input: r#"new\nline"#,
870 expect: b"new\nline",
871 },
872 TestCase {
873 input: r#"\b\f\n\r\t\v\\"#,
874 expect: b"\x08\x0c\n\r\t\x0b\\",
875 },
876 TestCase {
877 input: r#"\0\12\123"#,
878 expect: &[0, 0o12, 0o123],
879 },
880 TestCase {
881 input: r#"\x1\xaf"#,
882 expect: &[0x01, 0xaf],
883 },
884 TestCase {
885 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
886 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
887 },
888 TestCase {
889 input: r#"\\\""#,
890 expect: b"\\\"",
891 },
892 TestCase {
893 input: r#"\x"#,
894 expect: b"x",
895 },
896 TestCase {
897 input: r#"\xg"#,
898 expect: b"xg",
899 },
900 TestCase {
901 input: r#"\"#,
902 expect: b"\\",
903 },
904 TestCase {
905 input: r#"\8"#,
906 expect: b"8",
907 },
908 TestCase {
909 input: r#"\a"#,
910 expect: b"a",
911 },
912 TestCase {
913 input: r#"\x\xg\8\xH\x32\s\"#,
914 expect: b"xxg8xH2s\\",
915 },
916 ];
917
918 for test in tests {
919 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
920 assert_eq!(
921 parser
922 .consume_raw_value()
923 .expect("unexpected error")
924 .expect("unexpected empty result"),
925 test.expect,
926 "input: {}, expect: {:?}",
927 test.input,
928 std::str::from_utf8(test.expect),
929 );
930 assert!(parser.is_eof());
931 }
932 }
933
934 #[mz_ore::test]
935 fn test_copy_csv_format_params() {
936 assert_eq!(
937 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
938 Ok(CopyCsvFormatParams {
939 delimiter: b't',
940 quote: b'q',
941 escape: b'q',
942 header: false,
943 null: Cow::from(""),
944 })
945 );
946
947 assert_eq!(
948 CopyCsvFormatParams::try_new(
949 Some(b't'),
950 Some(b'q'),
951 Some(b'e'),
952 Some(true),
953 Some("null".to_string())
954 ),
955 Ok(CopyCsvFormatParams {
956 delimiter: b't',
957 quote: b'q',
958 escape: b'e',
959 header: true,
960 null: Cow::from("null"),
961 })
962 );
963
964 assert_eq!(
965 CopyCsvFormatParams::try_new(
966 None,
967 Some(b','),
968 Some(b'e'),
969 Some(true),
970 Some("null".to_string())
971 ),
972 Err("COPY delimiter and quote must be different".to_string())
973 );
974 }
975
976 #[mz_ore::test]
977 fn test_copy_csv_row() -> Result<(), io::Error> {
978 let mut row = Row::default();
979 let mut packer = row.packer();
980 packer.push(Datum::from("1,2,\"3\""));
981 packer.push(Datum::Null);
982 packer.push(Datum::from(1000u64));
983 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
985
986 let typ: SqlRelationType = SqlRelationType::new(vec![
987 SqlColumnType {
988 scalar_type: mz_repr::SqlScalarType::String,
989 nullable: false,
990 },
991 SqlColumnType {
992 scalar_type: mz_repr::SqlScalarType::String,
993 nullable: true,
994 },
995 SqlColumnType {
996 scalar_type: mz_repr::SqlScalarType::UInt64,
997 nullable: false,
998 },
999 SqlColumnType {
1000 scalar_type: mz_repr::SqlScalarType::String,
1001 nullable: false,
1002 },
1003 SqlColumnType {
1004 scalar_type: mz_repr::SqlScalarType::String,
1005 nullable: false,
1006 },
1007 ]);
1008
1009 let mut out = Vec::new();
1010
1011 struct TestCase<'a> {
1012 params: CopyCsvFormatParams<'a>,
1013 expected: &'static [u8],
1014 }
1015
1016 let tests = [
1017 TestCase {
1018 params: CopyCsvFormatParams::default(),
1019 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1020 },
1021 TestCase {
1022 params: CopyCsvFormatParams {
1023 null: Cow::from("NULL"),
1024 quote: b'q',
1025 escape: b'e',
1026 ..Default::default()
1027 },
1028 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1029 },
1030 ];
1031
1032 for TestCase { params, expected } in tests {
1033 out.clear();
1034 let params = CopyFormatParams::Csv(params);
1035 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1036 let output = std::str::from_utf8(&out);
1037 assert_eq!(output, std::str::from_utf8(expected));
1038 }
1039
1040 Ok(())
1041 }
1042
1043 proptest! {
1044 #[mz_ore::test]
1045 #[cfg_attr(miri, ignore)]
1046 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1047 let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1049 let row = Row::pack_slice(&[datum]);
1050 let typ = SqlRelationType::new(vec![
1051 SqlColumnType {
1052 scalar_type: scalar_type.clone(),
1053 nullable: true,
1054 }
1055 ]);
1056
1057 let mut buf = Vec::new();
1058 let mut csv_params = copy_csv_params.clone();
1059 csv_params.header = false;
1061 let params = CopyFormatParams::Csv(csv_params);
1062
1063 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1065 let column_types = typ
1066 .column_types
1067 .iter()
1068 .map(|x| &x.scalar_type)
1069 .map(mz_pgrepr::Type::from)
1070 .collect::<Vec<mz_pgrepr::Type>>();
1071 let result = decode_copy_format(&buf, &column_types, params);
1072
1073 match result {
1074 Ok(rows) => {
1075 let out_str = std::str::from_utf8(&buf[..]);
1076
1077 prop_assert_eq!(
1078 rows.len(),
1079 1,
1080 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1081 );
1082 let output = rows.into_element();
1083
1084 prop_assert_eq!(
1085 row,
1086 output,
1087 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1088 );
1089 }
1090 _ => {
1091 }
1093 }
1094
1095 Ok(())
1096 };
1097
1098 for scalar_type in SqlScalarType::enumerate() {
1100 for datum in scalar_type.interesting_datums() {
1101 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1103 let mut buf = bytes::BytesMut::new();
1104 value.encode_text(&mut buf);
1105
1106 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1107 if datum_str == copy_csv_params.null {
1108 continue;
1109 }
1110 }
1111 }
1112
1113 let updated_datum = match datum {
1114 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1116 continue;
1117 }
1118 Datum::String(s) => {
1119 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1121 continue;
1122 } else {
1123 Datum::String(s)
1124 }
1125 }
1126 other => other,
1127 };
1128
1129 let result = try_roundtrip_datum(scalar_type, updated_datum);
1130 prop_assert!(result.is_ok(), "failure: {result:?}");
1131 }
1132 }
1133 }
1134 }
1135}