1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use mz_proto::{ProtoType, RustType, TryFromProtoError};
16use mz_repr::{
17 ColumnType, Datum, RelationDesc, RelationType, Row, RowArena, RowRef, ScalarType, SharedRow,
18};
19use proptest::prelude::{Arbitrary, Just, any};
20use proptest::strategy::{BoxedStrategy, Strategy, Union};
21use serde::Deserialize;
22use serde::Serialize;
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26include!(concat!(env!("OUT_DIR"), "/mz_pgcopy.copy.rs"));
27
28fn encode_copy_row_binary(
29 row: &RowRef,
30 typ: &RelationType,
31 out: &mut Vec<u8>,
32) -> Result<(), io::Error> {
33 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
34
35 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
37 io::Error::new(
38 io::ErrorKind::Other,
39 "column count does not fit into an i16",
40 )
41 })?;
42
43 out.extend(count.to_be_bytes());
44 let mut buf = BytesMut::new();
45 for (field, typ) in row
46 .iter()
47 .zip(&typ.column_types)
48 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
49 {
50 match field {
51 None => out.extend(NULL_BYTES),
52 Some(field) => {
53 buf.clear();
54 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
55 out.extend(
56 i32::try_from(buf.len())
57 .map_err(|_| {
58 io::Error::new(
59 io::ErrorKind::Other,
60 "field length does not fit into an i32",
61 )
62 })?
63 .to_be_bytes(),
64 );
65 out.extend(&buf);
66 }
67 }
68 }
69 Ok(())
70}
71
72fn encode_copy_row_text(
73 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
74 row: &RowRef,
75 typ: &RelationType,
76 out: &mut Vec<u8>,
77) -> Result<(), io::Error> {
78 let null = null.as_bytes();
79 let mut buf = BytesMut::new();
80 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
81 if idx > 0 {
82 out.push(*delimiter);
83 }
84 match field {
85 None => out.extend(null),
86 Some(field) => {
87 buf.clear();
88 field.encode_text(&mut buf);
89 for b in &buf {
90 match b {
91 b'\\' => out.extend(b"\\\\"),
92 b'\n' => out.extend(b"\\n"),
93 b'\r' => out.extend(b"\\r"),
94 b'\t' => out.extend(b"\\t"),
95 _ => out.push(*b),
96 }
97 }
98 }
99 }
100 }
101 out.push(b'\n');
102 Ok(())
103}
104
105fn encode_copy_row_csv(
106 CopyCsvFormatParams {
107 delimiter: delim,
108 quote,
109 escape,
110 header: _,
111 null,
112 }: &CopyCsvFormatParams,
113 row: &RowRef,
114 typ: &RelationType,
115 out: &mut Vec<u8>,
116) -> Result<(), io::Error> {
117 let null = null.as_bytes();
118 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
119 let mut buf = BytesMut::new();
120 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
121 if idx > 0 {
122 out.push(*delim);
123 }
124 match field {
125 None => out.extend(null),
126 Some(field) => {
127 buf.clear();
128 field.encode_text(&mut buf);
129 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
135 || buf.iter().any(is_special)
136 || &*buf == null
137 {
138 out.push(*quote);
142 for b in &buf {
143 if *b == *quote || *b == *escape {
144 out.push(*escape);
145 }
146 out.push(*b);
147 }
148 out.push(*quote);
149 } else {
150 out.extend(&buf);
153 }
154 }
155 }
156 }
157 out.push(b'\n');
158 Ok(())
159}
160
161pub struct CopyTextFormatParser<'a> {
162 data: &'a [u8],
163 position: usize,
164 column_delimiter: u8,
165 null_string: &'a str,
166 buffer: Vec<u8>,
167}
168
169impl<'a> CopyTextFormatParser<'a> {
170 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
171 Self {
172 data,
173 position: 0,
174 column_delimiter,
175 null_string,
176 buffer: Vec::new(),
177 }
178 }
179
180 fn peek(&self) -> Option<u8> {
181 if self.position < self.data.len() {
182 Some(self.data[self.position])
183 } else {
184 None
185 }
186 }
187
188 fn consume_n(&mut self, n: usize) {
189 self.position = std::cmp::min(self.position + n, self.data.len());
190 }
191
192 pub fn is_eof(&self) -> bool {
193 self.peek().is_none() || self.is_end_of_copy_marker()
194 }
195
196 pub fn is_end_of_copy_marker(&self) -> bool {
197 self.check_bytes(END_OF_COPY_MARKER)
198 }
199
200 fn is_end_of_line(&self) -> bool {
201 match self.peek() {
202 Some(b'\n') | None => true,
203 _ => false,
204 }
205 }
206
207 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
208 if self.is_end_of_line() {
209 self.consume_n(1);
210 Ok(())
211 } else {
212 Err(io::Error::new(
213 io::ErrorKind::InvalidData,
214 "extra data after last expected column",
215 ))
216 }
217 }
218
219 fn is_column_delimiter(&self) -> bool {
220 self.check_bytes(&[self.column_delimiter])
221 }
222
223 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
224 if self.consume_bytes(&[self.column_delimiter]) {
225 Ok(())
226 } else {
227 Err(io::Error::new(
228 io::ErrorKind::InvalidData,
229 "missing data for column",
230 ))
231 }
232 }
233
234 fn check_bytes(&self, bytes: &[u8]) -> bool {
235 let remaining_bytes = self.data.len() - self.position;
236 remaining_bytes >= bytes.len()
237 && self.data[self.position..]
238 .iter()
239 .zip(bytes.iter())
240 .all(|(x, y)| x == y)
241 }
242
243 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
244 if self.check_bytes(bytes) {
245 self.consume_n(bytes.len());
246 true
247 } else {
248 false
249 }
250 }
251
252 fn consume_null_string(&mut self) -> bool {
253 if self.null_string.is_empty() {
254 self.is_column_delimiter()
257 || self.is_end_of_line()
258 || self.is_end_of_copy_marker()
259 || self.is_eof()
260 } else {
261 self.consume_bytes(self.null_string.as_bytes())
262 }
263 }
264
265 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
266 if self.consume_null_string() {
267 return Ok(None);
268 }
269
270 let mut start = self.position;
271
272 self.buffer.clear();
274
275 while !self.is_eof() && !self.is_end_of_copy_marker() {
276 if self.is_end_of_line() || self.is_column_delimiter() {
277 break;
278 }
279 match self.peek() {
280 Some(b'\\') => {
281 self.buffer.extend(&self.data[start..self.position]);
283
284 self.consume_n(1);
285 match self.peek() {
286 Some(b'b') => {
287 self.consume_n(1);
288 self.buffer.push(8);
289 }
290 Some(b'f') => {
291 self.consume_n(1);
292 self.buffer.push(12);
293 }
294 Some(b'n') => {
295 self.consume_n(1);
296 self.buffer.push(b'\n');
297 }
298 Some(b'r') => {
299 self.consume_n(1);
300 self.buffer.push(b'\r');
301 }
302 Some(b't') => {
303 self.consume_n(1);
304 self.buffer.push(b'\t');
305 }
306 Some(b'v') => {
307 self.consume_n(1);
308 self.buffer.push(11);
309 }
310 Some(b'x') => {
311 self.consume_n(1);
312 match self.peek() {
313 Some(_c @ b'0'..=b'9')
314 | Some(_c @ b'A'..=b'F')
315 | Some(_c @ b'a'..=b'f') => {
316 let mut value: u8 = 0;
317 let decode_nibble = |b| match b {
318 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
319 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
320 Some(c @ b'0'..=b'9') => Some(c - b'0'),
321 _ => None,
322 };
323 for _ in 0..2 {
324 match decode_nibble(self.peek()) {
325 Some(c) => {
326 self.consume_n(1);
327 value = value << 4 | c;
328 }
329 _ => break,
330 }
331 }
332 self.buffer.push(value);
333 }
334 _ => {
335 self.buffer.push(b'x');
336 }
337 }
338 }
339 Some(_c @ b'0'..=b'7') => {
340 let mut value: u8 = 0;
341 for _ in 0..3 {
342 match self.peek() {
343 Some(c @ b'0'..=b'7') => {
344 self.consume_n(1);
345 value = value << 3 | (c - b'0');
346 }
347 _ => break,
348 }
349 }
350 self.buffer.push(value);
351 }
352 Some(c) => {
353 self.consume_n(1);
354 self.buffer.push(c);
355 }
356 None => {
357 self.buffer.push(b'\\');
358 }
359 }
360
361 start = self.position;
362 }
363 Some(_) => {
364 self.consume_n(1);
365 }
366 None => {}
367 }
368 }
369
370 if self.buffer.is_empty() {
372 Ok(Some(&self.data[start..self.position]))
373 } else {
374 self.buffer.extend(&self.data[start..self.position]);
377 Ok(Some(&self.buffer[..]))
378 }
379 }
380
381 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
383 RawIterator {
384 parser: self,
385 current_column: 0,
386 num_columns,
387 truncate: false,
388 }
389 }
390
391 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
393 RawIterator {
394 parser: self,
395 current_column: 0,
396 num_columns,
397 truncate: true,
398 }
399 }
400}
401
402pub struct RawIterator<'a> {
403 parser: CopyTextFormatParser<'a>,
404 current_column: usize,
405 num_columns: usize,
406 truncate: bool,
407}
408
409impl<'a> RawIterator<'a> {
410 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
411 if self.current_column > self.num_columns {
412 return None;
413 }
414
415 if self.current_column == self.num_columns {
416 if !self.truncate {
417 if let Some(err) = self.parser.expect_end_of_line().err() {
418 return Some(Err(err));
419 }
420 }
421
422 return None;
423 }
424
425 if self.current_column > 0 {
426 if let Some(err) = self.parser.expect_column_delimiter().err() {
427 return Some(Err(err));
428 }
429 }
430
431 self.current_column += 1;
432 Some(self.parser.consume_raw_value())
433 }
434}
435
436#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
437pub enum CopyFormatParams<'a> {
438 Text(CopyTextFormatParams<'a>),
439 Csv(CopyCsvFormatParams<'a>),
440 Binary,
441 Parquet,
442}
443
444impl RustType<ProtoCopyFormatParams> for CopyFormatParams<'static> {
445 fn into_proto(&self) -> ProtoCopyFormatParams {
446 use proto_copy_format_params::Kind;
447 ProtoCopyFormatParams {
448 kind: Some(match self {
449 Self::Text(f) => Kind::Text(f.into_proto()),
450 Self::Csv(f) => Kind::Csv(f.into_proto()),
451 Self::Binary => Kind::Binary(()),
452 Self::Parquet => Kind::Parquet(ProtoCopyParquetFormatParams::default()),
453 }),
454 }
455 }
456
457 fn from_proto(proto: ProtoCopyFormatParams) -> Result<Self, TryFromProtoError> {
458 use proto_copy_format_params::Kind;
459 match proto.kind {
460 Some(Kind::Text(f)) => Ok(Self::Text(f.into_rust()?)),
461 Some(Kind::Csv(f)) => Ok(Self::Csv(f.into_rust()?)),
462 Some(Kind::Binary(())) => Ok(Self::Binary),
463 Some(Kind::Parquet(ProtoCopyParquetFormatParams {})) => Ok(Self::Parquet),
464 None => Err(TryFromProtoError::missing_field(
465 "ProtoCopyFormatParams::kind",
466 )),
467 }
468 }
469}
470
471impl Arbitrary for CopyFormatParams<'static> {
472 type Parameters = ();
473 type Strategy = Union<BoxedStrategy<Self>>;
474
475 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
476 Union::new(vec![
477 any::<CopyTextFormatParams>().prop_map(Self::Text).boxed(),
478 any::<CopyCsvFormatParams>().prop_map(Self::Csv).boxed(),
479 Just(Self::Binary).boxed(),
480 ])
481 }
482}
483
484impl CopyFormatParams<'static> {
485 pub fn file_extension(&self) -> &str {
486 match self {
487 &CopyFormatParams::Text(_) => "txt",
488 &CopyFormatParams::Csv(_) => "csv",
489 &CopyFormatParams::Binary => "bin",
490 &CopyFormatParams::Parquet => "parquet",
491 }
492 }
493
494 pub fn requires_header(&self) -> bool {
495 match self {
496 CopyFormatParams::Text(_) => false,
497 CopyFormatParams::Csv(params) => params.header,
498 CopyFormatParams::Binary => false,
499 CopyFormatParams::Parquet => false,
500 }
501 }
502}
503
504pub fn decode_copy_format<'a>(
506 data: &[u8],
507 column_types: &[mz_pgrepr::Type],
508 params: CopyFormatParams<'a>,
509) -> Result<Vec<Row>, io::Error> {
510 match params {
511 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
512 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
513 CopyFormatParams::Binary => Err(io::Error::new(
514 io::ErrorKind::Unsupported,
515 "cannot decode as binary format",
516 )),
517 CopyFormatParams::Parquet => {
518 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
520 }
521 }
522}
523
524pub fn encode_copy_format<'a>(
526 params: &CopyFormatParams<'a>,
527 row: &RowRef,
528 typ: &RelationType,
529 out: &mut Vec<u8>,
530) -> Result<(), io::Error> {
531 match params {
532 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
533 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
534 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
535 CopyFormatParams::Parquet => {
536 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
538 }
539 }
540}
541
542pub fn encode_copy_format_header<'a>(
543 params: &CopyFormatParams<'a>,
544 desc: &RelationDesc,
545 out: &mut Vec<u8>,
546) -> Result<(), io::Error> {
547 match params {
548 CopyFormatParams::Text(_) => Ok(()),
549 CopyFormatParams::Binary => Ok(()),
550 CopyFormatParams::Csv(params) => {
551 let mut header_row = Row::with_capacity(desc.arity());
552 header_row
553 .packer()
554 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
555 let typ = RelationType::new(vec![
556 ColumnType {
557 scalar_type: ScalarType::String,
558 nullable: false,
559 };
560 desc.arity()
561 ]);
562 encode_copy_row_csv(params, &header_row, &typ, out)
563 }
564 CopyFormatParams::Parquet => {
565 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
567 }
568 }
569}
570
571#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
572pub struct CopyTextFormatParams<'a> {
573 pub null: Cow<'a, str>,
574 pub delimiter: u8,
575}
576
577impl<'a> Default for CopyTextFormatParams<'a> {
578 fn default() -> Self {
579 CopyTextFormatParams {
580 delimiter: b'\t',
581 null: Cow::from("\\N"),
582 }
583 }
584}
585
586impl RustType<ProtoCopyTextFormatParams> for CopyTextFormatParams<'static> {
587 fn into_proto(&self) -> ProtoCopyTextFormatParams {
588 ProtoCopyTextFormatParams {
589 null: self.null.into_proto(),
590 delimiter: self.delimiter.into_proto(),
591 }
592 }
593
594 fn from_proto(proto: ProtoCopyTextFormatParams) -> Result<Self, TryFromProtoError> {
595 Ok(Self {
596 null: Cow::Owned(proto.null.into_rust()?),
597 delimiter: proto.delimiter.into_rust()?,
598 })
599 }
600}
601
602impl Arbitrary for CopyTextFormatParams<'static> {
603 type Parameters = ();
604 type Strategy = BoxedStrategy<Self>;
605
606 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
607 (any::<String>(), any::<u8>())
608 .prop_map(|(null, delimiter)| Self {
609 null: Cow::Owned(null),
610 delimiter,
611 })
612 .boxed()
613 }
614}
615
616pub fn decode_copy_format_text(
617 data: &[u8],
618 column_types: &[mz_pgrepr::Type],
619 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
620) -> Result<Vec<Row>, io::Error> {
621 let mut rows = Vec::new();
622
623 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
625 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
626 let mut row = Vec::new();
627 let buf = RowArena::new();
628 for (col, typ) in column_types.iter().enumerate() {
629 if col > 0 {
630 parser.expect_column_delimiter()?;
631 }
632 let raw_value = parser.consume_raw_value()?;
633 if let Some(raw_value) = raw_value {
634 match mz_pgrepr::Value::decode_text(typ, raw_value) {
635 Ok(value) => row.push(value.into_datum(&buf, typ)),
636 Err(err) => {
637 let msg = format!("unable to decode column: {}", err);
638 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
639 }
640 }
641 } else {
642 row.push(Datum::Null);
643 }
644 }
645 parser.expect_end_of_line()?;
646 rows.push(Row::pack(row));
647 }
648 Ok(rows)
651}
652
653#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
654pub struct CopyCsvFormatParams<'a> {
655 pub delimiter: u8,
656 pub quote: u8,
657 pub escape: u8,
658 pub header: bool,
659 pub null: Cow<'a, str>,
660}
661
662impl<'a> CopyCsvFormatParams<'a> {
663 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
664 CopyCsvFormatParams {
665 delimiter: self.delimiter,
666 quote: self.quote,
667 escape: self.escape,
668 header: self.header,
669 null: Cow::Owned(self.null.to_string()),
670 }
671 }
672}
673
674impl RustType<ProtoCopyCsvFormatParams> for CopyCsvFormatParams<'static> {
675 fn into_proto(&self) -> ProtoCopyCsvFormatParams {
676 ProtoCopyCsvFormatParams {
677 delimiter: self.delimiter.into(),
678 quote: self.quote.into(),
679 escape: self.escape.into(),
680 header: self.header,
681 null: self.null.into_proto(),
682 }
683 }
684
685 fn from_proto(proto: ProtoCopyCsvFormatParams) -> Result<Self, TryFromProtoError> {
686 Ok(Self {
687 delimiter: proto.delimiter.into_rust()?,
688 quote: proto.quote.into_rust()?,
689 escape: proto.escape.into_rust()?,
690 header: proto.header,
691 null: Cow::Owned(proto.null.into_rust()?),
692 })
693 }
694}
695
696impl Arbitrary for CopyCsvFormatParams<'static> {
697 type Parameters = ();
698 type Strategy = BoxedStrategy<Self>;
699
700 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
701 (
702 any::<u8>(),
703 any::<u8>(),
704 any::<u8>(),
705 any::<bool>(),
706 any::<String>(),
707 )
708 .prop_map(|(delimiter, diff, escape, header, null)| {
709 let diff = diff.saturating_sub(1).max(1);
711 let quote = delimiter.wrapping_add(diff);
712
713 Self::try_new(
714 Some(delimiter),
715 Some(quote),
716 Some(escape),
717 Some(header),
718 Some(null),
719 )
720 .expect("delimiter and quote should be different")
721 })
722 .boxed()
723 }
724}
725
726impl<'a> Default for CopyCsvFormatParams<'a> {
727 fn default() -> Self {
728 CopyCsvFormatParams {
729 delimiter: b',',
730 quote: b'"',
731 escape: b'"',
732 header: false,
733 null: Cow::from(""),
734 }
735 }
736}
737
738impl<'a> CopyCsvFormatParams<'a> {
739 pub fn try_new(
740 delimiter: Option<u8>,
741 quote: Option<u8>,
742 escape: Option<u8>,
743 header: Option<bool>,
744 null: Option<String>,
745 ) -> Result<CopyCsvFormatParams<'a>, String> {
746 let mut params = CopyCsvFormatParams::default();
747
748 if let Some(delimiter) = delimiter {
749 params.delimiter = delimiter;
750 }
751 if let Some(quote) = quote {
752 params.quote = quote;
753 params.escape = quote;
755 }
756 if let Some(escape) = escape {
757 params.escape = escape;
758 }
759 if let Some(header) = header {
760 params.header = header;
761 }
762 if let Some(null) = null {
763 params.null = Cow::from(null);
764 }
765
766 if params.quote == params.delimiter {
767 return Err("COPY delimiter and quote must be different".to_string());
768 }
769 Ok(params)
770 }
771}
772
773pub fn decode_copy_format_csv(
774 data: &[u8],
775 column_types: &[mz_pgrepr::Type],
776 CopyCsvFormatParams {
777 delimiter,
778 quote,
779 escape,
780 null,
781 header,
782 }: CopyCsvFormatParams,
783) -> Result<Vec<Row>, io::Error> {
784 let mut rows = Vec::new();
785
786 let (double_quote, escape) = if quote == escape {
787 (true, None)
788 } else {
789 (false, Some(escape))
790 };
791
792 let mut rdr = ReaderBuilder::new()
793 .delimiter(delimiter)
794 .quote(quote)
795 .has_headers(header)
796 .double_quote(double_quote)
797 .escape(escape)
798 .flexible(true)
801 .from_reader(data);
802
803 let null_as_bytes = null.as_bytes();
804
805 let mut record = ByteRecord::new();
806
807 while rdr.read_byte_record(&mut record)? {
808 if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
809 break;
810 }
811
812 match record.len().cmp(&column_types.len()) {
813 std::cmp::Ordering::Less => Err(io::Error::new(
814 io::ErrorKind::InvalidData,
815 "missing data for column",
816 )),
817 std::cmp::Ordering::Greater => Err(io::Error::new(
818 io::ErrorKind::InvalidData,
819 "extra data after last expected column",
820 )),
821 std::cmp::Ordering::Equal => Ok(()),
822 }?;
823
824 let mut row_builder = SharedRow::get();
825 let mut row_packer = row_builder.packer();
826
827 for (typ, raw_value) in column_types.iter().zip(record.iter()) {
828 if raw_value == null_as_bytes {
829 row_packer.push(Datum::Null);
830 } else {
831 let s = match std::str::from_utf8(raw_value) {
832 Ok(s) => s,
833 Err(err) => {
834 let msg = format!("invalid utf8 data in column: {}", err);
835 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
836 }
837 };
838 match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
839 Ok(()) => {}
840 Err(err) => {
841 let msg = format!("unable to decode column: {}", err);
842 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
843 }
844 }
845 }
846 }
847 rows.push(row_builder.clone());
848 }
849
850 Ok(rows)
851}
852
853#[cfg(test)]
854mod tests {
855 use mz_ore::collections::CollectionExt;
856 use mz_repr::{ColumnType, ScalarType};
857 use proptest::prelude::*;
858
859 use super::*;
860
861 #[mz_ore::test]
862 fn test_copy_format_text_parser() {
863 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
864 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
865 assert!(parser.is_column_delimiter());
866 parser
867 .expect_column_delimiter()
868 .expect("expected column delimiter");
869 assert_eq!(
870 parser
871 .consume_raw_value()
872 .expect("unexpected error")
873 .expect("unexpected empty result"),
874 "\nt e".as_bytes()
875 );
876 parser
877 .expect_column_delimiter()
878 .expect("expected column delimiter");
879 assert!(
881 parser
882 .consume_raw_value()
883 .expect("unexpected error")
884 .is_none()
885 );
886 parser
887 .expect_column_delimiter()
888 .expect("expected column delimiter");
889 assert!(parser.is_end_of_line());
890 parser.expect_end_of_line().expect("expected eol");
891 assert_eq!(
893 parser
894 .consume_raw_value()
895 .expect("unexpected error")
896 .expect("unexpected empty result"),
897 "`\n}J".as_bytes()
898 );
899 parser.expect_end_of_line().expect("expected eol");
900 assert_eq!(
902 parser
903 .consume_raw_value()
904 .expect("unexpected error")
905 .expect("unexpected empty result"),
906 "$$S".as_bytes()
907 );
908 assert!(parser.is_eof());
909 }
910
911 #[mz_ore::test]
912 fn test_copy_format_text_empty_null_string() {
913 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
914 let expect = vec![
915 vec![None, None],
916 vec![Some("10"), Some("20")],
917 vec![Some("30"), None],
918 vec![Some("40"), None],
919 ];
920 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
921 for line in expect {
922 for (i, value) in line.iter().enumerate() {
923 if i > 0 {
924 parser
925 .expect_column_delimiter()
926 .expect("expected column delimiter");
927 }
928 match value {
929 Some(s) => {
930 assert!(!parser.consume_null_string());
931 assert_eq!(
932 parser
933 .consume_raw_value()
934 .expect("unexpected error")
935 .expect("unexpected empty result"),
936 s.as_bytes()
937 );
938 }
939 None => {
940 assert!(parser.consume_null_string());
941 }
942 }
943 }
944 parser.expect_end_of_line().expect("expected eol");
945 }
946 }
947
948 #[mz_ore::test]
949 fn test_copy_format_text_parser_escapes() {
950 struct TestCase {
951 input: &'static str,
952 expect: &'static [u8],
953 }
954 let tests = vec![
955 TestCase {
956 input: "simple",
957 expect: b"simple",
958 },
959 TestCase {
960 input: r#"new\nline"#,
961 expect: b"new\nline",
962 },
963 TestCase {
964 input: r#"\b\f\n\r\t\v\\"#,
965 expect: b"\x08\x0c\n\r\t\x0b\\",
966 },
967 TestCase {
968 input: r#"\0\12\123"#,
969 expect: &[0, 0o12, 0o123],
970 },
971 TestCase {
972 input: r#"\x1\xaf"#,
973 expect: &[0x01, 0xaf],
974 },
975 TestCase {
976 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
977 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
978 },
979 TestCase {
980 input: r#"\\\""#,
981 expect: b"\\\"",
982 },
983 TestCase {
984 input: r#"\x"#,
985 expect: b"x",
986 },
987 TestCase {
988 input: r#"\xg"#,
989 expect: b"xg",
990 },
991 TestCase {
992 input: r#"\"#,
993 expect: b"\\",
994 },
995 TestCase {
996 input: r#"\8"#,
997 expect: b"8",
998 },
999 TestCase {
1000 input: r#"\a"#,
1001 expect: b"a",
1002 },
1003 TestCase {
1004 input: r#"\x\xg\8\xH\x32\s\"#,
1005 expect: b"xxg8xH2s\\",
1006 },
1007 ];
1008
1009 for test in tests {
1010 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1011 assert_eq!(
1012 parser
1013 .consume_raw_value()
1014 .expect("unexpected error")
1015 .expect("unexpected empty result"),
1016 test.expect,
1017 "input: {}, expect: {:?}",
1018 test.input,
1019 std::str::from_utf8(test.expect),
1020 );
1021 assert!(parser.is_eof());
1022 }
1023 }
1024
1025 #[mz_ore::test]
1026 fn test_copy_csv_format_params() {
1027 assert_eq!(
1028 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1029 Ok(CopyCsvFormatParams {
1030 delimiter: b't',
1031 quote: b'q',
1032 escape: b'q',
1033 header: false,
1034 null: Cow::from(""),
1035 })
1036 );
1037
1038 assert_eq!(
1039 CopyCsvFormatParams::try_new(
1040 Some(b't'),
1041 Some(b'q'),
1042 Some(b'e'),
1043 Some(true),
1044 Some("null".to_string())
1045 ),
1046 Ok(CopyCsvFormatParams {
1047 delimiter: b't',
1048 quote: b'q',
1049 escape: b'e',
1050 header: true,
1051 null: Cow::from("null"),
1052 })
1053 );
1054
1055 assert_eq!(
1056 CopyCsvFormatParams::try_new(
1057 None,
1058 Some(b','),
1059 Some(b'e'),
1060 Some(true),
1061 Some("null".to_string())
1062 ),
1063 Err("COPY delimiter and quote must be different".to_string())
1064 );
1065 }
1066
1067 #[mz_ore::test]
1068 fn test_copy_csv_row() -> Result<(), io::Error> {
1069 let mut row = Row::default();
1070 let mut packer = row.packer();
1071 packer.push(Datum::from("1,2,\"3\""));
1072 packer.push(Datum::Null);
1073 packer.push(Datum::from(1000u64));
1074 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
1076
1077 let typ: RelationType = RelationType::new(vec![
1078 ColumnType {
1079 scalar_type: mz_repr::ScalarType::String,
1080 nullable: false,
1081 },
1082 ColumnType {
1083 scalar_type: mz_repr::ScalarType::String,
1084 nullable: true,
1085 },
1086 ColumnType {
1087 scalar_type: mz_repr::ScalarType::UInt64,
1088 nullable: false,
1089 },
1090 ColumnType {
1091 scalar_type: mz_repr::ScalarType::String,
1092 nullable: false,
1093 },
1094 ColumnType {
1095 scalar_type: mz_repr::ScalarType::String,
1096 nullable: false,
1097 },
1098 ]);
1099
1100 let mut out = Vec::new();
1101
1102 struct TestCase<'a> {
1103 params: CopyCsvFormatParams<'a>,
1104 expected: &'static [u8],
1105 }
1106
1107 let tests = [
1108 TestCase {
1109 params: CopyCsvFormatParams::default(),
1110 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1111 },
1112 TestCase {
1113 params: CopyCsvFormatParams {
1114 null: Cow::from("NULL"),
1115 quote: b'q',
1116 escape: b'e',
1117 ..Default::default()
1118 },
1119 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1120 },
1121 ];
1122
1123 for TestCase { params, expected } in tests {
1124 out.clear();
1125 let params = CopyFormatParams::Csv(params);
1126 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1127 let output = std::str::from_utf8(&out);
1128 assert_eq!(output, std::str::from_utf8(expected));
1129 }
1130
1131 Ok(())
1132 }
1133
1134 proptest! {
1135 #[mz_ore::test]
1136 #[cfg_attr(miri, ignore)]
1137 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1138 let try_roundtrip_datum = |scalar_type: &ScalarType, datum| {
1140 let row = Row::pack_slice(&[datum]);
1141 let typ = RelationType::new(vec![
1142 ColumnType {
1143 scalar_type: scalar_type.clone(),
1144 nullable: true,
1145 }
1146 ]);
1147
1148 let mut buf = Vec::new();
1149 let mut csv_params = copy_csv_params.clone();
1150 csv_params.header = false;
1152 let params = CopyFormatParams::Csv(csv_params);
1153
1154 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1156 let column_types = typ
1157 .column_types
1158 .iter()
1159 .map(|x| &x.scalar_type)
1160 .map(mz_pgrepr::Type::from)
1161 .collect::<Vec<mz_pgrepr::Type>>();
1162 let result = decode_copy_format(&buf, &column_types, params);
1163
1164 match result {
1165 Ok(rows) => {
1166 let out_str = std::str::from_utf8(&buf[..]);
1167
1168 prop_assert_eq!(
1169 rows.len(),
1170 1,
1171 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1172 );
1173 let output = rows.into_element();
1174
1175 prop_assert_eq!(
1176 row,
1177 output,
1178 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1179 );
1180 }
1181 _ => {
1182 }
1184 }
1185
1186 Ok(())
1187 };
1188
1189 for scalar_type in ScalarType::enumerate() {
1191 for datum in scalar_type.interesting_datums() {
1192 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1194 let mut buf = bytes::BytesMut::new();
1195 value.encode_text(&mut buf);
1196
1197 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1198 if datum_str == copy_csv_params.null {
1199 continue;
1200 }
1201 }
1202 }
1203
1204 let updated_datum = match datum {
1205 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1207 continue;
1208 }
1209 Datum::String(s) => {
1210 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1212 continue;
1213 } else {
1214 Datum::String(s)
1215 }
1216 }
1217 other => other,
1218 };
1219
1220 let result = try_roundtrip_datum(scalar_type, updated_datum);
1221 prop_assert!(result.is_ok(), "failure: {result:?}");
1222 }
1223 }
1224 }
1225 }
1226}