1use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use mz_proto::{ProtoType, RustType, TryFromProtoError};
16use mz_repr::{
17 Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18 SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, Just, any};
21use proptest::strategy::{BoxedStrategy, Strategy, Union};
22use serde::Deserialize;
23use serde::Serialize;
24
25static END_OF_COPY_MARKER: &[u8] = b"\\.";
26
27include!(concat!(env!("OUT_DIR"), "/mz_pgcopy.copy.rs"));
28
29fn encode_copy_row_binary(
30 row: &RowRef,
31 typ: &SqlRelationType,
32 out: &mut Vec<u8>,
33) -> Result<(), io::Error> {
34 const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
35
36 let count = i16::try_from(typ.column_types.len()).map_err(|_| {
38 io::Error::new(
39 io::ErrorKind::Other,
40 "column count does not fit into an i16",
41 )
42 })?;
43
44 out.extend(count.to_be_bytes());
45 let mut buf = BytesMut::new();
46 for (field, typ) in row
47 .iter()
48 .zip(&typ.column_types)
49 .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
50 {
51 match field {
52 None => out.extend(NULL_BYTES),
53 Some(field) => {
54 buf.clear();
55 field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
56 out.extend(
57 i32::try_from(buf.len())
58 .map_err(|_| {
59 io::Error::new(
60 io::ErrorKind::Other,
61 "field length does not fit into an i32",
62 )
63 })?
64 .to_be_bytes(),
65 );
66 out.extend(&buf);
67 }
68 }
69 }
70 Ok(())
71}
72
73fn encode_copy_row_text(
74 CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
75 row: &RowRef,
76 typ: &SqlRelationType,
77 out: &mut Vec<u8>,
78) -> Result<(), io::Error> {
79 let null = null.as_bytes();
80 let mut buf = BytesMut::new();
81 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
82 if idx > 0 {
83 out.push(*delimiter);
84 }
85 match field {
86 None => out.extend(null),
87 Some(field) => {
88 buf.clear();
89 field.encode_text(&mut buf);
90 for b in &buf {
91 match b {
92 b'\\' => out.extend(b"\\\\"),
93 b'\n' => out.extend(b"\\n"),
94 b'\r' => out.extend(b"\\r"),
95 b'\t' => out.extend(b"\\t"),
96 _ => out.push(*b),
97 }
98 }
99 }
100 }
101 }
102 out.push(b'\n');
103 Ok(())
104}
105
106fn encode_copy_row_csv(
107 CopyCsvFormatParams {
108 delimiter: delim,
109 quote,
110 escape,
111 header: _,
112 null,
113 }: &CopyCsvFormatParams,
114 row: &RowRef,
115 typ: &SqlRelationType,
116 out: &mut Vec<u8>,
117) -> Result<(), io::Error> {
118 let null = null.as_bytes();
119 let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
120 let mut buf = BytesMut::new();
121 for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
122 if idx > 0 {
123 out.push(*delim);
124 }
125 match field {
126 None => out.extend(null),
127 Some(field) => {
128 buf.clear();
129 field.encode_text(&mut buf);
130 if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
136 || buf.iter().any(is_special)
137 || &*buf == null
138 {
139 out.push(*quote);
143 for b in &buf {
144 if *b == *quote || *b == *escape {
145 out.push(*escape);
146 }
147 out.push(*b);
148 }
149 out.push(*quote);
150 } else {
151 out.extend(&buf);
154 }
155 }
156 }
157 }
158 out.push(b'\n');
159 Ok(())
160}
161
162pub struct CopyTextFormatParser<'a> {
163 data: &'a [u8],
164 position: usize,
165 column_delimiter: u8,
166 null_string: &'a str,
167 buffer: Vec<u8>,
168}
169
170impl<'a> CopyTextFormatParser<'a> {
171 pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
172 Self {
173 data,
174 position: 0,
175 column_delimiter,
176 null_string,
177 buffer: Vec::new(),
178 }
179 }
180
181 fn peek(&self) -> Option<u8> {
182 if self.position < self.data.len() {
183 Some(self.data[self.position])
184 } else {
185 None
186 }
187 }
188
189 fn consume_n(&mut self, n: usize) {
190 self.position = std::cmp::min(self.position + n, self.data.len());
191 }
192
193 pub fn is_eof(&self) -> bool {
194 self.peek().is_none() || self.is_end_of_copy_marker()
195 }
196
197 pub fn is_end_of_copy_marker(&self) -> bool {
198 self.check_bytes(END_OF_COPY_MARKER)
199 }
200
201 fn is_end_of_line(&self) -> bool {
202 match self.peek() {
203 Some(b'\n') | None => true,
204 _ => false,
205 }
206 }
207
208 pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
209 if self.is_end_of_line() {
210 self.consume_n(1);
211 Ok(())
212 } else {
213 Err(io::Error::new(
214 io::ErrorKind::InvalidData,
215 "extra data after last expected column",
216 ))
217 }
218 }
219
220 fn is_column_delimiter(&self) -> bool {
221 self.check_bytes(&[self.column_delimiter])
222 }
223
224 pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
225 if self.consume_bytes(&[self.column_delimiter]) {
226 Ok(())
227 } else {
228 Err(io::Error::new(
229 io::ErrorKind::InvalidData,
230 "missing data for column",
231 ))
232 }
233 }
234
235 fn check_bytes(&self, bytes: &[u8]) -> bool {
236 let remaining_bytes = self.data.len() - self.position;
237 remaining_bytes >= bytes.len()
238 && self.data[self.position..]
239 .iter()
240 .zip(bytes.iter())
241 .all(|(x, y)| x == y)
242 }
243
244 fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
245 if self.check_bytes(bytes) {
246 self.consume_n(bytes.len());
247 true
248 } else {
249 false
250 }
251 }
252
253 fn consume_null_string(&mut self) -> bool {
254 if self.null_string.is_empty() {
255 self.is_column_delimiter()
258 || self.is_end_of_line()
259 || self.is_end_of_copy_marker()
260 || self.is_eof()
261 } else {
262 self.consume_bytes(self.null_string.as_bytes())
263 }
264 }
265
266 pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
267 if self.consume_null_string() {
268 return Ok(None);
269 }
270
271 let mut start = self.position;
272
273 self.buffer.clear();
275
276 while !self.is_eof() && !self.is_end_of_copy_marker() {
277 if self.is_end_of_line() || self.is_column_delimiter() {
278 break;
279 }
280 match self.peek() {
281 Some(b'\\') => {
282 self.buffer.extend(&self.data[start..self.position]);
284
285 self.consume_n(1);
286 match self.peek() {
287 Some(b'b') => {
288 self.consume_n(1);
289 self.buffer.push(8);
290 }
291 Some(b'f') => {
292 self.consume_n(1);
293 self.buffer.push(12);
294 }
295 Some(b'n') => {
296 self.consume_n(1);
297 self.buffer.push(b'\n');
298 }
299 Some(b'r') => {
300 self.consume_n(1);
301 self.buffer.push(b'\r');
302 }
303 Some(b't') => {
304 self.consume_n(1);
305 self.buffer.push(b'\t');
306 }
307 Some(b'v') => {
308 self.consume_n(1);
309 self.buffer.push(11);
310 }
311 Some(b'x') => {
312 self.consume_n(1);
313 match self.peek() {
314 Some(_c @ b'0'..=b'9')
315 | Some(_c @ b'A'..=b'F')
316 | Some(_c @ b'a'..=b'f') => {
317 let mut value: u8 = 0;
318 let decode_nibble = |b| match b {
319 Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
320 Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
321 Some(c @ b'0'..=b'9') => Some(c - b'0'),
322 _ => None,
323 };
324 for _ in 0..2 {
325 match decode_nibble(self.peek()) {
326 Some(c) => {
327 self.consume_n(1);
328 value = value << 4 | c;
329 }
330 _ => break,
331 }
332 }
333 self.buffer.push(value);
334 }
335 _ => {
336 self.buffer.push(b'x');
337 }
338 }
339 }
340 Some(_c @ b'0'..=b'7') => {
341 let mut value: u8 = 0;
342 for _ in 0..3 {
343 match self.peek() {
344 Some(c @ b'0'..=b'7') => {
345 self.consume_n(1);
346 value = value << 3 | (c - b'0');
347 }
348 _ => break,
349 }
350 }
351 self.buffer.push(value);
352 }
353 Some(c) => {
354 self.consume_n(1);
355 self.buffer.push(c);
356 }
357 None => {
358 self.buffer.push(b'\\');
359 }
360 }
361
362 start = self.position;
363 }
364 Some(_) => {
365 self.consume_n(1);
366 }
367 None => {}
368 }
369 }
370
371 if self.buffer.is_empty() {
373 Ok(Some(&self.data[start..self.position]))
374 } else {
375 self.buffer.extend(&self.data[start..self.position]);
378 Ok(Some(&self.buffer[..]))
379 }
380 }
381
382 pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
384 RawIterator {
385 parser: self,
386 current_column: 0,
387 num_columns,
388 truncate: false,
389 }
390 }
391
392 pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
394 RawIterator {
395 parser: self,
396 current_column: 0,
397 num_columns,
398 truncate: true,
399 }
400 }
401}
402
403pub struct RawIterator<'a> {
404 parser: CopyTextFormatParser<'a>,
405 current_column: usize,
406 num_columns: usize,
407 truncate: bool,
408}
409
410impl<'a> RawIterator<'a> {
411 pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
412 if self.current_column > self.num_columns {
413 return None;
414 }
415
416 if self.current_column == self.num_columns {
417 if !self.truncate {
418 if let Some(err) = self.parser.expect_end_of_line().err() {
419 return Some(Err(err));
420 }
421 }
422
423 return None;
424 }
425
426 if self.current_column > 0 {
427 if let Some(err) = self.parser.expect_column_delimiter().err() {
428 return Some(Err(err));
429 }
430 }
431
432 self.current_column += 1;
433 Some(self.parser.consume_raw_value())
434 }
435}
436
437#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
438pub enum CopyFormatParams<'a> {
439 Text(CopyTextFormatParams<'a>),
440 Csv(CopyCsvFormatParams<'a>),
441 Binary,
442 Parquet,
443}
444
445impl RustType<ProtoCopyFormatParams> for CopyFormatParams<'static> {
446 fn into_proto(&self) -> ProtoCopyFormatParams {
447 use proto_copy_format_params::Kind;
448 ProtoCopyFormatParams {
449 kind: Some(match self {
450 Self::Text(f) => Kind::Text(f.into_proto()),
451 Self::Csv(f) => Kind::Csv(f.into_proto()),
452 Self::Binary => Kind::Binary(()),
453 Self::Parquet => Kind::Parquet(ProtoCopyParquetFormatParams::default()),
454 }),
455 }
456 }
457
458 fn from_proto(proto: ProtoCopyFormatParams) -> Result<Self, TryFromProtoError> {
459 use proto_copy_format_params::Kind;
460 match proto.kind {
461 Some(Kind::Text(f)) => Ok(Self::Text(f.into_rust()?)),
462 Some(Kind::Csv(f)) => Ok(Self::Csv(f.into_rust()?)),
463 Some(Kind::Binary(())) => Ok(Self::Binary),
464 Some(Kind::Parquet(ProtoCopyParquetFormatParams {})) => Ok(Self::Parquet),
465 None => Err(TryFromProtoError::missing_field(
466 "ProtoCopyFormatParams::kind",
467 )),
468 }
469 }
470}
471
472impl Arbitrary for CopyFormatParams<'static> {
473 type Parameters = ();
474 type Strategy = Union<BoxedStrategy<Self>>;
475
476 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
477 Union::new(vec![
478 any::<CopyTextFormatParams>().prop_map(Self::Text).boxed(),
479 any::<CopyCsvFormatParams>().prop_map(Self::Csv).boxed(),
480 Just(Self::Binary).boxed(),
481 ])
482 }
483}
484
485impl CopyFormatParams<'static> {
486 pub fn file_extension(&self) -> &str {
487 match self {
488 &CopyFormatParams::Text(_) => "txt",
489 &CopyFormatParams::Csv(_) => "csv",
490 &CopyFormatParams::Binary => "bin",
491 &CopyFormatParams::Parquet => "parquet",
492 }
493 }
494
495 pub fn requires_header(&self) -> bool {
496 match self {
497 CopyFormatParams::Text(_) => false,
498 CopyFormatParams::Csv(params) => params.header,
499 CopyFormatParams::Binary => false,
500 CopyFormatParams::Parquet => false,
501 }
502 }
503}
504
505pub fn decode_copy_format<'a>(
507 data: &[u8],
508 column_types: &[mz_pgrepr::Type],
509 params: CopyFormatParams<'a>,
510) -> Result<Vec<Row>, io::Error> {
511 match params {
512 CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
513 CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
514 CopyFormatParams::Binary => Err(io::Error::new(
515 io::ErrorKind::Unsupported,
516 "cannot decode as binary format",
517 )),
518 CopyFormatParams::Parquet => {
519 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
521 }
522 }
523}
524
525pub fn encode_copy_format<'a>(
527 params: &CopyFormatParams<'a>,
528 row: &RowRef,
529 typ: &SqlRelationType,
530 out: &mut Vec<u8>,
531) -> Result<(), io::Error> {
532 match params {
533 CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
534 CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
535 CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
536 CopyFormatParams::Parquet => {
537 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
539 }
540 }
541}
542
543pub fn encode_copy_format_header<'a>(
544 params: &CopyFormatParams<'a>,
545 desc: &RelationDesc,
546 out: &mut Vec<u8>,
547) -> Result<(), io::Error> {
548 match params {
549 CopyFormatParams::Text(_) => Ok(()),
550 CopyFormatParams::Binary => Ok(()),
551 CopyFormatParams::Csv(params) => {
552 let mut header_row = Row::with_capacity(desc.arity());
553 header_row
554 .packer()
555 .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
556 let typ = SqlRelationType::new(vec![
557 SqlColumnType {
558 scalar_type: SqlScalarType::String,
559 nullable: false,
560 };
561 desc.arity()
562 ]);
563 encode_copy_row_csv(params, &header_row, &typ, out)
564 }
565 CopyFormatParams::Parquet => {
566 Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
568 }
569 }
570}
571
572#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
573pub struct CopyTextFormatParams<'a> {
574 pub null: Cow<'a, str>,
575 pub delimiter: u8,
576}
577
578impl<'a> Default for CopyTextFormatParams<'a> {
579 fn default() -> Self {
580 CopyTextFormatParams {
581 delimiter: b'\t',
582 null: Cow::from("\\N"),
583 }
584 }
585}
586
587impl RustType<ProtoCopyTextFormatParams> for CopyTextFormatParams<'static> {
588 fn into_proto(&self) -> ProtoCopyTextFormatParams {
589 ProtoCopyTextFormatParams {
590 null: self.null.into_proto(),
591 delimiter: self.delimiter.into_proto(),
592 }
593 }
594
595 fn from_proto(proto: ProtoCopyTextFormatParams) -> Result<Self, TryFromProtoError> {
596 Ok(Self {
597 null: Cow::Owned(proto.null.into_rust()?),
598 delimiter: proto.delimiter.into_rust()?,
599 })
600 }
601}
602
603impl Arbitrary for CopyTextFormatParams<'static> {
604 type Parameters = ();
605 type Strategy = BoxedStrategy<Self>;
606
607 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
608 (any::<String>(), any::<u8>())
609 .prop_map(|(null, delimiter)| Self {
610 null: Cow::Owned(null),
611 delimiter,
612 })
613 .boxed()
614 }
615}
616
617pub fn decode_copy_format_text(
618 data: &[u8],
619 column_types: &[mz_pgrepr::Type],
620 CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
621) -> Result<Vec<Row>, io::Error> {
622 let mut rows = Vec::new();
623
624 let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
626 while !parser.is_eof() && !parser.is_end_of_copy_marker() {
627 let mut row = Vec::new();
628 let buf = RowArena::new();
629 for (col, typ) in column_types.iter().enumerate() {
630 if col > 0 {
631 parser.expect_column_delimiter()?;
632 }
633 let raw_value = parser.consume_raw_value()?;
634 if let Some(raw_value) = raw_value {
635 match mz_pgrepr::Value::decode_text(typ, raw_value) {
636 Ok(value) => row.push(value.into_datum(&buf, typ)),
637 Err(err) => {
638 let msg = format!("unable to decode column: {}", err);
639 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
640 }
641 }
642 } else {
643 row.push(Datum::Null);
644 }
645 }
646 parser.expect_end_of_line()?;
647 rows.push(Row::pack(row));
648 }
649 Ok(rows)
652}
653
654#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
655pub struct CopyCsvFormatParams<'a> {
656 pub delimiter: u8,
657 pub quote: u8,
658 pub escape: u8,
659 pub header: bool,
660 pub null: Cow<'a, str>,
661}
662
663impl<'a> CopyCsvFormatParams<'a> {
664 pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
665 CopyCsvFormatParams {
666 delimiter: self.delimiter,
667 quote: self.quote,
668 escape: self.escape,
669 header: self.header,
670 null: Cow::Owned(self.null.to_string()),
671 }
672 }
673}
674
675impl RustType<ProtoCopyCsvFormatParams> for CopyCsvFormatParams<'static> {
676 fn into_proto(&self) -> ProtoCopyCsvFormatParams {
677 ProtoCopyCsvFormatParams {
678 delimiter: self.delimiter.into(),
679 quote: self.quote.into(),
680 escape: self.escape.into(),
681 header: self.header,
682 null: self.null.into_proto(),
683 }
684 }
685
686 fn from_proto(proto: ProtoCopyCsvFormatParams) -> Result<Self, TryFromProtoError> {
687 Ok(Self {
688 delimiter: proto.delimiter.into_rust()?,
689 quote: proto.quote.into_rust()?,
690 escape: proto.escape.into_rust()?,
691 header: proto.header,
692 null: Cow::Owned(proto.null.into_rust()?),
693 })
694 }
695}
696
697impl Arbitrary for CopyCsvFormatParams<'static> {
698 type Parameters = ();
699 type Strategy = BoxedStrategy<Self>;
700
701 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
702 (
703 any::<u8>(),
704 any::<u8>(),
705 any::<u8>(),
706 any::<bool>(),
707 any::<String>(),
708 )
709 .prop_map(|(delimiter, diff, escape, header, null)| {
710 let diff = diff.saturating_sub(1).max(1);
712 let quote = delimiter.wrapping_add(diff);
713
714 Self::try_new(
715 Some(delimiter),
716 Some(quote),
717 Some(escape),
718 Some(header),
719 Some(null),
720 )
721 .expect("delimiter and quote should be different")
722 })
723 .boxed()
724 }
725}
726
727impl<'a> Default for CopyCsvFormatParams<'a> {
728 fn default() -> Self {
729 CopyCsvFormatParams {
730 delimiter: b',',
731 quote: b'"',
732 escape: b'"',
733 header: false,
734 null: Cow::from(""),
735 }
736 }
737}
738
739impl<'a> CopyCsvFormatParams<'a> {
740 pub fn try_new(
741 delimiter: Option<u8>,
742 quote: Option<u8>,
743 escape: Option<u8>,
744 header: Option<bool>,
745 null: Option<String>,
746 ) -> Result<CopyCsvFormatParams<'a>, String> {
747 let mut params = CopyCsvFormatParams::default();
748
749 if let Some(delimiter) = delimiter {
750 params.delimiter = delimiter;
751 }
752 if let Some(quote) = quote {
753 params.quote = quote;
754 params.escape = quote;
756 }
757 if let Some(escape) = escape {
758 params.escape = escape;
759 }
760 if let Some(header) = header {
761 params.header = header;
762 }
763 if let Some(null) = null {
764 params.null = Cow::from(null);
765 }
766
767 if params.quote == params.delimiter {
768 return Err("COPY delimiter and quote must be different".to_string());
769 }
770 Ok(params)
771 }
772}
773
774pub fn decode_copy_format_csv(
775 data: &[u8],
776 column_types: &[mz_pgrepr::Type],
777 CopyCsvFormatParams {
778 delimiter,
779 quote,
780 escape,
781 null,
782 header,
783 }: CopyCsvFormatParams,
784) -> Result<Vec<Row>, io::Error> {
785 let mut rows = Vec::new();
786
787 let (double_quote, escape) = if quote == escape {
788 (true, None)
789 } else {
790 (false, Some(escape))
791 };
792
793 let mut rdr = ReaderBuilder::new()
794 .delimiter(delimiter)
795 .quote(quote)
796 .has_headers(header)
797 .double_quote(double_quote)
798 .escape(escape)
799 .flexible(true)
802 .from_reader(data);
803
804 let null_as_bytes = null.as_bytes();
805
806 let mut record = ByteRecord::new();
807
808 while rdr.read_byte_record(&mut record)? {
809 if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
810 break;
811 }
812
813 match record.len().cmp(&column_types.len()) {
814 std::cmp::Ordering::Less => Err(io::Error::new(
815 io::ErrorKind::InvalidData,
816 "missing data for column",
817 )),
818 std::cmp::Ordering::Greater => Err(io::Error::new(
819 io::ErrorKind::InvalidData,
820 "extra data after last expected column",
821 )),
822 std::cmp::Ordering::Equal => Ok(()),
823 }?;
824
825 let mut row_builder = SharedRow::get();
826 let mut row_packer = row_builder.packer();
827
828 for (typ, raw_value) in column_types.iter().zip(record.iter()) {
829 if raw_value == null_as_bytes {
830 row_packer.push(Datum::Null);
831 } else {
832 let s = match std::str::from_utf8(raw_value) {
833 Ok(s) => s,
834 Err(err) => {
835 let msg = format!("invalid utf8 data in column: {}", err);
836 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
837 }
838 };
839 match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
840 Ok(()) => {}
841 Err(err) => {
842 let msg = format!("unable to decode column: {}", err);
843 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
844 }
845 }
846 }
847 }
848 rows.push(row_builder.clone());
849 }
850
851 Ok(rows)
852}
853
854#[cfg(test)]
855mod tests {
856 use mz_ore::collections::CollectionExt;
857 use mz_repr::{SqlColumnType, SqlScalarType};
858 use proptest::prelude::*;
859
860 use super::*;
861
862 #[mz_ore::test]
863 fn test_copy_format_text_parser() {
864 let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
865 let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
866 assert!(parser.is_column_delimiter());
867 parser
868 .expect_column_delimiter()
869 .expect("expected column delimiter");
870 assert_eq!(
871 parser
872 .consume_raw_value()
873 .expect("unexpected error")
874 .expect("unexpected empty result"),
875 "\nt e".as_bytes()
876 );
877 parser
878 .expect_column_delimiter()
879 .expect("expected column delimiter");
880 assert!(
882 parser
883 .consume_raw_value()
884 .expect("unexpected error")
885 .is_none()
886 );
887 parser
888 .expect_column_delimiter()
889 .expect("expected column delimiter");
890 assert!(parser.is_end_of_line());
891 parser.expect_end_of_line().expect("expected eol");
892 assert_eq!(
894 parser
895 .consume_raw_value()
896 .expect("unexpected error")
897 .expect("unexpected empty result"),
898 "`\n}J".as_bytes()
899 );
900 parser.expect_end_of_line().expect("expected eol");
901 assert_eq!(
903 parser
904 .consume_raw_value()
905 .expect("unexpected error")
906 .expect("unexpected empty result"),
907 "$$S".as_bytes()
908 );
909 assert!(parser.is_eof());
910 }
911
912 #[mz_ore::test]
913 fn test_copy_format_text_empty_null_string() {
914 let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
915 let expect = vec![
916 vec![None, None],
917 vec![Some("10"), Some("20")],
918 vec![Some("30"), None],
919 vec![Some("40"), None],
920 ];
921 let mut parser = CopyTextFormatParser::new(text, b'\t', "");
922 for line in expect {
923 for (i, value) in line.iter().enumerate() {
924 if i > 0 {
925 parser
926 .expect_column_delimiter()
927 .expect("expected column delimiter");
928 }
929 match value {
930 Some(s) => {
931 assert!(!parser.consume_null_string());
932 assert_eq!(
933 parser
934 .consume_raw_value()
935 .expect("unexpected error")
936 .expect("unexpected empty result"),
937 s.as_bytes()
938 );
939 }
940 None => {
941 assert!(parser.consume_null_string());
942 }
943 }
944 }
945 parser.expect_end_of_line().expect("expected eol");
946 }
947 }
948
949 #[mz_ore::test]
950 fn test_copy_format_text_parser_escapes() {
951 struct TestCase {
952 input: &'static str,
953 expect: &'static [u8],
954 }
955 let tests = vec![
956 TestCase {
957 input: "simple",
958 expect: b"simple",
959 },
960 TestCase {
961 input: r#"new\nline"#,
962 expect: b"new\nline",
963 },
964 TestCase {
965 input: r#"\b\f\n\r\t\v\\"#,
966 expect: b"\x08\x0c\n\r\t\x0b\\",
967 },
968 TestCase {
969 input: r#"\0\12\123"#,
970 expect: &[0, 0o12, 0o123],
971 },
972 TestCase {
973 input: r#"\x1\xaf"#,
974 expect: &[0x01, 0xaf],
975 },
976 TestCase {
977 input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
978 expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
979 },
980 TestCase {
981 input: r#"\\\""#,
982 expect: b"\\\"",
983 },
984 TestCase {
985 input: r#"\x"#,
986 expect: b"x",
987 },
988 TestCase {
989 input: r#"\xg"#,
990 expect: b"xg",
991 },
992 TestCase {
993 input: r#"\"#,
994 expect: b"\\",
995 },
996 TestCase {
997 input: r#"\8"#,
998 expect: b"8",
999 },
1000 TestCase {
1001 input: r#"\a"#,
1002 expect: b"a",
1003 },
1004 TestCase {
1005 input: r#"\x\xg\8\xH\x32\s\"#,
1006 expect: b"xxg8xH2s\\",
1007 },
1008 ];
1009
1010 for test in tests {
1011 let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1012 assert_eq!(
1013 parser
1014 .consume_raw_value()
1015 .expect("unexpected error")
1016 .expect("unexpected empty result"),
1017 test.expect,
1018 "input: {}, expect: {:?}",
1019 test.input,
1020 std::str::from_utf8(test.expect),
1021 );
1022 assert!(parser.is_eof());
1023 }
1024 }
1025
1026 #[mz_ore::test]
1027 fn test_copy_csv_format_params() {
1028 assert_eq!(
1029 CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1030 Ok(CopyCsvFormatParams {
1031 delimiter: b't',
1032 quote: b'q',
1033 escape: b'q',
1034 header: false,
1035 null: Cow::from(""),
1036 })
1037 );
1038
1039 assert_eq!(
1040 CopyCsvFormatParams::try_new(
1041 Some(b't'),
1042 Some(b'q'),
1043 Some(b'e'),
1044 Some(true),
1045 Some("null".to_string())
1046 ),
1047 Ok(CopyCsvFormatParams {
1048 delimiter: b't',
1049 quote: b'q',
1050 escape: b'e',
1051 header: true,
1052 null: Cow::from("null"),
1053 })
1054 );
1055
1056 assert_eq!(
1057 CopyCsvFormatParams::try_new(
1058 None,
1059 Some(b','),
1060 Some(b'e'),
1061 Some(true),
1062 Some("null".to_string())
1063 ),
1064 Err("COPY delimiter and quote must be different".to_string())
1065 );
1066 }
1067
1068 #[mz_ore::test]
1069 fn test_copy_csv_row() -> Result<(), io::Error> {
1070 let mut row = Row::default();
1071 let mut packer = row.packer();
1072 packer.push(Datum::from("1,2,\"3\""));
1073 packer.push(Datum::Null);
1074 packer.push(Datum::from(1000u64));
1075 packer.push(Datum::from("qe")); packer.push(Datum::from(""));
1077
1078 let typ: SqlRelationType = SqlRelationType::new(vec![
1079 SqlColumnType {
1080 scalar_type: mz_repr::SqlScalarType::String,
1081 nullable: false,
1082 },
1083 SqlColumnType {
1084 scalar_type: mz_repr::SqlScalarType::String,
1085 nullable: true,
1086 },
1087 SqlColumnType {
1088 scalar_type: mz_repr::SqlScalarType::UInt64,
1089 nullable: false,
1090 },
1091 SqlColumnType {
1092 scalar_type: mz_repr::SqlScalarType::String,
1093 nullable: false,
1094 },
1095 SqlColumnType {
1096 scalar_type: mz_repr::SqlScalarType::String,
1097 nullable: false,
1098 },
1099 ]);
1100
1101 let mut out = Vec::new();
1102
1103 struct TestCase<'a> {
1104 params: CopyCsvFormatParams<'a>,
1105 expected: &'static [u8],
1106 }
1107
1108 let tests = [
1109 TestCase {
1110 params: CopyCsvFormatParams::default(),
1111 expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1112 },
1113 TestCase {
1114 params: CopyCsvFormatParams {
1115 null: Cow::from("NULL"),
1116 quote: b'q',
1117 escape: b'e',
1118 ..Default::default()
1119 },
1120 expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1121 },
1122 ];
1123
1124 for TestCase { params, expected } in tests {
1125 out.clear();
1126 let params = CopyFormatParams::Csv(params);
1127 let _ = encode_copy_format(¶ms, &row, &typ, &mut out);
1128 let output = std::str::from_utf8(&out);
1129 assert_eq!(output, std::str::from_utf8(expected));
1130 }
1131
1132 Ok(())
1133 }
1134
1135 proptest! {
1136 #[mz_ore::test]
1137 #[cfg_attr(miri, ignore)]
1138 fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams) {
1139 let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1141 let row = Row::pack_slice(&[datum]);
1142 let typ = SqlRelationType::new(vec![
1143 SqlColumnType {
1144 scalar_type: scalar_type.clone(),
1145 nullable: true,
1146 }
1147 ]);
1148
1149 let mut buf = Vec::new();
1150 let mut csv_params = copy_csv_params.clone();
1151 csv_params.header = false;
1153 let params = CopyFormatParams::Csv(csv_params);
1154
1155 encode_copy_format(¶ms, &row, &typ, &mut buf)?;
1157 let column_types = typ
1158 .column_types
1159 .iter()
1160 .map(|x| &x.scalar_type)
1161 .map(mz_pgrepr::Type::from)
1162 .collect::<Vec<mz_pgrepr::Type>>();
1163 let result = decode_copy_format(&buf, &column_types, params);
1164
1165 match result {
1166 Ok(rows) => {
1167 let out_str = std::str::from_utf8(&buf[..]);
1168
1169 prop_assert_eq!(
1170 rows.len(),
1171 1,
1172 "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1173 );
1174 let output = rows.into_element();
1175
1176 prop_assert_eq!(
1177 row,
1178 output,
1179 "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1180 );
1181 }
1182 _ => {
1183 }
1185 }
1186
1187 Ok(())
1188 };
1189
1190 for scalar_type in SqlScalarType::enumerate() {
1192 for datum in scalar_type.interesting_datums() {
1193 if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1195 let mut buf = bytes::BytesMut::new();
1196 value.encode_text(&mut buf);
1197
1198 if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1199 if datum_str == copy_csv_params.null {
1200 continue;
1201 }
1202 }
1203 }
1204
1205 let updated_datum = match datum {
1206 Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1208 continue;
1209 }
1210 Datum::String(s) => {
1211 if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1213 continue;
1214 } else {
1215 Datum::String(s)
1216 }
1217 }
1218 other => other,
1219 };
1220
1221 let result = try_roundtrip_datum(scalar_type, updated_datum);
1222 prop_assert!(result.is_ok(), "failure: {result:?}");
1223 }
1224 }
1225 }
1226 }
1227}