1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use mz_proto::{ProtoType, RustType, TryFromProtoError};
16use mz_repr::{
17 ColumnType, Datum, RelationDesc, RelationType, Row, RowArena, RowRef, ScalarType, SharedRow,
18};
19use proptest::prelude::{Arbitrary, Just, any};
20use proptest::strategy::{BoxedStrategy, Strategy, Union};
21use serde::Deserialize;
22use serde::Serialize;
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26include!(concat!(env!("OUT_DIR"), "/mz_pgcopy.copy.rs"));
27
28fn encode_copy_row_binary(
29 row: &RowRef,
30 typ: &RelationType,
31 out: &mut Vec<u8>,
32) -> Result<(), io::Error> {
33 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
34
35 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
37 io::Error::new(
38 io::ErrorKind::Other,
39 "column count does not fit into an i16",
40 )
41 })?;
42
43 out.extend(count.to_be_bytes());
44 let mut buf = BytesMut::new();
45 for (field, typ) in row
46 .iter()
47 .zip(&typ.column_types)
48 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
49 {
50 match field {
51 None => out.extend(NULL_BYTES),
52 Some(field) => {
53 buf.clear();
54 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
55 out.extend(
56 i32::try_from(buf.len())
57 .map_err(|_| {
58 io::Error::new(
59 io::ErrorKind::Other,
60 "field length does not fit into an i32",
61 )
62 })?
63 .to_be_bytes(),
64 );
65 out.extend(&buf);
66 }
67 }
68 }
69 Ok(())
70}
71
72fn encode_copy_row_text(
73 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
74 row: &RowRef,
75 typ: &RelationType,
76 out: &mut Vec<u8>,
77) -> Result<(), io::Error> {
78 let null = null.as_bytes();
79 let mut buf = BytesMut::new();
80 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
81 if idx > 0 {
82 out.push(*delimiter);
83 }
84 match field {
85 None => out.extend(null),
86 Some(field) => {
87 buf.clear();
88 field.encode_text(&mut buf);
89 for b in &buf {
90 match b {
91 b'\\' => out.extend(b"\\\\"),
92 b'\n' => out.extend(b"\\n"),
93 b'\r' => out.extend(b"\\r"),
94 b'\t' => out.extend(b"\\t"),
95 _ => out.push(*b),
96 }
97 }
98 }
99 }
100 }
101 out.push(b'\n');
102 Ok(())
103}
104
105fn encode_copy_row_csv(
106 CopyCsvFormatParams {
107 delimiter: delim,
108 quote,
109 escape,
110 header: _,
111 null,
112 }: &CopyCsvFormatParams,
113 row: &RowRef,
114 typ: &RelationType,
115 out: &mut Vec<u8>,
116) -> Result<(), io::Error> {
117 let null = null.as_bytes();
118 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
119 let mut buf = BytesMut::new();
120 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
121 if idx > 0 {
122 out.push(*delim);
123 }
124 match field {
125 None => out.extend(null),
126 Some(field) => {
127 buf.clear();
128 field.encode_text(&mut buf);
129 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
135 || buf.iter().any(is_special)
136 || &*buf == null
137 {
138 out.push(*quote);
142 for b in &buf {
143 if *b == *quote || *b == *escape {
144 out.push(*escape);
145 }
146 out.push(*b);
147 }
148 out.push(*quote);
149 } else {
150 out.extend(&buf);
153 }
154 }
155 }
156 }
157 out.push(b'\n');
158 Ok(())
159}
160
161pub struct CopyTextFormatParser<'a> {
162 data: &'a [u8],
163 position: usize,
164 column_delimiter: u8,
165 null_string: &'a str,
166 buffer: Vec<u8>,
167}
168
169impl<'a> CopyTextFormatParser<'a> {
170 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
171 Self {
172 data,
173 position: 0,
174 column_delimiter,
175 null_string,
176 buffer: Vec::new(),
177 }
178 }
179
180 fn peek(&self) -> Option<u8> {
181 if self.position < self.data.len() {
182 Some(self.data[self.position])
183 } else {
184 None
185 }
186 }
187
188 fn consume_n(&mut self, n: usize) {
189 self.position = std::cmp::min(self.position + n, self.data.len());
190 }
191
192 pub fn is_eof(&self) -> bool {
193 self.peek().is_none() || self.is_end_of_copy_marker()
194 }
195
196 pub fn is_end_of_copy_marker(&self) -> bool {
197 self.check_bytes(END_OF_COPY_MARKER)
198 }
199
200 fn is_end_of_line(&self) -> bool {
201 match self.peek() {
202 Some(b'\n') | None => true,
203 _ => false,
204 }
205 }
206
207 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
208 if self.is_end_of_line() {
209 self.consume_n(1);
210 Ok(())
211 } else {
212 Err(io::Error::new(
213 io::ErrorKind::InvalidData,
214 "extra data after last expected column",
215 ))
216 }
217 }
218
219 fn is_column_delimiter(&self) -> bool {
220 self.check_bytes(&[self.column_delimiter])
221 }
222
223 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
224 if self.consume_bytes(&[self.column_delimiter]) {
225 Ok(())
226 } else {
227 Err(io::Error::new(
228 io::ErrorKind::InvalidData,
229 "missing data for column",
230 ))
231 }
232 }
233
234 fn check_bytes(&self, bytes: &[u8]) -> bool {
235 let remaining_bytes = self.data.len() - self.position;
236 remaining_bytes >= bytes.len()
237 && self.data[self.position..]
238 .iter()
239 .zip(bytes.iter())
240 .all(|(x, y)| x == y)
241 }
242
243 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
244 if self.check_bytes(bytes) {
245 self.consume_n(bytes.len());
246 true
247 } else {
248 false
249 }
250 }
251
252 fn consume_null_string(&mut self) -> bool {
253 if self.null_string.is_empty() {
254 self.is_column_delimiter()
257 || self.is_end_of_line()
258 || self.is_end_of_copy_marker()
259 || self.is_eof()
260 } else {
261 self.consume_bytes(self.null_string.as_bytes())
262 }
263 }
264
265 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
266 if self.consume_null_string() {
267 return Ok(None);
268 }
269
270 let mut start = self.position;
271
272 self.buffer.clear();
274
275 while !self.is_eof() && !self.is_end_of_copy_marker() {
276 if self.is_end_of_line() || self.is_column_delimiter() {
277 break;
278 }
279 match self.peek() {
280 Some(b'\\') => {
281 self.buffer.extend(&self.data[start..self.position]);
283
284 self.consume_n(1);
285 match self.peek() {
286 Some(b'b') => {
287 self.consume_n(1);
288 self.buffer.push(8);
289 }
290 Some(b'f') => {
291 self.consume_n(1);
292 self.buffer.push(12);
293 }
294 Some(b'n') => {
295 self.consume_n(1);
296 self.buffer.push(b'\n');
297 }
298 Some(b'r') => {
299 self.consume_n(1);
300 self.buffer.push(b'\r');
301 }
302 Some(b't') => {
303 self.consume_n(1);
304 self.buffer.push(b'\t');
305 }
306 Some(b'v') => {
307 self.consume_n(1);
308 self.buffer.push(11);
309 }
310 Some(b'x') => {
311 self.consume_n(1);
312 match self.peek() {
313 Some(_c @ b'0'..=b'9')
314 | Some(_c @ b'A'..=b'F')
315 | Some(_c @ b'a'..=b'f') => {
316 let mut value: u8 = 0;
317 let decode_nibble = |b| match b {
318 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
319 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
320 Some(c @ b'0'..=b'9') => Some(c - b'0'),
321 _ => None,
322 };
323 for _ in 0..2 {
324 match decode_nibble(self.peek()) {
325 Some(c) => {
326 self.consume_n(1);
327 value = value << 4 | c;
328 }
329 _ => break,
330 }
331 }
332 self.buffer.push(value);
333 }
334 _ => {
335 self.buffer.push(b'x');
336 }
337 }
338 }
339 Some(_c @ b'0'..=b'7') => {
340 let mut value: u8 = 0;
341 for _ in 0..3 {
342 match self.peek() {
343 Some(c @ b'0'..=b'7') => {
344 self.consume_n(1);
345 value = value << 3 | (c - b'0');
346 }
347 _ => break,
348 }
349 }
350 self.buffer.push(value);
351 }
352 Some(c) => {
353 self.consume_n(1);
354 self.buffer.push(c);
355 }
356 None => {
357 self.buffer.push(b'\\');
358 }
359 }
360
361 start = self.position;
362 }
363 Some(_) => {
364 self.consume_n(1);
365 }
366 None => {}
367 }
368 }
369
370 if self.buffer.is_empty() {
372 Ok(Some(&self.data[start..self.position]))
373 } else {
374 self.buffer.extend(&self.data[start..self.position]);
377 Ok(Some(&self.buffer[..]))
378 }
379 }
380
381 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
383 RawIterator {
384 parser: self,
385 current_column: 0,
386 num_columns,
387 truncate: false,
388 }
389 }
390
391 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
393 RawIterator {
394 parser: self,
395 current_column: 0,
396 num_columns,
397 truncate: true,
398 }
399 }
400}
401
402pub struct RawIterator<'a> {
403 parser: CopyTextFormatParser<'a>,
404 current_column: usize,
405 num_columns: usize,
406 truncate: bool,
407}
408
409impl<'a> RawIterator<'a> {
410 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
411 if self.current_column > self.num_columns {
412 return None;
413 }
414
415 if self.current_column == self.num_columns {
416 if !self.truncate {
417 if let Some(err) = self.parser.expect_end_of_line().err() {
418 return Some(Err(err));
419 }
420 }
421
422 return None;
423 }
424
425 if self.current_column > 0 {
426 if let Some(err) = self.parser.expect_column_delimiter().err() {
427 return Some(Err(err));
428 }
429 }
430
431 self.current_column += 1;
432 Some(self.parser.consume_raw_value())
433 }
434}
435
436#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
437pub enum CopyFormatParams<'a> {
438 Text(CopyTextFormatParams<'a>),
439 Csv(CopyCsvFormatParams<'a>),
440 Binary,
441 Parquet,
442}
443
444impl RustType<ProtoCopyFormatParams> for CopyFormatParams<'static> {
445 fn into_proto(&self) -> ProtoCopyFormatParams {
446 use proto_copy_format_params::Kind;
447 ProtoCopyFormatParams {
448 kind: Some(match self {
449 Self::Text(f) => Kind::Text(f.into_proto()),
450 Self::Csv(f) => Kind::Csv(f.into_proto()),
451 Self::Binary => Kind::Binary(()),
452 Self::Parquet => Kind::Parquet(ProtoCopyParquetFormatParams::default()),
453 }),
454 }
455 }
456
457 fn from_proto(proto: ProtoCopyFormatParams) -> Result<Self, TryFromProtoError> {
458 use proto_copy_format_params::Kind;
459 match proto.kind {
460 Some(Kind::Text(f)) => Ok(Self::Text(f.into_rust()?)),
461 Some(Kind::Csv(f)) => Ok(Self::Csv(f.into_rust()?)),
462 Some(Kind::Binary(())) => Ok(Self::Binary),
463 Some(Kind::Parquet(ProtoCopyParquetFormatParams {})) => Ok(Self::Parquet),
464 None => Err(TryFromProtoError::missing_field(
465 "ProtoCopyFormatParams::kind",
466 )),
467 }
468 }
469}
470
471impl Arbitrary for CopyFormatParams<'static> {
472 type Parameters = ();
473 type Strategy = Union<BoxedStrategy<Self>>;
474
475 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
476 Union::new(vec![
477 any::<CopyTextFormatParams>().prop_map(Self::Text).boxed(),
478 any::<CopyCsvFormatParams>().prop_map(Self::Csv).boxed(),
479 Just(Self::Binary).boxed(),
480 ])
481 }
482}
483
484impl CopyFormatParams<'static> {
485 pub fn file_extension(&self) -> &str {
486 match self {
487 &CopyFormatParams::Text(_) => "txt",
488 &CopyFormatParams::Csv(_) => "csv",
489 &CopyFormatParams::Binary => "bin",
490 &CopyFormatParams::Parquet => "parquet",
491 }
492 }
493
494 pub fn requires_header(&self) -> bool {
495 match self {
496 CopyFormatParams::Text(_) => false,
497 CopyFormatParams::Csv(params) => params.header,
498 CopyFormatParams::Binary => false,
499 CopyFormatParams::Parquet => false,
500 }
501 }
502}
503
504pub fn decode_copy_format<'a>(
506 data: &[u8],
507 column_types: &[mz_pgrepr::Type],
508 params: CopyFormatParams<'a>,
509) -> Result<Vec<Row>, io::Error> {
510 match params {
511 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
512 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
513 CopyFormatParams::Binary => Err(io::Error::new(
514 io::ErrorKind::Unsupported,
515 "cannot decode as binary format",
516 )),
517 CopyFormatParams::Parquet => {
518 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
520 }
521 }
522}
523
524pub fn encode_copy_format<'a>(
526 params: &CopyFormatParams<'a>,
527 row: &RowRef,
528 typ: &RelationType,
529 out: &mut Vec<u8>,
530) -> Result<(), io::Error> {
531 match params {
532 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
533 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
534 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
535 CopyFormatParams::Parquet => {
536 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
538 }
539 }
540}
541
542pub fn encode_copy_format_header<'a>(
543 params: &CopyFormatParams<'a>,
544 desc: &RelationDesc,
545 out: &mut Vec<u8>,
546) -> Result<(), io::Error> {
547 match params {
548 CopyFormatParams::Text(_) => Ok(()),
549 CopyFormatParams::Binary => Ok(()),
550 CopyFormatParams::Csv(params) => {
551 let mut header_row = Row::with_capacity(desc.arity());
552 header_row
553 .packer()
554 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
555 let typ = RelationType::new(vec![
556 ColumnType {
557 scalar_type: ScalarType::String,
558 nullable: false,
559 };
560 desc.arity()
561 ]);
562 encode_copy_row_csv(params, &header_row, &typ, out)
563 }
564 CopyFormatParams::Parquet => {
565 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
567 }
568 }
569}
570
571#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
572pub struct CopyTextFormatParams<'a> {
573 pub null: Cow<'a, str>,
574 pub delimiter: u8,
575}
576
577impl<'a> Default for CopyTextFormatParams<'a> {
578 fn default() -> Self {
579 CopyTextFormatParams {
580 delimiter: b'\t',
581 null: Cow::from("\\N"),
582 }
583 }
584}
585
586impl RustType<ProtoCopyTextFormatParams> for CopyTextFormatParams<'static> {
587 fn into_proto(&self) -> ProtoCopyTextFormatParams {
588 ProtoCopyTextFormatParams {
589 null: self.null.into_proto(),
590 delimiter: self.delimiter.into_proto(),
591 }
592 }
593
594 fn from_proto(proto: ProtoCopyTextFormatParams) -> Result<Self, TryFromProtoError> {
595 Ok(Self {
596 null: Cow::Owned(proto.null.into_rust()?),
597 delimiter: proto.delimiter.into_rust()?,
598 })
599 }
600}
601
602impl Arbitrary for CopyTextFormatParams<'static> {
603 type Parameters = ();
604 type Strategy = BoxedStrategy<Self>;
605
606 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
607 (any::<String>(), any::<u8>())
608 .prop_map(|(null, delimiter)| Self {
609 null: Cow::Owned(null),
610 delimiter,
611 })
612 .boxed()
613 }
614}
615
616pub fn decode_copy_format_text(
617 data: &[u8],
618 column_types: &[mz_pgrepr::Type],
619 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
620) -> Result<Vec<Row>, io::Error> {
621 let mut rows = Vec::new();
622
623 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
625 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
626 let mut row = Vec::new();
627 let buf = RowArena::new();
628 for (col, typ) in column_types.iter().enumerate() {
629 if col > 0 {
630 parser.expect_column_delimiter()?;
631 }
632 let raw_value = parser.consume_raw_value()?;
633 if let Some(raw_value) = raw_value {
634 match mz_pgrepr::Value::decode_text(typ, raw_value) {
635 Ok(value) => row.push(value.into_datum(&buf, typ)),
636 Err(err) => {
637 let msg = format!("unable to decode column: {}", err);
638 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
639 }
640 }
641 } else {
642 row.push(Datum::Null);
643 }
644 }
645 parser.expect_end_of_line()?;
646 rows.push(Row::pack(row));
647 }
648 Ok(rows)
651}
652
653#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
654pub struct CopyCsvFormatParams<'a> {
655 pub delimiter: u8,
656 pub quote: u8,
657 pub escape: u8,
658 pub header: bool,
659 pub null: Cow<'a, str>,
660}
661
662impl<'a> CopyCsvFormatParams<'a> {
663 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
664 CopyCsvFormatParams {
665 delimiter: self.delimiter,
666 quote: self.quote,
667 escape: self.escape,
668 header: self.header,
669 null: Cow::Owned(self.null.to_string()),
670 }
671 }
672}
673
674impl RustType<ProtoCopyCsvFormatParams> for CopyCsvFormatParams<'static> {
675 fn into_proto(&self) -> ProtoCopyCsvFormatParams {
676 ProtoCopyCsvFormatParams {
677 delimiter: self.delimiter.into(),
678 quote: self.quote.into(),
679 escape: self.escape.into(),
680 header: self.header,
681 null: self.null.into_proto(),
682 }
683 }
684
685 fn from_proto(proto: ProtoCopyCsvFormatParams) -> Result<Self, TryFromProtoError> {
686 Ok(Self {
687 delimiter: proto.delimiter.into_rust()?,
688 quote: proto.quote.into_rust()?,
689 escape: proto.escape.into_rust()?,
690 header: proto.header,
691 null: Cow::Owned(proto.null.into_rust()?),
692 })
693 }
694}
695
696impl Arbitrary for CopyCsvFormatParams<'static> {
697 type Parameters = ();
698 type Strategy = BoxedStrategy<Self>;
699
700 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
701 (
702 any::<u8>(),
703 any::<u8>(),
704 any::<u8>(),
705 any::<bool>(),
706 any::<String>(),
707 )
708 .prop_map(|(delimiter, diff, escape, header, null)| {
709 let diff = diff.saturating_sub(1).max(1);
711 let quote = delimiter.wrapping_add(diff);
712
713 Self::try_new(
714 Some(delimiter),
715 Some(quote),
716 Some(escape),
717 Some(header),
718 Some(null),
719 )
720 .expect("delimiter and quote should be different")
721 })
722 .boxed()
723 }
724}
725
726impl<'a> Default for CopyCsvFormatParams<'a> {
727 fn default() -> Self {
728 CopyCsvFormatParams {
729 delimiter: b',',
730 quote: b'"',
731 escape: b'"',
732 header: false,
733 null: Cow::from(""),
734 }
735 }
736}
737
738impl<'a> CopyCsvFormatParams<'a> {
739 pub fn try_new(
740 delimiter: Option<u8>,
741 quote: Option<u8>,
742 escape: Option<u8>,
743 header: Option<bool>,
744 null: Option<String>,
745 ) -> Result<CopyCsvFormatParams<'a>, String> {
746 let mut params = CopyCsvFormatParams::default();
747
748 if let Some(delimiter) = delimiter {
749 params.delimiter = delimiter;
750 }
751 if let Some(quote) = quote {
752 params.quote = quote;
753 params.escape = quote;
755 }
756 if let Some(escape) = escape {
757 params.escape = escape;
758 }
759 if let Some(header) = header {
760 params.header = header;
761 }
762 if let Some(null) = null {
763 params.null = Cow::from(null);
764 }
765
766 if params.quote == params.delimiter {
767 return Err("COPY delimiter and quote must be different".to_string());
768 }
769 Ok(params)
770 }
771}
772
773pub fn decode_copy_format_csv(
774 data: &[u8],
775 column_types: &[mz_pgrepr::Type],
776 CopyCsvFormatParams {
777 delimiter,
778 quote,
779 escape,
780 null,
781 header,
782 }: CopyCsvFormatParams,
783) -> Result<Vec<Row>, io::Error> {
784 let mut rows = Vec::new();
785
786 let (double_quote, escape) = if quote == escape {
787 (true, None)
788 } else {
789 (false, Some(escape))
790 };
791
792 let mut rdr = ReaderBuilder::new()
793 .delimiter(delimiter)
794 .quote(quote)
795 .has_headers(header)
796 .double_quote(double_quote)
797 .escape(escape)
798 .flexible(true)
801 .from_reader(data);
802
803 let null_as_bytes = null.as_bytes();
804
805 let mut record = ByteRecord::new();
806
807 while rdr.read_byte_record(&mut record)? {
808 if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
809 break;
810 }
811
812 match record.len().cmp(&column_types.len()) {
813 std::cmp::Ordering::Less => Err(io::Error::new(
814 io::ErrorKind::InvalidData,
815 "missing data for column",
816 )),
817 std::cmp::Ordering::Greater => Err(io::Error::new(
818 io::ErrorKind::InvalidData,
819 "extra data after last expected column",
820 )),
821 std::cmp::Ordering::Equal => Ok(()),
822 }?;
823
824 let binding = SharedRow::get();
825 let mut row_builder = binding.borrow_mut();
826 let mut row_packer = row_builder.packer();
827
828 for (typ, raw_value) in column_types.iter().zip(record.iter()) {
829 if raw_value == null_as_bytes {
830 row_packer.push(Datum::Null);
831 } else {
832 let s = match std::str::from_utf8(raw_value) {
833 Ok(s) => s,
834 Err(err) => {
835 let msg = format!("invalid utf8 data in column: {}", err);
836 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
837 }
838 };
839 match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
840 Ok(()) => {}
841 Err(err) => {
842 let msg = format!("unable to decode column: {}", err);
843 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
844 }
845 }
846 }
847 }
848 rows.push(row_builder.clone());
849 }
850
851 Ok(rows)
852}
853
854#[cfg(test)]
855mod tests {
856 use mz_ore::collections::CollectionExt;
857 use mz_repr::{ColumnType, ScalarType};
858 use proptest::prelude::*;
859
860 use super::*;
861
862 #[mz_ore::test]
863 fn test_copy_format_text_parser() {
864 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
865 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
866 assert!(parser.is_column_delimiter());
867 parser
868 .expect_column_delimiter()
869 .expect("expected column delimiter");
870 assert_eq!(
871 parser
872 .consume_raw_value()
873 .expect("unexpected error")
874 .expect("unexpected empty result"),
875 "\nt e".as_bytes()
876 );
877 parser
878 .expect_column_delimiter()
879 .expect("expected column delimiter");
880 assert!(
882 parser
883 .consume_raw_value()
884 .expect("unexpected error")
885 .is_none()
886 );
887 parser
888 .expect_column_delimiter()
889 .expect("expected column delimiter");
890 assert!(parser.is_end_of_line());
891 parser.expect_end_of_line().expect("expected eol");
892 assert_eq!(
894 parser
895 .consume_raw_value()
896 .expect("unexpected error")
897 .expect("unexpected empty result"),
898 "`\n}J".as_bytes()
899 );
900 parser.expect_end_of_line().expect("expected eol");
901 assert_eq!(
903 parser
904 .consume_raw_value()
905 .expect("unexpected error")
906 .expect("unexpected empty result"),
907 "$$S".as_bytes()
908 );
909 assert!(parser.is_eof());
910 }
911
912 #[mz_ore::test]
913 fn test_copy_format_text_empty_null_string() {
914 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
915 let expect = vec![
916 vec![None, None],
917 vec![Some("10"), Some("20")],
918 vec![Some("30"), None],
919 vec![Some("40"), None],
920 ];
921 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
922 for line in expect {
923 for (i, value) in line.iter().enumerate() {
924 if i > 0 {
925 parser
926 .expect_column_delimiter()
927 .expect("expected column delimiter");
928 }
929 match value {
930 Some(s) => {
931 assert!(!parser.consume_null_string());
932 assert_eq!(
933 parser
934 .consume_raw_value()
935 .expect("unexpected error")
936 .expect("unexpected empty result"),
937 s.as_bytes()
938 );
939 }
940 None => {
941 assert!(parser.consume_null_string());
942 }
943 }
944 }
945 parser.expect_end_of_line().expect("expected eol");
946 }
947 }
948
949 #[mz_ore::test]
950 fn test_copy_format_text_parser_escapes() {
951 struct TestCase {
952 input: &'static str,
953 expect: &'static [u8],
954 }
955 let tests = vec![
956 TestCase {
957 input: "simple",
958 expect: b"simple",
959 },
960 TestCase {
961 input: r#"new\nline"#,
962 expect: b"new\nline",
963 },
964 TestCase {
965 input: r#"\b\f\n\r\t\v\\"#,
966 expect: b"\x08\x0c\n\r\t\x0b\\",
967 },
968 TestCase {
969 input: r#"\0\12\123"#,
970 expect: &[0, 0o12, 0o123],
971 },
972 TestCase {
973 input: r#"\x1\xaf"#,
974 expect: &[0x01, 0xaf],
975 },
976 TestCase {
977 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
978 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
979 },
980 TestCase {
981 input: r#"\\\""#,
982 expect: b"\\\"",
983 },
984 TestCase {
985 input: r#"\x"#,
986 expect: b"x",
987 },
988 TestCase {
989 input: r#"\xg"#,
990 expect: b"xg",
991 },
992 TestCase {
993 input: r#"\"#,
994 expect: b"\\",
995 },
996 TestCase {
997 input: r#"\8"#,
998 expect: b"8",
999 },
1000 TestCase {
1001 input: r#"\a"#,
1002 expect: b"a",
1003 },
1004 TestCase {
1005 input: r#"\x\xg\8\xH\x32\s\"#,
1006 expect: b"xxg8xH2s\\",
1007 },
1008 ];
1009
1010 for test in tests {
1011 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1012 assert_eq!(
1013 parser
1014 .consume_raw_value()
1015 .expect("unexpected error")
1016 .expect("unexpected empty result"),
1017 test.expect,
1018 "input: {}, expect: {:?}",
1019 test.input,
1020 std::str::from_utf8(test.expect),
1021 );
1022 assert!(parser.is_eof());
1023 }
1024 }
1025
1026 #[mz_ore::test]
1027 fn test_copy_csv_format_params() {
1028 assert_eq!(
1029 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1030 Ok(CopyCsvFormatParams {
1031 delimiter: b't',
1032 quote: b'q',
1033 escape: b'q',
1034 header: false,
1035 null: Cow::from(""),
1036 })
1037 );
1038
1039 assert_eq!(
1040 CopyCsvFormatParams::try_new(
1041 Some(b't'),
1042 Some(b'q'),
1043 Some(b'e'),
1044 Some(true),
1045 Some("null".to_string())
1046 ),
1047 Ok(CopyCsvFormatParams {
1048 delimiter: b't',
1049 quote: b'q',
1050 escape: b'e',
1051 header: true,
1052 null: Cow::from("null"),
1053 })
1054 );
1055
1056 assert_eq!(
1057 CopyCsvFormatParams::try_new(
1058 None,
1059 Some(b','),
1060 Some(b'e'),
1061 Some(true),
1062 Some("null".to_string())
1063 ),
1064 Err("COPY delimiter and quote must be different".to_string())
1065 );
1066 }
1067
1068 #[mz_ore::test]
1069 fn test_copy_csv_row() -> Result<(), io::Error> {
1070 let mut row = Row::default();
1071 let mut packer = row.packer();
1072 packer.push(Datum::from("1,2,\"3\""));
1073 packer.push(Datum::Null);
1074 packer.push(Datum::from(1000u64));
1075 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
1077
1078 let typ: RelationType = RelationType::new(vec![
1079 ColumnType {
1080 scalar_type: mz_repr::ScalarType::String,
1081 nullable: false,
1082 },
1083 ColumnType {
1084 scalar_type: mz_repr::ScalarType::String,
1085 nullable: true,
1086 },
1087 ColumnType {
1088 scalar_type: mz_repr::ScalarType::UInt64,
1089 nullable: false,
1090 },
1091 ColumnType {
1092 scalar_type: mz_repr::ScalarType::String,
1093 nullable: false,
1094 },
1095 ColumnType {
1096 scalar_type: mz_repr::ScalarType::String,
1097 nullable: false,
1098 },
1099 ]);
1100
1101 let mut out = Vec::new();
1102
1103 struct TestCase<'a> {
1104 params: CopyCsvFormatParams<'a>,
1105 expected: &'static [u8],
1106 }
1107
1108 let tests = [
1109 TestCase {
1110 params: CopyCsvFormatParams::default(),
1111 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1112 },
1113 TestCase {
1114 params: CopyCsvFormatParams {
1115 null: Cow::from("NULL"),
1116 quote: b'q',
1117 escape: b'e',
1118 ..Default::default()
1119 },
1120 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1121 },
1122 ];
1123
1124 for TestCase { params, expected } in tests {
1125 out.clear();
1126 let params = CopyFormatParams::Csv(params);
1127 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1128 let output = std::str::from_utf8(&out);
1129 assert_eq!(output, std::str::from_utf8(expected));
1130 }
1131
1132 Ok(())
1133 }
1134
1135 proptest! {
1136 #[mz_ore::test]
1137 #[cfg_attr(miri, ignore)]
1138 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1139 let try_roundtrip_datum = |scalar_type: &ScalarType, datum| {
1141 let row = Row::pack_slice(&[datum]);
1142 let typ = RelationType::new(vec![
1143 ColumnType {
1144 scalar_type: scalar_type.clone(),
1145 nullable: true,
1146 }
1147 ]);
1148
1149 let mut buf = Vec::new();
1150 let mut csv_params = copy_csv_params.clone();
1151 csv_params.header = false;
1153 let params = CopyFormatParams::Csv(csv_params);
1154
1155 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1157 let column_types = typ
1158 .column_types
1159 .iter()
1160 .map(|x| &x.scalar_type)
1161 .map(mz_pgrepr::Type::from)
1162 .collect::<Vec<mz_pgrepr::Type>>();
1163 let result = decode_copy_format(&buf, &column_types, params);
1164
1165 match result {
1166 Ok(rows) => {
1167 let out_str = std::str::from_utf8(&buf[..]);
1168
1169 prop_assert_eq!(
1170 rows.len(),
1171 1,
1172 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1173 );
1174 let output = rows.into_element();
1175
1176 prop_assert_eq!(
1177 row,
1178 output,
1179 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1180 );
1181 }
1182 _ => {
1183 }
1185 }
1186
1187 Ok(())
1188 };
1189
1190 for scalar_type in ScalarType::enumerate() {
1192 for datum in scalar_type.interesting_datums() {
1193 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1195 let mut buf = bytes::BytesMut::new();
1196 value.encode_text(&mut buf);
1197
1198 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1199 if datum_str == copy_csv_params.null {
1200 continue;
1201 }
1202 }
1203 }
1204
1205 let updated_datum = match datum {
1206 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1208 continue;
1209 }
1210 Datum::String(s) => {
1211 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1213 continue;
1214 } else {
1215 Datum::String(s)
1216 }
1217 }
1218 other => other,
1219 };
1220
1221 let result = try_roundtrip_datum(scalar_type, updated_datum);
1222 prop_assert!(result.is_ok(), "failure: {result:?}");
1223 }
1224 }
1225 }
1226 }
1227}