1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use itertools::Itertools;
16use mz_repr::{
17 Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18 SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, any};
21use proptest::strategy::{BoxedStrategy, Strategy};
22use serde::{Deserialize, Serialize};
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26fn encode_copy_row_binary(
27 row: &RowRef,
28 typ: &SqlRelationType,
29 out: &mut Vec<u8>,
30) -> Result<(), io::Error> {
31 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
32
33 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
35 io::Error::new(
36 io::ErrorKind::Other,
37 "column count does not fit into an i16",
38 )
39 })?;
40
41 out.extend(count.to_be_bytes());
42 let mut buf = BytesMut::new();
43 for (field, typ) in row
44 .iter()
45 .zip_eq(&typ.column_types)
46 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
47 {
48 match field {
49 None => out.extend(NULL_BYTES),
50 Some(field) => {
51 buf.clear();
52 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
53 out.extend(
54 i32::try_from(buf.len())
55 .map_err(|_| {
56 io::Error::new(
57 io::ErrorKind::Other,
58 "field length does not fit into an i32",
59 )
60 })?
61 .to_be_bytes(),
62 );
63 out.extend(&buf);
64 }
65 }
66 }
67 Ok(())
68}
69
70fn encode_copy_row_text(
71 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
72 row: &RowRef,
73 typ: &SqlRelationType,
74 out: &mut Vec<u8>,
75) -> Result<(), io::Error> {
76 let null = null.as_bytes();
77 let mut buf = BytesMut::new();
78 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
79 if idx > 0 {
80 out.push(*delimiter);
81 }
82 match field {
83 None => out.extend(null),
84 Some(field) => {
85 buf.clear();
86 field.encode_text(&mut buf);
87 for b in &buf {
88 match b {
89 b'\\' => out.extend(b"\\\\"),
90 b'\n' => out.extend(b"\\n"),
91 b'\r' => out.extend(b"\\r"),
92 b'\t' => out.extend(b"\\t"),
93 _ => out.push(*b),
94 }
95 }
96 }
97 }
98 }
99 out.push(b'\n');
100 Ok(())
101}
102
103fn encode_copy_row_csv(
104 CopyCsvFormatParams {
105 delimiter: delim,
106 quote,
107 escape,
108 header: _,
109 null,
110 }: &CopyCsvFormatParams,
111 row: &RowRef,
112 typ: &SqlRelationType,
113 out: &mut Vec<u8>,
114) -> Result<(), io::Error> {
115 let null = null.as_bytes();
116 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
117 let mut buf = BytesMut::new();
118 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
119 if idx > 0 {
120 out.push(*delim);
121 }
122 match field {
123 None => out.extend(null),
124 Some(field) => {
125 buf.clear();
126 field.encode_text(&mut buf);
127 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
133 || buf.iter().any(is_special)
134 || &*buf == null
135 {
136 out.push(*quote);
140 for b in &buf {
141 if *b == *quote || *b == *escape {
142 out.push(*escape);
143 }
144 out.push(*b);
145 }
146 out.push(*quote);
147 } else {
148 out.extend(&buf);
151 }
152 }
153 }
154 }
155 out.push(b'\n');
156 Ok(())
157}
158
159pub struct CopyTextFormatParser<'a> {
160 data: &'a [u8],
161 position: usize,
162 column_delimiter: u8,
163 null_string: &'a str,
164 buffer: Vec<u8>,
165}
166
167impl<'a> CopyTextFormatParser<'a> {
168 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
169 Self {
170 data,
171 position: 0,
172 column_delimiter,
173 null_string,
174 buffer: Vec::new(),
175 }
176 }
177
178 fn peek(&self) -> Option<u8> {
179 if self.position < self.data.len() {
180 Some(self.data[self.position])
181 } else {
182 None
183 }
184 }
185
186 fn consume_n(&mut self, n: usize) {
187 self.position = std::cmp::min(self.position + n, self.data.len());
188 }
189
190 pub fn is_eof(&self) -> bool {
191 self.peek().is_none() || self.is_end_of_copy_marker()
192 }
193
194 pub fn is_end_of_copy_marker(&self) -> bool {
195 self.check_bytes(END_OF_COPY_MARKER)
196 }
197
198 fn is_end_of_line(&self) -> bool {
199 match self.peek() {
200 Some(b'\n') | None => true,
201 _ => false,
202 }
203 }
204
205 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
206 if self.is_end_of_line() {
207 self.consume_n(1);
208 Ok(())
209 } else {
210 Err(io::Error::new(
211 io::ErrorKind::InvalidData,
212 "extra data after last expected column",
213 ))
214 }
215 }
216
217 fn is_column_delimiter(&self) -> bool {
218 self.check_bytes(&[self.column_delimiter])
219 }
220
221 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
222 if self.consume_bytes(&[self.column_delimiter]) {
223 Ok(())
224 } else {
225 Err(io::Error::new(
226 io::ErrorKind::InvalidData,
227 "missing data for column",
228 ))
229 }
230 }
231
232 fn check_bytes(&self, bytes: &[u8]) -> bool {
233 self.data
234 .get(self.position..self.position + bytes.len())
235 .map_or(false, |d| d == bytes)
236 }
237
238 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
239 if self.check_bytes(bytes) {
240 self.consume_n(bytes.len());
241 true
242 } else {
243 false
244 }
245 }
246
247 fn consume_null_string(&mut self) -> bool {
248 if self.null_string.is_empty() {
249 self.is_column_delimiter()
252 || self.is_end_of_line()
253 || self.is_end_of_copy_marker()
254 || self.is_eof()
255 } else {
256 self.consume_bytes(self.null_string.as_bytes())
257 }
258 }
259
260 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
261 if self.consume_null_string() {
262 return Ok(None);
263 }
264
265 let mut start = self.position;
266
267 self.buffer.clear();
269
270 while !self.is_eof() && !self.is_end_of_copy_marker() {
271 if self.is_end_of_line() || self.is_column_delimiter() {
272 break;
273 }
274 match self.peek() {
275 Some(b'\\') => {
276 self.buffer.extend(&self.data[start..self.position]);
278
279 self.consume_n(1);
280 match self.peek() {
281 Some(b'b') => {
282 self.consume_n(1);
283 self.buffer.push(8);
284 }
285 Some(b'f') => {
286 self.consume_n(1);
287 self.buffer.push(12);
288 }
289 Some(b'n') => {
290 self.consume_n(1);
291 self.buffer.push(b'\n');
292 }
293 Some(b'r') => {
294 self.consume_n(1);
295 self.buffer.push(b'\r');
296 }
297 Some(b't') => {
298 self.consume_n(1);
299 self.buffer.push(b'\t');
300 }
301 Some(b'v') => {
302 self.consume_n(1);
303 self.buffer.push(11);
304 }
305 Some(b'x') => {
306 self.consume_n(1);
307 match self.peek() {
308 Some(_c @ b'0'..=b'9')
309 | Some(_c @ b'A'..=b'F')
310 | Some(_c @ b'a'..=b'f') => {
311 let mut value: u8 = 0;
312 let decode_nibble = |b| match b {
313 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
314 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
315 Some(c @ b'0'..=b'9') => Some(c - b'0'),
316 _ => None,
317 };
318 for _ in 0..2 {
319 match decode_nibble(self.peek()) {
320 Some(c) => {
321 self.consume_n(1);
322 value = value << 4 | c;
323 }
324 _ => break,
325 }
326 }
327 self.buffer.push(value);
328 }
329 _ => {
330 self.buffer.push(b'x');
331 }
332 }
333 }
334 Some(_c @ b'0'..=b'7') => {
335 let mut value: u8 = 0;
336 for _ in 0..3 {
337 match self.peek() {
338 Some(c @ b'0'..=b'7') => {
339 self.consume_n(1);
340 value = value << 3 | (c - b'0');
341 }
342 _ => break,
343 }
344 }
345 self.buffer.push(value);
346 }
347 Some(c) => {
348 self.consume_n(1);
349 self.buffer.push(c);
350 }
351 None => {
352 self.buffer.push(b'\\');
353 }
354 }
355
356 start = self.position;
357 }
358 Some(_) => {
359 self.consume_n(1);
360 }
361 None => {}
362 }
363 }
364
365 if self.buffer.is_empty() {
367 Ok(Some(&self.data[start..self.position]))
368 } else {
369 self.buffer.extend(&self.data[start..self.position]);
372 Ok(Some(&self.buffer[..]))
373 }
374 }
375
376 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
378 RawIterator {
379 parser: self,
380 current_column: 0,
381 num_columns,
382 truncate: false,
383 }
384 }
385
386 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
388 RawIterator {
389 parser: self,
390 current_column: 0,
391 num_columns,
392 truncate: true,
393 }
394 }
395}
396
397pub struct RawIterator<'a> {
398 parser: CopyTextFormatParser<'a>,
399 current_column: usize,
400 num_columns: usize,
401 truncate: bool,
402}
403
404impl<'a> RawIterator<'a> {
405 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
406 if self.current_column > self.num_columns {
407 return None;
408 }
409
410 if self.current_column == self.num_columns {
411 if !self.truncate {
412 if let Some(err) = self.parser.expect_end_of_line().err() {
413 return Some(Err(err));
414 }
415 }
416
417 return None;
418 }
419
420 if self.current_column > 0 {
421 if let Some(err) = self.parser.expect_column_delimiter().err() {
422 return Some(Err(err));
423 }
424 }
425
426 self.current_column += 1;
427 Some(self.parser.consume_raw_value())
428 }
429}
430
431#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
432pub enum CopyFormatParams<'a> {
433 Text(CopyTextFormatParams<'a>),
434 Csv(CopyCsvFormatParams<'a>),
435 Binary,
436 Parquet,
437}
438
439impl CopyFormatParams<'static> {
440 pub fn file_extension(&self) -> &str {
441 match self {
442 &CopyFormatParams::Text(_) => "txt",
443 &CopyFormatParams::Csv(_) => "csv",
444 &CopyFormatParams::Binary => "bin",
445 &CopyFormatParams::Parquet => "parquet",
446 }
447 }
448
449 pub fn requires_header(&self) -> bool {
450 match self {
451 CopyFormatParams::Text(_) => false,
452 CopyFormatParams::Csv(params) => params.header,
453 CopyFormatParams::Binary => false,
454 CopyFormatParams::Parquet => false,
455 }
456 }
457}
458
459pub fn decode_copy_format<'a>(
461 data: &[u8],
462 column_types: &[mz_pgrepr::Type],
463 params: CopyFormatParams<'a>,
464) -> Result<Vec<Row>, io::Error> {
465 match params {
466 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
467 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
468 CopyFormatParams::Binary => Err(io::Error::new(
469 io::ErrorKind::Unsupported,
470 "cannot decode as binary format",
471 )),
472 CopyFormatParams::Parquet => {
473 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
475 }
476 }
477}
478
479pub fn encode_copy_format<'a>(
481 params: &CopyFormatParams<'a>,
482 row: &RowRef,
483 typ: &SqlRelationType,
484 out: &mut Vec<u8>,
485) -> Result<(), io::Error> {
486 match params {
487 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
488 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
489 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
490 CopyFormatParams::Parquet => {
491 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
493 }
494 }
495}
496
497pub fn encode_copy_format_header<'a>(
498 params: &CopyFormatParams<'a>,
499 desc: &RelationDesc,
500 out: &mut Vec<u8>,
501) -> Result<(), io::Error> {
502 match params {
503 CopyFormatParams::Text(_) => Ok(()),
504 CopyFormatParams::Binary => Ok(()),
505 CopyFormatParams::Csv(params) => {
506 let mut header_row = Row::with_capacity(desc.arity());
507 header_row
508 .packer()
509 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
510 let typ = SqlRelationType::new(vec![
511 SqlColumnType {
512 scalar_type: SqlScalarType::String,
513 nullable: false,
514 };
515 desc.arity()
516 ]);
517 encode_copy_row_csv(params, &header_row, &typ, out)
518 }
519 CopyFormatParams::Parquet => {
520 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
522 }
523 }
524}
525
526#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
527pub struct CopyTextFormatParams<'a> {
528 pub null: Cow<'a, str>,
529 pub delimiter: u8,
530}
531
532impl<'a> Default for CopyTextFormatParams<'a> {
533 fn default() -> Self {
534 CopyTextFormatParams {
535 delimiter: b'\t',
536 null: Cow::from("\\N"),
537 }
538 }
539}
540
541pub fn decode_copy_format_text(
542 data: &[u8],
543 column_types: &[mz_pgrepr::Type],
544 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
545) -> Result<Vec<Row>, io::Error> {
546 let mut rows = Vec::new();
547
548 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
550 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
551 let mut row = Vec::new();
552 let buf = RowArena::new();
553 for (col, typ) in column_types.iter().enumerate() {
554 if col > 0 {
555 parser.expect_column_delimiter()?;
556 }
557 let raw_value = parser.consume_raw_value()?;
558 if let Some(raw_value) = raw_value {
559 match mz_pgrepr::Value::decode_text(typ, raw_value) {
560 Ok(value) => row.push(value.into_datum(&buf, typ)),
561 Err(err) => {
562 let msg = format!("unable to decode column: {}", err);
563 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
564 }
565 }
566 } else {
567 row.push(Datum::Null);
568 }
569 }
570 parser.expect_end_of_line()?;
571 rows.push(Row::pack(row));
572 }
573 Ok(rows)
576}
577
578#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
579pub struct CopyCsvFormatParams<'a> {
580 pub delimiter: u8,
581 pub quote: u8,
582 pub escape: u8,
583 pub header: bool,
584 pub null: Cow<'a, str>,
585}
586
587impl<'a> CopyCsvFormatParams<'a> {
588 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
589 CopyCsvFormatParams {
590 delimiter: self.delimiter,
591 quote: self.quote,
592 escape: self.escape,
593 header: self.header,
594 null: Cow::Owned(self.null.to_string()),
595 }
596 }
597}
598
599impl Arbitrary for CopyCsvFormatParams<'static> {
600 type Parameters = ();
601 type Strategy = BoxedStrategy<Self>;
602
603 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
604 (
605 any::<u8>(),
606 any::<u8>(),
607 any::<u8>(),
608 any::<bool>(),
609 any::<String>(),
610 )
611 .prop_map(|(delimiter, diff, escape, header, null)| {
612 let diff = diff.saturating_sub(1).max(1);
614 let quote = delimiter.wrapping_add(diff);
615
616 Self::try_new(
617 Some(delimiter),
618 Some(quote),
619 Some(escape),
620 Some(header),
621 Some(null),
622 )
623 .expect("delimiter and quote should be different")
624 })
625 .boxed()
626 }
627}
628
629impl<'a> Default for CopyCsvFormatParams<'a> {
630 fn default() -> Self {
631 CopyCsvFormatParams {
632 delimiter: b',',
633 quote: b'"',
634 escape: b'"',
635 header: false,
636 null: Cow::from(""),
637 }
638 }
639}
640
641impl<'a> CopyCsvFormatParams<'a> {
642 pub fn try_new(
643 delimiter: Option<u8>,
644 quote: Option<u8>,
645 escape: Option<u8>,
646 header: Option<bool>,
647 null: Option<String>,
648 ) -> Result<CopyCsvFormatParams<'a>, String> {
649 let mut params = CopyCsvFormatParams::default();
650
651 if let Some(delimiter) = delimiter {
652 params.delimiter = delimiter;
653 }
654 if let Some(quote) = quote {
655 params.quote = quote;
656 params.escape = quote;
658 }
659 if let Some(escape) = escape {
660 params.escape = escape;
661 }
662 if let Some(header) = header {
663 params.header = header;
664 }
665 if let Some(null) = null {
666 params.null = Cow::from(null);
667 }
668
669 if params.quote == params.delimiter {
670 return Err("COPY delimiter and quote must be different".to_string());
671 }
672 Ok(params)
673 }
674}
675
676pub fn decode_copy_format_csv(
677 data: &[u8],
678 column_types: &[mz_pgrepr::Type],
679 CopyCsvFormatParams {
680 delimiter,
681 quote,
682 escape,
683 null,
684 header,
685 }: CopyCsvFormatParams,
686) -> Result<Vec<Row>, io::Error> {
687 let mut rows = Vec::new();
688
689 let (double_quote, escape) = if quote == escape {
690 (true, None)
691 } else {
692 (false, Some(escape))
693 };
694
695 let mut rdr = ReaderBuilder::new()
696 .delimiter(delimiter)
697 .quote(quote)
698 .has_headers(header)
699 .double_quote(double_quote)
700 .escape(escape)
701 .flexible(true)
704 .from_reader(data);
705
706 let null_as_bytes = null.as_bytes();
707
708 let mut record = ByteRecord::new();
709
710 while rdr.read_byte_record(&mut record)? {
711 if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
712 break;
713 }
714
715 match record.len().cmp(&column_types.len()) {
716 std::cmp::Ordering::Less => Err(io::Error::new(
717 io::ErrorKind::InvalidData,
718 "missing data for column",
719 )),
720 std::cmp::Ordering::Greater => Err(io::Error::new(
721 io::ErrorKind::InvalidData,
722 "extra data after last expected column",
723 )),
724 std::cmp::Ordering::Equal => Ok(()),
725 }?;
726
727 let mut row_builder = SharedRow::get();
728 let mut row_packer = row_builder.packer();
729
730 for (typ, raw_value) in column_types.iter().zip_eq(record.iter()) {
731 if raw_value == null_as_bytes {
732 row_packer.push(Datum::Null);
733 } else {
734 let s = match std::str::from_utf8(raw_value) {
735 Ok(s) => s,
736 Err(err) => {
737 let msg = format!("invalid utf8 data in column: {}", err);
738 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
739 }
740 };
741 match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
742 Ok(()) => {}
743 Err(err) => {
744 let msg = format!("unable to decode column: {}", err);
745 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
746 }
747 }
748 }
749 }
750 rows.push(row_builder.clone());
751 }
752
753 Ok(rows)
754}
755
756#[cfg(test)]
757mod tests {
758 use mz_ore::collections::CollectionExt;
759 use mz_repr::SqlColumnType;
760 use proptest::prelude::*;
761
762 use super::*;
763
764 #[mz_ore::test]
765 fn test_copy_format_text_parser() {
766 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
767 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
768 assert!(parser.is_column_delimiter());
769 parser
770 .expect_column_delimiter()
771 .expect("expected column delimiter");
772 assert_eq!(
773 parser
774 .consume_raw_value()
775 .expect("unexpected error")
776 .expect("unexpected empty result"),
777 "\nt e".as_bytes()
778 );
779 parser
780 .expect_column_delimiter()
781 .expect("expected column delimiter");
782 assert!(
784 parser
785 .consume_raw_value()
786 .expect("unexpected error")
787 .is_none()
788 );
789 parser
790 .expect_column_delimiter()
791 .expect("expected column delimiter");
792 assert!(parser.is_end_of_line());
793 parser.expect_end_of_line().expect("expected eol");
794 assert_eq!(
796 parser
797 .consume_raw_value()
798 .expect("unexpected error")
799 .expect("unexpected empty result"),
800 "`\n}J".as_bytes()
801 );
802 parser.expect_end_of_line().expect("expected eol");
803 assert_eq!(
805 parser
806 .consume_raw_value()
807 .expect("unexpected error")
808 .expect("unexpected empty result"),
809 "$$S".as_bytes()
810 );
811 assert!(parser.is_eof());
812 }
813
814 #[mz_ore::test]
815 fn test_copy_format_text_empty_null_string() {
816 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
817 let expect = vec![
818 vec![None, None],
819 vec![Some("10"), Some("20")],
820 vec![Some("30"), None],
821 vec![Some("40"), None],
822 ];
823 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
824 for line in expect {
825 for (i, value) in line.iter().enumerate() {
826 if i > 0 {
827 parser
828 .expect_column_delimiter()
829 .expect("expected column delimiter");
830 }
831 match value {
832 Some(s) => {
833 assert!(!parser.consume_null_string());
834 assert_eq!(
835 parser
836 .consume_raw_value()
837 .expect("unexpected error")
838 .expect("unexpected empty result"),
839 s.as_bytes()
840 );
841 }
842 None => {
843 assert!(parser.consume_null_string());
844 }
845 }
846 }
847 parser.expect_end_of_line().expect("expected eol");
848 }
849 }
850
851 #[mz_ore::test]
852 fn test_copy_format_text_parser_escapes() {
853 struct TestCase {
854 input: &'static str,
855 expect: &'static [u8],
856 }
857 let tests = vec![
858 TestCase {
859 input: "simple",
860 expect: b"simple",
861 },
862 TestCase {
863 input: r#"new\nline"#,
864 expect: b"new\nline",
865 },
866 TestCase {
867 input: r#"\b\f\n\r\t\v\\"#,
868 expect: b"\x08\x0c\n\r\t\x0b\\",
869 },
870 TestCase {
871 input: r#"\0\12\123"#,
872 expect: &[0, 0o12, 0o123],
873 },
874 TestCase {
875 input: r#"\x1\xaf"#,
876 expect: &[0x01, 0xaf],
877 },
878 TestCase {
879 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
880 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
881 },
882 TestCase {
883 input: r#"\\\""#,
884 expect: b"\\\"",
885 },
886 TestCase {
887 input: r#"\x"#,
888 expect: b"x",
889 },
890 TestCase {
891 input: r#"\xg"#,
892 expect: b"xg",
893 },
894 TestCase {
895 input: r#"\"#,
896 expect: b"\\",
897 },
898 TestCase {
899 input: r#"\8"#,
900 expect: b"8",
901 },
902 TestCase {
903 input: r#"\a"#,
904 expect: b"a",
905 },
906 TestCase {
907 input: r#"\x\xg\8\xH\x32\s\"#,
908 expect: b"xxg8xH2s\\",
909 },
910 ];
911
912 for test in tests {
913 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
914 assert_eq!(
915 parser
916 .consume_raw_value()
917 .expect("unexpected error")
918 .expect("unexpected empty result"),
919 test.expect,
920 "input: {}, expect: {:?}",
921 test.input,
922 std::str::from_utf8(test.expect),
923 );
924 assert!(parser.is_eof());
925 }
926 }
927
928 #[mz_ore::test]
929 fn test_copy_csv_format_params() {
930 assert_eq!(
931 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
932 Ok(CopyCsvFormatParams {
933 delimiter: b't',
934 quote: b'q',
935 escape: b'q',
936 header: false,
937 null: Cow::from(""),
938 })
939 );
940
941 assert_eq!(
942 CopyCsvFormatParams::try_new(
943 Some(b't'),
944 Some(b'q'),
945 Some(b'e'),
946 Some(true),
947 Some("null".to_string())
948 ),
949 Ok(CopyCsvFormatParams {
950 delimiter: b't',
951 quote: b'q',
952 escape: b'e',
953 header: true,
954 null: Cow::from("null"),
955 })
956 );
957
958 assert_eq!(
959 CopyCsvFormatParams::try_new(
960 None,
961 Some(b','),
962 Some(b'e'),
963 Some(true),
964 Some("null".to_string())
965 ),
966 Err("COPY delimiter and quote must be different".to_string())
967 );
968 }
969
970 #[mz_ore::test]
971 fn test_copy_csv_row() -> Result<(), io::Error> {
972 let mut row = Row::default();
973 let mut packer = row.packer();
974 packer.push(Datum::from("1,2,\"3\""));
975 packer.push(Datum::Null);
976 packer.push(Datum::from(1000u64));
977 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
979
980 let typ: SqlRelationType = SqlRelationType::new(vec![
981 SqlColumnType {
982 scalar_type: mz_repr::SqlScalarType::String,
983 nullable: false,
984 },
985 SqlColumnType {
986 scalar_type: mz_repr::SqlScalarType::String,
987 nullable: true,
988 },
989 SqlColumnType {
990 scalar_type: mz_repr::SqlScalarType::UInt64,
991 nullable: false,
992 },
993 SqlColumnType {
994 scalar_type: mz_repr::SqlScalarType::String,
995 nullable: false,
996 },
997 SqlColumnType {
998 scalar_type: mz_repr::SqlScalarType::String,
999 nullable: false,
1000 },
1001 ]);
1002
1003 let mut out = Vec::new();
1004
1005 struct TestCase<'a> {
1006 params: CopyCsvFormatParams<'a>,
1007 expected: &'static [u8],
1008 }
1009
1010 let tests = [
1011 TestCase {
1012 params: CopyCsvFormatParams::default(),
1013 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1014 },
1015 TestCase {
1016 params: CopyCsvFormatParams {
1017 null: Cow::from("NULL"),
1018 quote: b'q',
1019 escape: b'e',
1020 ..Default::default()
1021 },
1022 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1023 },
1024 ];
1025
1026 for TestCase { params, expected } in tests {
1027 out.clear();
1028 let params = CopyFormatParams::Csv(params);
1029 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1030 let output = std::str::from_utf8(&out);
1031 assert_eq!(output, std::str::from_utf8(expected));
1032 }
1033
1034 Ok(())
1035 }
1036
1037 proptest! {
1038 #[mz_ore::test]
1039 #[cfg_attr(miri, ignore)]
1040 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1041 let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1043 let row = Row::pack_slice(&[datum]);
1044 let typ = SqlRelationType::new(vec![
1045 SqlColumnType {
1046 scalar_type: scalar_type.clone(),
1047 nullable: true,
1048 }
1049 ]);
1050
1051 let mut buf = Vec::new();
1052 let mut csv_params = copy_csv_params.clone();
1053 csv_params.header = false;
1055 let params = CopyFormatParams::Csv(csv_params);
1056
1057 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1059 let column_types = typ
1060 .column_types
1061 .iter()
1062 .map(|x| &x.scalar_type)
1063 .map(mz_pgrepr::Type::from)
1064 .collect::<Vec<mz_pgrepr::Type>>();
1065 let result = decode_copy_format(&buf, &column_types, params);
1066
1067 match result {
1068 Ok(rows) => {
1069 let out_str = std::str::from_utf8(&buf[..]);
1070
1071 prop_assert_eq!(
1072 rows.len(),
1073 1,
1074 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1075 );
1076 let output = rows.into_element();
1077
1078 prop_assert_eq!(
1079 row,
1080 output,
1081 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1082 );
1083 }
1084 _ => {
1085 }
1087 }
1088
1089 Ok(())
1090 };
1091
1092 for scalar_type in SqlScalarType::enumerate() {
1094 for datum in scalar_type.interesting_datums() {
1095 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1097 let mut buf = bytes::BytesMut::new();
1098 value.encode_text(&mut buf);
1099
1100 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1101 if datum_str == copy_csv_params.null {
1102 continue;
1103 }
1104 }
1105 }
1106
1107 let updated_datum = match datum {
1108 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1110 continue;
1111 }
1112 Datum::String(s) => {
1113 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1115 continue;
1116 } else {
1117 Datum::String(s)
1118 }
1119 }
1120 other => other,
1121 };
1122
1123 let result = try_roundtrip_datum(scalar_type, updated_datum);
1124 prop_assert!(result.is_ok(), "failure: {result:?}");
1125 }
1126 }
1127 }
1128 }
1129}