mz_pgcopy/
copy.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use mz_proto::{ProtoType, RustType, TryFromProtoError};
16use mz_repr::{
17    Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18    SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, Just, any};
21use proptest::strategy::{BoxedStrategy, Strategy, Union};
22use serde::Deserialize;
23use serde::Serialize;
24
25static END_OF_COPY_MARKER: &[u8] = b"\\.";
26
27include!(concat!(env!("OUT_DIR"), "/mz_pgcopy.copy.rs"));
28
29fn encode_copy_row_binary(
30    row: &RowRef,
31    typ: &SqlRelationType,
32    out: &mut Vec<u8>,
33) -> Result<(), io::Error> {
34    const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
35
36    // 16-bit int of number of tuples.
37    let count = i16::try_from(typ.column_types.len()).map_err(|_| {
38        io::Error::new(
39            io::ErrorKind::Other,
40            "column count does not fit into an i16",
41        )
42    })?;
43
44    out.extend(count.to_be_bytes());
45    let mut buf = BytesMut::new();
46    for (field, typ) in row
47        .iter()
48        .zip(&typ.column_types)
49        .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
50    {
51        match field {
52            None => out.extend(NULL_BYTES),
53            Some(field) => {
54                buf.clear();
55                field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
56                out.extend(
57                    i32::try_from(buf.len())
58                        .map_err(|_| {
59                            io::Error::new(
60                                io::ErrorKind::Other,
61                                "field length does not fit into an i32",
62                            )
63                        })?
64                        .to_be_bytes(),
65                );
66                out.extend(&buf);
67            }
68        }
69    }
70    Ok(())
71}
72
73fn encode_copy_row_text(
74    CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
75    row: &RowRef,
76    typ: &SqlRelationType,
77    out: &mut Vec<u8>,
78) -> Result<(), io::Error> {
79    let null = null.as_bytes();
80    let mut buf = BytesMut::new();
81    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
82        if idx > 0 {
83            out.push(*delimiter);
84        }
85        match field {
86            None => out.extend(null),
87            Some(field) => {
88                buf.clear();
89                field.encode_text(&mut buf);
90                for b in &buf {
91                    match b {
92                        b'\\' => out.extend(b"\\\\"),
93                        b'\n' => out.extend(b"\\n"),
94                        b'\r' => out.extend(b"\\r"),
95                        b'\t' => out.extend(b"\\t"),
96                        _ => out.push(*b),
97                    }
98                }
99            }
100        }
101    }
102    out.push(b'\n');
103    Ok(())
104}
105
106fn encode_copy_row_csv(
107    CopyCsvFormatParams {
108        delimiter: delim,
109        quote,
110        escape,
111        header: _,
112        null,
113    }: &CopyCsvFormatParams,
114    row: &RowRef,
115    typ: &SqlRelationType,
116    out: &mut Vec<u8>,
117) -> Result<(), io::Error> {
118    let null = null.as_bytes();
119    let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
120    let mut buf = BytesMut::new();
121    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
122        if idx > 0 {
123            out.push(*delim);
124        }
125        match field {
126            None => out.extend(null),
127            Some(field) => {
128                buf.clear();
129                field.encode_text(&mut buf);
130                // A field needs quoting if:
131                //   * It is the only field and the value is exactly the end
132                //     of copy marker.
133                //   * The field contains a special character.
134                //   * The field is exactly the NULL sentinel.
135                if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
136                    || buf.iter().any(is_special)
137                    || &*buf == null
138                {
139                    // Quote the value by wrapping it in the quote character and
140                    // emitting the escape character before any quote or escape
141                    // characters within.
142                    out.push(*quote);
143                    for b in &buf {
144                        if *b == *quote || *b == *escape {
145                            out.push(*escape);
146                        }
147                        out.push(*b);
148                    }
149                    out.push(*quote);
150                } else {
151                    // The value does not need quoting and can be emitted
152                    // directly.
153                    out.extend(&buf);
154                }
155            }
156        }
157    }
158    out.push(b'\n');
159    Ok(())
160}
161
162pub struct CopyTextFormatParser<'a> {
163    data: &'a [u8],
164    position: usize,
165    column_delimiter: u8,
166    null_string: &'a str,
167    buffer: Vec<u8>,
168}
169
170impl<'a> CopyTextFormatParser<'a> {
171    pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
172        Self {
173            data,
174            position: 0,
175            column_delimiter,
176            null_string,
177            buffer: Vec::new(),
178        }
179    }
180
181    fn peek(&self) -> Option<u8> {
182        if self.position < self.data.len() {
183            Some(self.data[self.position])
184        } else {
185            None
186        }
187    }
188
189    fn consume_n(&mut self, n: usize) {
190        self.position = std::cmp::min(self.position + n, self.data.len());
191    }
192
193    pub fn is_eof(&self) -> bool {
194        self.peek().is_none() || self.is_end_of_copy_marker()
195    }
196
197    pub fn is_end_of_copy_marker(&self) -> bool {
198        self.check_bytes(END_OF_COPY_MARKER)
199    }
200
201    fn is_end_of_line(&self) -> bool {
202        match self.peek() {
203            Some(b'\n') | None => true,
204            _ => false,
205        }
206    }
207
208    pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
209        if self.is_end_of_line() {
210            self.consume_n(1);
211            Ok(())
212        } else {
213            Err(io::Error::new(
214                io::ErrorKind::InvalidData,
215                "extra data after last expected column",
216            ))
217        }
218    }
219
220    fn is_column_delimiter(&self) -> bool {
221        self.check_bytes(&[self.column_delimiter])
222    }
223
224    pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
225        if self.consume_bytes(&[self.column_delimiter]) {
226            Ok(())
227        } else {
228            Err(io::Error::new(
229                io::ErrorKind::InvalidData,
230                "missing data for column",
231            ))
232        }
233    }
234
235    fn check_bytes(&self, bytes: &[u8]) -> bool {
236        let remaining_bytes = self.data.len() - self.position;
237        remaining_bytes >= bytes.len()
238            && self.data[self.position..]
239                .iter()
240                .zip(bytes.iter())
241                .all(|(x, y)| x == y)
242    }
243
244    fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
245        if self.check_bytes(bytes) {
246            self.consume_n(bytes.len());
247            true
248        } else {
249            false
250        }
251    }
252
253    fn consume_null_string(&mut self) -> bool {
254        if self.null_string.is_empty() {
255            // An empty NULL marker is supported. Look ahead to ensure that is followed by
256            // a column delimiter, an end of line or it is at the end of the data.
257            self.is_column_delimiter()
258                || self.is_end_of_line()
259                || self.is_end_of_copy_marker()
260                || self.is_eof()
261        } else {
262            self.consume_bytes(self.null_string.as_bytes())
263        }
264    }
265
266    pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
267        if self.consume_null_string() {
268            return Ok(None);
269        }
270
271        let mut start = self.position;
272
273        // buffer where unescaped data is accumulated
274        self.buffer.clear();
275
276        while !self.is_eof() && !self.is_end_of_copy_marker() {
277            if self.is_end_of_line() || self.is_column_delimiter() {
278                break;
279            }
280            match self.peek() {
281                Some(b'\\') => {
282                    // Add non-escaped data parsed so far
283                    self.buffer.extend(&self.data[start..self.position]);
284
285                    self.consume_n(1);
286                    match self.peek() {
287                        Some(b'b') => {
288                            self.consume_n(1);
289                            self.buffer.push(8);
290                        }
291                        Some(b'f') => {
292                            self.consume_n(1);
293                            self.buffer.push(12);
294                        }
295                        Some(b'n') => {
296                            self.consume_n(1);
297                            self.buffer.push(b'\n');
298                        }
299                        Some(b'r') => {
300                            self.consume_n(1);
301                            self.buffer.push(b'\r');
302                        }
303                        Some(b't') => {
304                            self.consume_n(1);
305                            self.buffer.push(b'\t');
306                        }
307                        Some(b'v') => {
308                            self.consume_n(1);
309                            self.buffer.push(11);
310                        }
311                        Some(b'x') => {
312                            self.consume_n(1);
313                            match self.peek() {
314                                Some(_c @ b'0'..=b'9')
315                                | Some(_c @ b'A'..=b'F')
316                                | Some(_c @ b'a'..=b'f') => {
317                                    let mut value: u8 = 0;
318                                    let decode_nibble = |b| match b {
319                                        Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
320                                        Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
321                                        Some(c @ b'0'..=b'9') => Some(c - b'0'),
322                                        _ => None,
323                                    };
324                                    for _ in 0..2 {
325                                        match decode_nibble(self.peek()) {
326                                            Some(c) => {
327                                                self.consume_n(1);
328                                                value = value << 4 | c;
329                                            }
330                                            _ => break,
331                                        }
332                                    }
333                                    self.buffer.push(value);
334                                }
335                                _ => {
336                                    self.buffer.push(b'x');
337                                }
338                            }
339                        }
340                        Some(_c @ b'0'..=b'7') => {
341                            let mut value: u8 = 0;
342                            for _ in 0..3 {
343                                match self.peek() {
344                                    Some(c @ b'0'..=b'7') => {
345                                        self.consume_n(1);
346                                        value = value << 3 | (c - b'0');
347                                    }
348                                    _ => break,
349                                }
350                            }
351                            self.buffer.push(value);
352                        }
353                        Some(c) => {
354                            self.consume_n(1);
355                            self.buffer.push(c);
356                        }
357                        None => {
358                            self.buffer.push(b'\\');
359                        }
360                    }
361
362                    start = self.position;
363                }
364                Some(_) => {
365                    self.consume_n(1);
366                }
367                None => {}
368            }
369        }
370
371        // Return a slice of the original buffer if no escaped characters where processed
372        if self.buffer.is_empty() {
373            Ok(Some(&self.data[start..self.position]))
374        } else {
375            // ... otherwise, add the remaining non-escaped data to the decoding buffer
376            // and return a pointer to it
377            self.buffer.extend(&self.data[start..self.position]);
378            Ok(Some(&self.buffer[..]))
379        }
380    }
381
382    /// Error if more than `num_columns` values in `parser`.
383    pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
384        RawIterator {
385            parser: self,
386            current_column: 0,
387            num_columns,
388            truncate: false,
389        }
390    }
391
392    /// Return no more than `num_columns` values from `parser`.
393    pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
394        RawIterator {
395            parser: self,
396            current_column: 0,
397            num_columns,
398            truncate: true,
399        }
400    }
401}
402
403pub struct RawIterator<'a> {
404    parser: CopyTextFormatParser<'a>,
405    current_column: usize,
406    num_columns: usize,
407    truncate: bool,
408}
409
410impl<'a> RawIterator<'a> {
411    pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
412        if self.current_column > self.num_columns {
413            return None;
414        }
415
416        if self.current_column == self.num_columns {
417            if !self.truncate {
418                if let Some(err) = self.parser.expect_end_of_line().err() {
419                    return Some(Err(err));
420                }
421            }
422
423            return None;
424        }
425
426        if self.current_column > 0 {
427            if let Some(err) = self.parser.expect_column_delimiter().err() {
428                return Some(Err(err));
429            }
430        }
431
432        self.current_column += 1;
433        Some(self.parser.consume_raw_value())
434    }
435}
436
437#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
438pub enum CopyFormatParams<'a> {
439    Text(CopyTextFormatParams<'a>),
440    Csv(CopyCsvFormatParams<'a>),
441    Binary,
442    Parquet,
443}
444
445impl RustType<ProtoCopyFormatParams> for CopyFormatParams<'static> {
446    fn into_proto(&self) -> ProtoCopyFormatParams {
447        use proto_copy_format_params::Kind;
448        ProtoCopyFormatParams {
449            kind: Some(match self {
450                Self::Text(f) => Kind::Text(f.into_proto()),
451                Self::Csv(f) => Kind::Csv(f.into_proto()),
452                Self::Binary => Kind::Binary(()),
453                Self::Parquet => Kind::Parquet(ProtoCopyParquetFormatParams::default()),
454            }),
455        }
456    }
457
458    fn from_proto(proto: ProtoCopyFormatParams) -> Result<Self, TryFromProtoError> {
459        use proto_copy_format_params::Kind;
460        match proto.kind {
461            Some(Kind::Text(f)) => Ok(Self::Text(f.into_rust()?)),
462            Some(Kind::Csv(f)) => Ok(Self::Csv(f.into_rust()?)),
463            Some(Kind::Binary(())) => Ok(Self::Binary),
464            Some(Kind::Parquet(ProtoCopyParquetFormatParams {})) => Ok(Self::Parquet),
465            None => Err(TryFromProtoError::missing_field(
466                "ProtoCopyFormatParams::kind",
467            )),
468        }
469    }
470}
471
472impl Arbitrary for CopyFormatParams<'static> {
473    type Parameters = ();
474    type Strategy = Union<BoxedStrategy<Self>>;
475
476    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
477        Union::new(vec![
478            any::<CopyTextFormatParams>().prop_map(Self::Text).boxed(),
479            any::<CopyCsvFormatParams>().prop_map(Self::Csv).boxed(),
480            Just(Self::Binary).boxed(),
481        ])
482    }
483}
484
485impl CopyFormatParams<'static> {
486    pub fn file_extension(&self) -> &str {
487        match self {
488            &CopyFormatParams::Text(_) => "txt",
489            &CopyFormatParams::Csv(_) => "csv",
490            &CopyFormatParams::Binary => "bin",
491            &CopyFormatParams::Parquet => "parquet",
492        }
493    }
494
495    pub fn requires_header(&self) -> bool {
496        match self {
497            CopyFormatParams::Text(_) => false,
498            CopyFormatParams::Csv(params) => params.header,
499            CopyFormatParams::Binary => false,
500            CopyFormatParams::Parquet => false,
501        }
502    }
503}
504
505/// Decodes the given bytes into `Row`-s based on the given `CopyFormatParams`.
506pub fn decode_copy_format<'a>(
507    data: &[u8],
508    column_types: &[mz_pgrepr::Type],
509    params: CopyFormatParams<'a>,
510) -> Result<Vec<Row>, io::Error> {
511    match params {
512        CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
513        CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
514        CopyFormatParams::Binary => Err(io::Error::new(
515            io::ErrorKind::Unsupported,
516            "cannot decode as binary format",
517        )),
518        CopyFormatParams::Parquet => {
519            // TODO(cf2): Support Parquet over STDIN.
520            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
521        }
522    }
523}
524
525/// Encodes the given `Row` into bytes based on the given `CopyFormatParams`.
526pub fn encode_copy_format<'a>(
527    params: &CopyFormatParams<'a>,
528    row: &RowRef,
529    typ: &SqlRelationType,
530    out: &mut Vec<u8>,
531) -> Result<(), io::Error> {
532    match params {
533        CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
534        CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
535        CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
536        CopyFormatParams::Parquet => {
537            // TODO(cf2): Support Parquet over STDIN.
538            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
539        }
540    }
541}
542
543pub fn encode_copy_format_header<'a>(
544    params: &CopyFormatParams<'a>,
545    desc: &RelationDesc,
546    out: &mut Vec<u8>,
547) -> Result<(), io::Error> {
548    match params {
549        CopyFormatParams::Text(_) => Ok(()),
550        CopyFormatParams::Binary => Ok(()),
551        CopyFormatParams::Csv(params) => {
552            let mut header_row = Row::with_capacity(desc.arity());
553            header_row
554                .packer()
555                .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
556            let typ = SqlRelationType::new(vec![
557                SqlColumnType {
558                    scalar_type: SqlScalarType::String,
559                    nullable: false,
560                };
561                desc.arity()
562            ]);
563            encode_copy_row_csv(params, &header_row, &typ, out)
564        }
565        CopyFormatParams::Parquet => {
566            // TODO(cf2): Support Parquet over STDIN.
567            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
568        }
569    }
570}
571
572#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
573pub struct CopyTextFormatParams<'a> {
574    pub null: Cow<'a, str>,
575    pub delimiter: u8,
576}
577
578impl<'a> Default for CopyTextFormatParams<'a> {
579    fn default() -> Self {
580        CopyTextFormatParams {
581            delimiter: b'\t',
582            null: Cow::from("\\N"),
583        }
584    }
585}
586
587impl RustType<ProtoCopyTextFormatParams> for CopyTextFormatParams<'static> {
588    fn into_proto(&self) -> ProtoCopyTextFormatParams {
589        ProtoCopyTextFormatParams {
590            null: self.null.into_proto(),
591            delimiter: self.delimiter.into_proto(),
592        }
593    }
594
595    fn from_proto(proto: ProtoCopyTextFormatParams) -> Result<Self, TryFromProtoError> {
596        Ok(Self {
597            null: Cow::Owned(proto.null.into_rust()?),
598            delimiter: proto.delimiter.into_rust()?,
599        })
600    }
601}
602
603impl Arbitrary for CopyTextFormatParams<'static> {
604    type Parameters = ();
605    type Strategy = BoxedStrategy<Self>;
606
607    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
608        (any::<String>(), any::<u8>())
609            .prop_map(|(null, delimiter)| Self {
610                null: Cow::Owned(null),
611                delimiter,
612            })
613            .boxed()
614    }
615}
616
617pub fn decode_copy_format_text(
618    data: &[u8],
619    column_types: &[mz_pgrepr::Type],
620    CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
621) -> Result<Vec<Row>, io::Error> {
622    let mut rows = Vec::new();
623
624    // TODO: pass the `CopyTextFormatParams` to the `new` method
625    let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
626    while !parser.is_eof() && !parser.is_end_of_copy_marker() {
627        let mut row = Vec::new();
628        let buf = RowArena::new();
629        for (col, typ) in column_types.iter().enumerate() {
630            if col > 0 {
631                parser.expect_column_delimiter()?;
632            }
633            let raw_value = parser.consume_raw_value()?;
634            if let Some(raw_value) = raw_value {
635                match mz_pgrepr::Value::decode_text(typ, raw_value) {
636                    Ok(value) => row.push(value.into_datum(&buf, typ)),
637                    Err(err) => {
638                        let msg = format!("unable to decode column: {}", err);
639                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
640                    }
641                }
642            } else {
643                row.push(Datum::Null);
644            }
645        }
646        parser.expect_end_of_line()?;
647        rows.push(Row::pack(row));
648    }
649    // Note that if there is any junk data after the end of copy marker, we drop
650    // it on the floor as PG does.
651    Ok(rows)
652}
653
654#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
655pub struct CopyCsvFormatParams<'a> {
656    pub delimiter: u8,
657    pub quote: u8,
658    pub escape: u8,
659    pub header: bool,
660    pub null: Cow<'a, str>,
661}
662
663impl<'a> CopyCsvFormatParams<'a> {
664    pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
665        CopyCsvFormatParams {
666            delimiter: self.delimiter,
667            quote: self.quote,
668            escape: self.escape,
669            header: self.header,
670            null: Cow::Owned(self.null.to_string()),
671        }
672    }
673}
674
675impl RustType<ProtoCopyCsvFormatParams> for CopyCsvFormatParams<'static> {
676    fn into_proto(&self) -> ProtoCopyCsvFormatParams {
677        ProtoCopyCsvFormatParams {
678            delimiter: self.delimiter.into(),
679            quote: self.quote.into(),
680            escape: self.escape.into(),
681            header: self.header,
682            null: self.null.into_proto(),
683        }
684    }
685
686    fn from_proto(proto: ProtoCopyCsvFormatParams) -> Result<Self, TryFromProtoError> {
687        Ok(Self {
688            delimiter: proto.delimiter.into_rust()?,
689            quote: proto.quote.into_rust()?,
690            escape: proto.escape.into_rust()?,
691            header: proto.header,
692            null: Cow::Owned(proto.null.into_rust()?),
693        })
694    }
695}
696
697impl Arbitrary for CopyCsvFormatParams<'static> {
698    type Parameters = ();
699    type Strategy = BoxedStrategy<Self>;
700
701    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
702        (
703            any::<u8>(),
704            any::<u8>(),
705            any::<u8>(),
706            any::<bool>(),
707            any::<String>(),
708        )
709            .prop_map(|(delimiter, diff, escape, header, null)| {
710                // Delimiter and Quote need to be different.
711                let diff = diff.saturating_sub(1).max(1);
712                let quote = delimiter.wrapping_add(diff);
713
714                Self::try_new(
715                    Some(delimiter),
716                    Some(quote),
717                    Some(escape),
718                    Some(header),
719                    Some(null),
720                )
721                .expect("delimiter and quote should be different")
722            })
723            .boxed()
724    }
725}
726
727impl<'a> Default for CopyCsvFormatParams<'a> {
728    fn default() -> Self {
729        CopyCsvFormatParams {
730            delimiter: b',',
731            quote: b'"',
732            escape: b'"',
733            header: false,
734            null: Cow::from(""),
735        }
736    }
737}
738
739impl<'a> CopyCsvFormatParams<'a> {
740    pub fn try_new(
741        delimiter: Option<u8>,
742        quote: Option<u8>,
743        escape: Option<u8>,
744        header: Option<bool>,
745        null: Option<String>,
746    ) -> Result<CopyCsvFormatParams<'a>, String> {
747        let mut params = CopyCsvFormatParams::default();
748
749        if let Some(delimiter) = delimiter {
750            params.delimiter = delimiter;
751        }
752        if let Some(quote) = quote {
753            params.quote = quote;
754            // escape defaults to the value provided for quote
755            params.escape = quote;
756        }
757        if let Some(escape) = escape {
758            params.escape = escape;
759        }
760        if let Some(header) = header {
761            params.header = header;
762        }
763        if let Some(null) = null {
764            params.null = Cow::from(null);
765        }
766
767        if params.quote == params.delimiter {
768            return Err("COPY delimiter and quote must be different".to_string());
769        }
770        Ok(params)
771    }
772}
773
774pub fn decode_copy_format_csv(
775    data: &[u8],
776    column_types: &[mz_pgrepr::Type],
777    CopyCsvFormatParams {
778        delimiter,
779        quote,
780        escape,
781        null,
782        header,
783    }: CopyCsvFormatParams,
784) -> Result<Vec<Row>, io::Error> {
785    let mut rows = Vec::new();
786
787    let (double_quote, escape) = if quote == escape {
788        (true, None)
789    } else {
790        (false, Some(escape))
791    };
792
793    let mut rdr = ReaderBuilder::new()
794        .delimiter(delimiter)
795        .quote(quote)
796        .has_headers(header)
797        .double_quote(double_quote)
798        .escape(escape)
799        // Must be flexible to accept end of copy marker, which will always be 1
800        // field.
801        .flexible(true)
802        .from_reader(data);
803
804    let null_as_bytes = null.as_bytes();
805
806    let mut record = ByteRecord::new();
807
808    while rdr.read_byte_record(&mut record)? {
809        if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
810            break;
811        }
812
813        match record.len().cmp(&column_types.len()) {
814            std::cmp::Ordering::Less => Err(io::Error::new(
815                io::ErrorKind::InvalidData,
816                "missing data for column",
817            )),
818            std::cmp::Ordering::Greater => Err(io::Error::new(
819                io::ErrorKind::InvalidData,
820                "extra data after last expected column",
821            )),
822            std::cmp::Ordering::Equal => Ok(()),
823        }?;
824
825        let mut row_builder = SharedRow::get();
826        let mut row_packer = row_builder.packer();
827
828        for (typ, raw_value) in column_types.iter().zip(record.iter()) {
829            if raw_value == null_as_bytes {
830                row_packer.push(Datum::Null);
831            } else {
832                let s = match std::str::from_utf8(raw_value) {
833                    Ok(s) => s,
834                    Err(err) => {
835                        let msg = format!("invalid utf8 data in column: {}", err);
836                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
837                    }
838                };
839                match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
840                    Ok(()) => {}
841                    Err(err) => {
842                        let msg = format!("unable to decode column: {}", err);
843                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
844                    }
845                }
846            }
847        }
848        rows.push(row_builder.clone());
849    }
850
851    Ok(rows)
852}
853
854#[cfg(test)]
855mod tests {
856    use mz_ore::collections::CollectionExt;
857    use mz_repr::{SqlColumnType, SqlScalarType};
858    use proptest::prelude::*;
859
860    use super::*;
861
862    #[mz_ore::test]
863    fn test_copy_format_text_parser() {
864        let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
865        let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
866        assert!(parser.is_column_delimiter());
867        parser
868            .expect_column_delimiter()
869            .expect("expected column delimiter");
870        assert_eq!(
871            parser
872                .consume_raw_value()
873                .expect("unexpected error")
874                .expect("unexpected empty result"),
875            "\nt e".as_bytes()
876        );
877        parser
878            .expect_column_delimiter()
879            .expect("expected column delimiter");
880        // null value
881        assert!(
882            parser
883                .consume_raw_value()
884                .expect("unexpected error")
885                .is_none()
886        );
887        parser
888            .expect_column_delimiter()
889            .expect("expected column delimiter");
890        assert!(parser.is_end_of_line());
891        parser.expect_end_of_line().expect("expected eol");
892        // hex value
893        assert_eq!(
894            parser
895                .consume_raw_value()
896                .expect("unexpected error")
897                .expect("unexpected empty result"),
898            "`\n}J".as_bytes()
899        );
900        parser.expect_end_of_line().expect("expected eol");
901        // octal value
902        assert_eq!(
903            parser
904                .consume_raw_value()
905                .expect("unexpected error")
906                .expect("unexpected empty result"),
907            "$$S".as_bytes()
908        );
909        assert!(parser.is_eof());
910    }
911
912    #[mz_ore::test]
913    fn test_copy_format_text_empty_null_string() {
914        let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
915        let expect = vec![
916            vec![None, None],
917            vec![Some("10"), Some("20")],
918            vec![Some("30"), None],
919            vec![Some("40"), None],
920        ];
921        let mut parser = CopyTextFormatParser::new(text, b'\t', "");
922        for line in expect {
923            for (i, value) in line.iter().enumerate() {
924                if i > 0 {
925                    parser
926                        .expect_column_delimiter()
927                        .expect("expected column delimiter");
928                }
929                match value {
930                    Some(s) => {
931                        assert!(!parser.consume_null_string());
932                        assert_eq!(
933                            parser
934                                .consume_raw_value()
935                                .expect("unexpected error")
936                                .expect("unexpected empty result"),
937                            s.as_bytes()
938                        );
939                    }
940                    None => {
941                        assert!(parser.consume_null_string());
942                    }
943                }
944            }
945            parser.expect_end_of_line().expect("expected eol");
946        }
947    }
948
949    #[mz_ore::test]
950    fn test_copy_format_text_parser_escapes() {
951        struct TestCase {
952            input: &'static str,
953            expect: &'static [u8],
954        }
955        let tests = vec![
956            TestCase {
957                input: "simple",
958                expect: b"simple",
959            },
960            TestCase {
961                input: r#"new\nline"#,
962                expect: b"new\nline",
963            },
964            TestCase {
965                input: r#"\b\f\n\r\t\v\\"#,
966                expect: b"\x08\x0c\n\r\t\x0b\\",
967            },
968            TestCase {
969                input: r#"\0\12\123"#,
970                expect: &[0, 0o12, 0o123],
971            },
972            TestCase {
973                input: r#"\x1\xaf"#,
974                expect: &[0x01, 0xaf],
975            },
976            TestCase {
977                input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
978                expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
979            },
980            TestCase {
981                input: r#"\\\""#,
982                expect: b"\\\"",
983            },
984            TestCase {
985                input: r#"\x"#,
986                expect: b"x",
987            },
988            TestCase {
989                input: r#"\xg"#,
990                expect: b"xg",
991            },
992            TestCase {
993                input: r#"\"#,
994                expect: b"\\",
995            },
996            TestCase {
997                input: r#"\8"#,
998                expect: b"8",
999            },
1000            TestCase {
1001                input: r#"\a"#,
1002                expect: b"a",
1003            },
1004            TestCase {
1005                input: r#"\x\xg\8\xH\x32\s\"#,
1006                expect: b"xxg8xH2s\\",
1007            },
1008        ];
1009
1010        for test in tests {
1011            let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1012            assert_eq!(
1013                parser
1014                    .consume_raw_value()
1015                    .expect("unexpected error")
1016                    .expect("unexpected empty result"),
1017                test.expect,
1018                "input: {}, expect: {:?}",
1019                test.input,
1020                std::str::from_utf8(test.expect),
1021            );
1022            assert!(parser.is_eof());
1023        }
1024    }
1025
1026    #[mz_ore::test]
1027    fn test_copy_csv_format_params() {
1028        assert_eq!(
1029            CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1030            Ok(CopyCsvFormatParams {
1031                delimiter: b't',
1032                quote: b'q',
1033                escape: b'q',
1034                header: false,
1035                null: Cow::from(""),
1036            })
1037        );
1038
1039        assert_eq!(
1040            CopyCsvFormatParams::try_new(
1041                Some(b't'),
1042                Some(b'q'),
1043                Some(b'e'),
1044                Some(true),
1045                Some("null".to_string())
1046            ),
1047            Ok(CopyCsvFormatParams {
1048                delimiter: b't',
1049                quote: b'q',
1050                escape: b'e',
1051                header: true,
1052                null: Cow::from("null"),
1053            })
1054        );
1055
1056        assert_eq!(
1057            CopyCsvFormatParams::try_new(
1058                None,
1059                Some(b','),
1060                Some(b'e'),
1061                Some(true),
1062                Some("null".to_string())
1063            ),
1064            Err("COPY delimiter and quote must be different".to_string())
1065        );
1066    }
1067
1068    #[mz_ore::test]
1069    fn test_copy_csv_row() -> Result<(), io::Error> {
1070        let mut row = Row::default();
1071        let mut packer = row.packer();
1072        packer.push(Datum::from("1,2,\"3\""));
1073        packer.push(Datum::Null);
1074        packer.push(Datum::from(1000u64));
1075        packer.push(Datum::from("qe")); // overridden quote and escape character in test below
1076        packer.push(Datum::from(""));
1077
1078        let typ: SqlRelationType = SqlRelationType::new(vec![
1079            SqlColumnType {
1080                scalar_type: mz_repr::SqlScalarType::String,
1081                nullable: false,
1082            },
1083            SqlColumnType {
1084                scalar_type: mz_repr::SqlScalarType::String,
1085                nullable: true,
1086            },
1087            SqlColumnType {
1088                scalar_type: mz_repr::SqlScalarType::UInt64,
1089                nullable: false,
1090            },
1091            SqlColumnType {
1092                scalar_type: mz_repr::SqlScalarType::String,
1093                nullable: false,
1094            },
1095            SqlColumnType {
1096                scalar_type: mz_repr::SqlScalarType::String,
1097                nullable: false,
1098            },
1099        ]);
1100
1101        let mut out = Vec::new();
1102
1103        struct TestCase<'a> {
1104            params: CopyCsvFormatParams<'a>,
1105            expected: &'static [u8],
1106        }
1107
1108        let tests = [
1109            TestCase {
1110                params: CopyCsvFormatParams::default(),
1111                expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1112            },
1113            TestCase {
1114                params: CopyCsvFormatParams {
1115                    null: Cow::from("NULL"),
1116                    quote: b'q',
1117                    escape: b'e',
1118                    ..Default::default()
1119                },
1120                expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1121            },
1122        ];
1123
1124        for TestCase { params, expected } in tests {
1125            out.clear();
1126            let params = CopyFormatParams::Csv(params);
1127            let _ = encode_copy_format(&params, &row, &typ, &mut out);
1128            let output = std::str::from_utf8(&out);
1129            assert_eq!(output, std::str::from_utf8(expected));
1130        }
1131
1132        Ok(())
1133    }
1134
1135    proptest! {
1136        #[mz_ore::test]
1137        #[cfg_attr(miri, ignore)]
1138        fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams)  {
1139            // Given a SqlScalarType and Datum roundtrips it through the CSV COPY format.
1140            let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1141                let row = Row::pack_slice(&[datum]);
1142                let typ = SqlRelationType::new(vec![
1143                    SqlColumnType {
1144                        scalar_type: scalar_type.clone(),
1145                        nullable: true,
1146                    }
1147                ]);
1148
1149                let mut buf = Vec::new();
1150                let mut csv_params = copy_csv_params.clone();
1151                // TODO: Encoding never writes a header.
1152                csv_params.header = false;
1153                let params = CopyFormatParams::Csv(csv_params);
1154
1155                // Roundtrip the Row through our CSV format.
1156                encode_copy_format(&params, &row, &typ, &mut buf)?;
1157                let column_types = typ
1158                    .column_types
1159                    .iter()
1160                    .map(|x| &x.scalar_type)
1161                    .map(mz_pgrepr::Type::from)
1162                    .collect::<Vec<mz_pgrepr::Type>>();
1163                let result = decode_copy_format(&buf, &column_types, params);
1164
1165                match result {
1166                    Ok(rows) => {
1167                        let out_str = std::str::from_utf8(&buf[..]);
1168
1169                        prop_assert_eq!(
1170                            rows.len(),
1171                            1,
1172                            "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1173                        );
1174                        let output = rows.into_element();
1175
1176                        prop_assert_eq!(
1177                            row,
1178                            output,
1179                            "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1180                        );
1181                    }
1182                    _ => {
1183                        // ignoring decoding failures
1184                    }
1185                }
1186
1187                Ok(())
1188            };
1189
1190            // Try roundtripping all of our interesting Datums.
1191            for scalar_type in SqlScalarType::enumerate() {
1192                for datum in scalar_type.interesting_datums() {
1193                    // TODO: The decoder cannot differentiate between empty string and null.
1194                    if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1195                        let mut buf = bytes::BytesMut::new();
1196                        value.encode_text(&mut buf);
1197
1198                        if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1199                            if datum_str == copy_csv_params.null {
1200                                continue;
1201                            }
1202                        }
1203                    }
1204
1205                    let updated_datum = match datum {
1206                        // TODO: Fix roundtrip decoding of these types.
1207                        Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1208                            continue;
1209                        }
1210                        Datum::String(s) => {
1211                            // TODO: The decoder cannot differentiate between empty string and null.
1212                            if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1213                                continue;
1214                            } else {
1215                                Datum::String(s)
1216                            }
1217                        }
1218                        other => other,
1219                    };
1220
1221                    let result = try_roundtrip_datum(scalar_type, updated_datum);
1222                    prop_assert!(result.is_ok(), "failure: {result:?}");
1223                }
1224            }
1225        }
1226    }
1227}