mz_pgcopy/
copy.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use mz_proto::{ProtoType, RustType, TryFromProtoError};
16use mz_repr::{
17    ColumnType, Datum, RelationDesc, RelationType, Row, RowArena, RowRef, ScalarType, SharedRow,
18};
19use proptest::prelude::{Arbitrary, Just, any};
20use proptest::strategy::{BoxedStrategy, Strategy, Union};
21use serde::Deserialize;
22use serde::Serialize;
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26include!(concat!(env!("OUT_DIR"), "/mz_pgcopy.copy.rs"));
27
28fn encode_copy_row_binary(
29    row: &RowRef,
30    typ: &RelationType,
31    out: &mut Vec<u8>,
32) -> Result<(), io::Error> {
33    const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
34
35    // 16-bit int of number of tuples.
36    let count = i16::try_from(typ.column_types.len()).map_err(|_| {
37        io::Error::new(
38            io::ErrorKind::Other,
39            "column count does not fit into an i16",
40        )
41    })?;
42
43    out.extend(count.to_be_bytes());
44    let mut buf = BytesMut::new();
45    for (field, typ) in row
46        .iter()
47        .zip(&typ.column_types)
48        .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
49    {
50        match field {
51            None => out.extend(NULL_BYTES),
52            Some(field) => {
53                buf.clear();
54                field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
55                out.extend(
56                    i32::try_from(buf.len())
57                        .map_err(|_| {
58                            io::Error::new(
59                                io::ErrorKind::Other,
60                                "field length does not fit into an i32",
61                            )
62                        })?
63                        .to_be_bytes(),
64                );
65                out.extend(&buf);
66            }
67        }
68    }
69    Ok(())
70}
71
72fn encode_copy_row_text(
73    CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
74    row: &RowRef,
75    typ: &RelationType,
76    out: &mut Vec<u8>,
77) -> Result<(), io::Error> {
78    let null = null.as_bytes();
79    let mut buf = BytesMut::new();
80    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
81        if idx > 0 {
82            out.push(*delimiter);
83        }
84        match field {
85            None => out.extend(null),
86            Some(field) => {
87                buf.clear();
88                field.encode_text(&mut buf);
89                for b in &buf {
90                    match b {
91                        b'\\' => out.extend(b"\\\\"),
92                        b'\n' => out.extend(b"\\n"),
93                        b'\r' => out.extend(b"\\r"),
94                        b'\t' => out.extend(b"\\t"),
95                        _ => out.push(*b),
96                    }
97                }
98            }
99        }
100    }
101    out.push(b'\n');
102    Ok(())
103}
104
105fn encode_copy_row_csv(
106    CopyCsvFormatParams {
107        delimiter: delim,
108        quote,
109        escape,
110        header: _,
111        null,
112    }: &CopyCsvFormatParams,
113    row: &RowRef,
114    typ: &RelationType,
115    out: &mut Vec<u8>,
116) -> Result<(), io::Error> {
117    let null = null.as_bytes();
118    let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
119    let mut buf = BytesMut::new();
120    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
121        if idx > 0 {
122            out.push(*delim);
123        }
124        match field {
125            None => out.extend(null),
126            Some(field) => {
127                buf.clear();
128                field.encode_text(&mut buf);
129                // A field needs quoting if:
130                //   * It is the only field and the value is exactly the end
131                //     of copy marker.
132                //   * The field contains a special character.
133                //   * The field is exactly the NULL sentinel.
134                if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
135                    || buf.iter().any(is_special)
136                    || &*buf == null
137                {
138                    // Quote the value by wrapping it in the quote character and
139                    // emitting the escape character before any quote or escape
140                    // characters within.
141                    out.push(*quote);
142                    for b in &buf {
143                        if *b == *quote || *b == *escape {
144                            out.push(*escape);
145                        }
146                        out.push(*b);
147                    }
148                    out.push(*quote);
149                } else {
150                    // The value does not need quoting and can be emitted
151                    // directly.
152                    out.extend(&buf);
153                }
154            }
155        }
156    }
157    out.push(b'\n');
158    Ok(())
159}
160
161pub struct CopyTextFormatParser<'a> {
162    data: &'a [u8],
163    position: usize,
164    column_delimiter: u8,
165    null_string: &'a str,
166    buffer: Vec<u8>,
167}
168
169impl<'a> CopyTextFormatParser<'a> {
170    pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
171        Self {
172            data,
173            position: 0,
174            column_delimiter,
175            null_string,
176            buffer: Vec::new(),
177        }
178    }
179
180    fn peek(&self) -> Option<u8> {
181        if self.position < self.data.len() {
182            Some(self.data[self.position])
183        } else {
184            None
185        }
186    }
187
188    fn consume_n(&mut self, n: usize) {
189        self.position = std::cmp::min(self.position + n, self.data.len());
190    }
191
192    pub fn is_eof(&self) -> bool {
193        self.peek().is_none() || self.is_end_of_copy_marker()
194    }
195
196    pub fn is_end_of_copy_marker(&self) -> bool {
197        self.check_bytes(END_OF_COPY_MARKER)
198    }
199
200    fn is_end_of_line(&self) -> bool {
201        match self.peek() {
202            Some(b'\n') | None => true,
203            _ => false,
204        }
205    }
206
207    pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
208        if self.is_end_of_line() {
209            self.consume_n(1);
210            Ok(())
211        } else {
212            Err(io::Error::new(
213                io::ErrorKind::InvalidData,
214                "extra data after last expected column",
215            ))
216        }
217    }
218
219    fn is_column_delimiter(&self) -> bool {
220        self.check_bytes(&[self.column_delimiter])
221    }
222
223    pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
224        if self.consume_bytes(&[self.column_delimiter]) {
225            Ok(())
226        } else {
227            Err(io::Error::new(
228                io::ErrorKind::InvalidData,
229                "missing data for column",
230            ))
231        }
232    }
233
234    fn check_bytes(&self, bytes: &[u8]) -> bool {
235        let remaining_bytes = self.data.len() - self.position;
236        remaining_bytes >= bytes.len()
237            && self.data[self.position..]
238                .iter()
239                .zip(bytes.iter())
240                .all(|(x, y)| x == y)
241    }
242
243    fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
244        if self.check_bytes(bytes) {
245            self.consume_n(bytes.len());
246            true
247        } else {
248            false
249        }
250    }
251
252    fn consume_null_string(&mut self) -> bool {
253        if self.null_string.is_empty() {
254            // An empty NULL marker is supported. Look ahead to ensure that is followed by
255            // a column delimiter, an end of line or it is at the end of the data.
256            self.is_column_delimiter()
257                || self.is_end_of_line()
258                || self.is_end_of_copy_marker()
259                || self.is_eof()
260        } else {
261            self.consume_bytes(self.null_string.as_bytes())
262        }
263    }
264
265    pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
266        if self.consume_null_string() {
267            return Ok(None);
268        }
269
270        let mut start = self.position;
271
272        // buffer where unescaped data is accumulated
273        self.buffer.clear();
274
275        while !self.is_eof() && !self.is_end_of_copy_marker() {
276            if self.is_end_of_line() || self.is_column_delimiter() {
277                break;
278            }
279            match self.peek() {
280                Some(b'\\') => {
281                    // Add non-escaped data parsed so far
282                    self.buffer.extend(&self.data[start..self.position]);
283
284                    self.consume_n(1);
285                    match self.peek() {
286                        Some(b'b') => {
287                            self.consume_n(1);
288                            self.buffer.push(8);
289                        }
290                        Some(b'f') => {
291                            self.consume_n(1);
292                            self.buffer.push(12);
293                        }
294                        Some(b'n') => {
295                            self.consume_n(1);
296                            self.buffer.push(b'\n');
297                        }
298                        Some(b'r') => {
299                            self.consume_n(1);
300                            self.buffer.push(b'\r');
301                        }
302                        Some(b't') => {
303                            self.consume_n(1);
304                            self.buffer.push(b'\t');
305                        }
306                        Some(b'v') => {
307                            self.consume_n(1);
308                            self.buffer.push(11);
309                        }
310                        Some(b'x') => {
311                            self.consume_n(1);
312                            match self.peek() {
313                                Some(_c @ b'0'..=b'9')
314                                | Some(_c @ b'A'..=b'F')
315                                | Some(_c @ b'a'..=b'f') => {
316                                    let mut value: u8 = 0;
317                                    let decode_nibble = |b| match b {
318                                        Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
319                                        Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
320                                        Some(c @ b'0'..=b'9') => Some(c - b'0'),
321                                        _ => None,
322                                    };
323                                    for _ in 0..2 {
324                                        match decode_nibble(self.peek()) {
325                                            Some(c) => {
326                                                self.consume_n(1);
327                                                value = value << 4 | c;
328                                            }
329                                            _ => break,
330                                        }
331                                    }
332                                    self.buffer.push(value);
333                                }
334                                _ => {
335                                    self.buffer.push(b'x');
336                                }
337                            }
338                        }
339                        Some(_c @ b'0'..=b'7') => {
340                            let mut value: u8 = 0;
341                            for _ in 0..3 {
342                                match self.peek() {
343                                    Some(c @ b'0'..=b'7') => {
344                                        self.consume_n(1);
345                                        value = value << 3 | (c - b'0');
346                                    }
347                                    _ => break,
348                                }
349                            }
350                            self.buffer.push(value);
351                        }
352                        Some(c) => {
353                            self.consume_n(1);
354                            self.buffer.push(c);
355                        }
356                        None => {
357                            self.buffer.push(b'\\');
358                        }
359                    }
360
361                    start = self.position;
362                }
363                Some(_) => {
364                    self.consume_n(1);
365                }
366                None => {}
367            }
368        }
369
370        // Return a slice of the original buffer if no escaped characters where processed
371        if self.buffer.is_empty() {
372            Ok(Some(&self.data[start..self.position]))
373        } else {
374            // ... otherwise, add the remaining non-escaped data to the decoding buffer
375            // and return a pointer to it
376            self.buffer.extend(&self.data[start..self.position]);
377            Ok(Some(&self.buffer[..]))
378        }
379    }
380
381    /// Error if more than `num_columns` values in `parser`.
382    pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
383        RawIterator {
384            parser: self,
385            current_column: 0,
386            num_columns,
387            truncate: false,
388        }
389    }
390
391    /// Return no more than `num_columns` values from `parser`.
392    pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
393        RawIterator {
394            parser: self,
395            current_column: 0,
396            num_columns,
397            truncate: true,
398        }
399    }
400}
401
402pub struct RawIterator<'a> {
403    parser: CopyTextFormatParser<'a>,
404    current_column: usize,
405    num_columns: usize,
406    truncate: bool,
407}
408
409impl<'a> RawIterator<'a> {
410    pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
411        if self.current_column > self.num_columns {
412            return None;
413        }
414
415        if self.current_column == self.num_columns {
416            if !self.truncate {
417                if let Some(err) = self.parser.expect_end_of_line().err() {
418                    return Some(Err(err));
419                }
420            }
421
422            return None;
423        }
424
425        if self.current_column > 0 {
426            if let Some(err) = self.parser.expect_column_delimiter().err() {
427                return Some(Err(err));
428            }
429        }
430
431        self.current_column += 1;
432        Some(self.parser.consume_raw_value())
433    }
434}
435
436#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
437pub enum CopyFormatParams<'a> {
438    Text(CopyTextFormatParams<'a>),
439    Csv(CopyCsvFormatParams<'a>),
440    Binary,
441    Parquet,
442}
443
444impl RustType<ProtoCopyFormatParams> for CopyFormatParams<'static> {
445    fn into_proto(&self) -> ProtoCopyFormatParams {
446        use proto_copy_format_params::Kind;
447        ProtoCopyFormatParams {
448            kind: Some(match self {
449                Self::Text(f) => Kind::Text(f.into_proto()),
450                Self::Csv(f) => Kind::Csv(f.into_proto()),
451                Self::Binary => Kind::Binary(()),
452                Self::Parquet => Kind::Parquet(ProtoCopyParquetFormatParams::default()),
453            }),
454        }
455    }
456
457    fn from_proto(proto: ProtoCopyFormatParams) -> Result<Self, TryFromProtoError> {
458        use proto_copy_format_params::Kind;
459        match proto.kind {
460            Some(Kind::Text(f)) => Ok(Self::Text(f.into_rust()?)),
461            Some(Kind::Csv(f)) => Ok(Self::Csv(f.into_rust()?)),
462            Some(Kind::Binary(())) => Ok(Self::Binary),
463            Some(Kind::Parquet(ProtoCopyParquetFormatParams {})) => Ok(Self::Parquet),
464            None => Err(TryFromProtoError::missing_field(
465                "ProtoCopyFormatParams::kind",
466            )),
467        }
468    }
469}
470
471impl Arbitrary for CopyFormatParams<'static> {
472    type Parameters = ();
473    type Strategy = Union<BoxedStrategy<Self>>;
474
475    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
476        Union::new(vec![
477            any::<CopyTextFormatParams>().prop_map(Self::Text).boxed(),
478            any::<CopyCsvFormatParams>().prop_map(Self::Csv).boxed(),
479            Just(Self::Binary).boxed(),
480        ])
481    }
482}
483
484impl CopyFormatParams<'static> {
485    pub fn file_extension(&self) -> &str {
486        match self {
487            &CopyFormatParams::Text(_) => "txt",
488            &CopyFormatParams::Csv(_) => "csv",
489            &CopyFormatParams::Binary => "bin",
490            &CopyFormatParams::Parquet => "parquet",
491        }
492    }
493
494    pub fn requires_header(&self) -> bool {
495        match self {
496            CopyFormatParams::Text(_) => false,
497            CopyFormatParams::Csv(params) => params.header,
498            CopyFormatParams::Binary => false,
499            CopyFormatParams::Parquet => false,
500        }
501    }
502}
503
504/// Decodes the given bytes into `Row`-s based on the given `CopyFormatParams`.
505pub fn decode_copy_format<'a>(
506    data: &[u8],
507    column_types: &[mz_pgrepr::Type],
508    params: CopyFormatParams<'a>,
509) -> Result<Vec<Row>, io::Error> {
510    match params {
511        CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
512        CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
513        CopyFormatParams::Binary => Err(io::Error::new(
514            io::ErrorKind::Unsupported,
515            "cannot decode as binary format",
516        )),
517        CopyFormatParams::Parquet => {
518            // TODO(cf2): Support Parquet over STDIN.
519            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
520        }
521    }
522}
523
524/// Encodes the given `Row` into bytes based on the given `CopyFormatParams`.
525pub fn encode_copy_format<'a>(
526    params: &CopyFormatParams<'a>,
527    row: &RowRef,
528    typ: &RelationType,
529    out: &mut Vec<u8>,
530) -> Result<(), io::Error> {
531    match params {
532        CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
533        CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
534        CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
535        CopyFormatParams::Parquet => {
536            // TODO(cf2): Support Parquet over STDIN.
537            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
538        }
539    }
540}
541
542pub fn encode_copy_format_header<'a>(
543    params: &CopyFormatParams<'a>,
544    desc: &RelationDesc,
545    out: &mut Vec<u8>,
546) -> Result<(), io::Error> {
547    match params {
548        CopyFormatParams::Text(_) => Ok(()),
549        CopyFormatParams::Binary => Ok(()),
550        CopyFormatParams::Csv(params) => {
551            let mut header_row = Row::with_capacity(desc.arity());
552            header_row
553                .packer()
554                .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
555            let typ = RelationType::new(vec![
556                ColumnType {
557                    scalar_type: ScalarType::String,
558                    nullable: false,
559                };
560                desc.arity()
561            ]);
562            encode_copy_row_csv(params, &header_row, &typ, out)
563        }
564        CopyFormatParams::Parquet => {
565            // TODO(cf2): Support Parquet over STDIN.
566            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
567        }
568    }
569}
570
571#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
572pub struct CopyTextFormatParams<'a> {
573    pub null: Cow<'a, str>,
574    pub delimiter: u8,
575}
576
577impl<'a> Default for CopyTextFormatParams<'a> {
578    fn default() -> Self {
579        CopyTextFormatParams {
580            delimiter: b'\t',
581            null: Cow::from("\\N"),
582        }
583    }
584}
585
586impl RustType<ProtoCopyTextFormatParams> for CopyTextFormatParams<'static> {
587    fn into_proto(&self) -> ProtoCopyTextFormatParams {
588        ProtoCopyTextFormatParams {
589            null: self.null.into_proto(),
590            delimiter: self.delimiter.into_proto(),
591        }
592    }
593
594    fn from_proto(proto: ProtoCopyTextFormatParams) -> Result<Self, TryFromProtoError> {
595        Ok(Self {
596            null: Cow::Owned(proto.null.into_rust()?),
597            delimiter: proto.delimiter.into_rust()?,
598        })
599    }
600}
601
602impl Arbitrary for CopyTextFormatParams<'static> {
603    type Parameters = ();
604    type Strategy = BoxedStrategy<Self>;
605
606    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
607        (any::<String>(), any::<u8>())
608            .prop_map(|(null, delimiter)| Self {
609                null: Cow::Owned(null),
610                delimiter,
611            })
612            .boxed()
613    }
614}
615
616pub fn decode_copy_format_text(
617    data: &[u8],
618    column_types: &[mz_pgrepr::Type],
619    CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
620) -> Result<Vec<Row>, io::Error> {
621    let mut rows = Vec::new();
622
623    // TODO: pass the `CopyTextFormatParams` to the `new` method
624    let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
625    while !parser.is_eof() && !parser.is_end_of_copy_marker() {
626        let mut row = Vec::new();
627        let buf = RowArena::new();
628        for (col, typ) in column_types.iter().enumerate() {
629            if col > 0 {
630                parser.expect_column_delimiter()?;
631            }
632            let raw_value = parser.consume_raw_value()?;
633            if let Some(raw_value) = raw_value {
634                match mz_pgrepr::Value::decode_text(typ, raw_value) {
635                    Ok(value) => row.push(value.into_datum(&buf, typ)),
636                    Err(err) => {
637                        let msg = format!("unable to decode column: {}", err);
638                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
639                    }
640                }
641            } else {
642                row.push(Datum::Null);
643            }
644        }
645        parser.expect_end_of_line()?;
646        rows.push(Row::pack(row));
647    }
648    // Note that if there is any junk data after the end of copy marker, we drop
649    // it on the floor as PG does.
650    Ok(rows)
651}
652
653#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
654pub struct CopyCsvFormatParams<'a> {
655    pub delimiter: u8,
656    pub quote: u8,
657    pub escape: u8,
658    pub header: bool,
659    pub null: Cow<'a, str>,
660}
661
662impl<'a> CopyCsvFormatParams<'a> {
663    pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
664        CopyCsvFormatParams {
665            delimiter: self.delimiter,
666            quote: self.quote,
667            escape: self.escape,
668            header: self.header,
669            null: Cow::Owned(self.null.to_string()),
670        }
671    }
672}
673
674impl RustType<ProtoCopyCsvFormatParams> for CopyCsvFormatParams<'static> {
675    fn into_proto(&self) -> ProtoCopyCsvFormatParams {
676        ProtoCopyCsvFormatParams {
677            delimiter: self.delimiter.into(),
678            quote: self.quote.into(),
679            escape: self.escape.into(),
680            header: self.header,
681            null: self.null.into_proto(),
682        }
683    }
684
685    fn from_proto(proto: ProtoCopyCsvFormatParams) -> Result<Self, TryFromProtoError> {
686        Ok(Self {
687            delimiter: proto.delimiter.into_rust()?,
688            quote: proto.quote.into_rust()?,
689            escape: proto.escape.into_rust()?,
690            header: proto.header,
691            null: Cow::Owned(proto.null.into_rust()?),
692        })
693    }
694}
695
696impl Arbitrary for CopyCsvFormatParams<'static> {
697    type Parameters = ();
698    type Strategy = BoxedStrategy<Self>;
699
700    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
701        (
702            any::<u8>(),
703            any::<u8>(),
704            any::<u8>(),
705            any::<bool>(),
706            any::<String>(),
707        )
708            .prop_map(|(delimiter, diff, escape, header, null)| {
709                // Delimiter and Quote need to be different.
710                let diff = diff.saturating_sub(1).max(1);
711                let quote = delimiter.wrapping_add(diff);
712
713                Self::try_new(
714                    Some(delimiter),
715                    Some(quote),
716                    Some(escape),
717                    Some(header),
718                    Some(null),
719                )
720                .expect("delimiter and quote should be different")
721            })
722            .boxed()
723    }
724}
725
726impl<'a> Default for CopyCsvFormatParams<'a> {
727    fn default() -> Self {
728        CopyCsvFormatParams {
729            delimiter: b',',
730            quote: b'"',
731            escape: b'"',
732            header: false,
733            null: Cow::from(""),
734        }
735    }
736}
737
738impl<'a> CopyCsvFormatParams<'a> {
739    pub fn try_new(
740        delimiter: Option<u8>,
741        quote: Option<u8>,
742        escape: Option<u8>,
743        header: Option<bool>,
744        null: Option<String>,
745    ) -> Result<CopyCsvFormatParams<'a>, String> {
746        let mut params = CopyCsvFormatParams::default();
747
748        if let Some(delimiter) = delimiter {
749            params.delimiter = delimiter;
750        }
751        if let Some(quote) = quote {
752            params.quote = quote;
753            // escape defaults to the value provided for quote
754            params.escape = quote;
755        }
756        if let Some(escape) = escape {
757            params.escape = escape;
758        }
759        if let Some(header) = header {
760            params.header = header;
761        }
762        if let Some(null) = null {
763            params.null = Cow::from(null);
764        }
765
766        if params.quote == params.delimiter {
767            return Err("COPY delimiter and quote must be different".to_string());
768        }
769        Ok(params)
770    }
771}
772
773pub fn decode_copy_format_csv(
774    data: &[u8],
775    column_types: &[mz_pgrepr::Type],
776    CopyCsvFormatParams {
777        delimiter,
778        quote,
779        escape,
780        null,
781        header,
782    }: CopyCsvFormatParams,
783) -> Result<Vec<Row>, io::Error> {
784    let mut rows = Vec::new();
785
786    let (double_quote, escape) = if quote == escape {
787        (true, None)
788    } else {
789        (false, Some(escape))
790    };
791
792    let mut rdr = ReaderBuilder::new()
793        .delimiter(delimiter)
794        .quote(quote)
795        .has_headers(header)
796        .double_quote(double_quote)
797        .escape(escape)
798        // Must be flexible to accept end of copy marker, which will always be 1
799        // field.
800        .flexible(true)
801        .from_reader(data);
802
803    let null_as_bytes = null.as_bytes();
804
805    let mut record = ByteRecord::new();
806
807    while rdr.read_byte_record(&mut record)? {
808        if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
809            break;
810        }
811
812        match record.len().cmp(&column_types.len()) {
813            std::cmp::Ordering::Less => Err(io::Error::new(
814                io::ErrorKind::InvalidData,
815                "missing data for column",
816            )),
817            std::cmp::Ordering::Greater => Err(io::Error::new(
818                io::ErrorKind::InvalidData,
819                "extra data after last expected column",
820            )),
821            std::cmp::Ordering::Equal => Ok(()),
822        }?;
823
824        let binding = SharedRow::get();
825        let mut row_builder = binding.borrow_mut();
826        let mut row_packer = row_builder.packer();
827
828        for (typ, raw_value) in column_types.iter().zip(record.iter()) {
829            if raw_value == null_as_bytes {
830                row_packer.push(Datum::Null);
831            } else {
832                let s = match std::str::from_utf8(raw_value) {
833                    Ok(s) => s,
834                    Err(err) => {
835                        let msg = format!("invalid utf8 data in column: {}", err);
836                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
837                    }
838                };
839                match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
840                    Ok(()) => {}
841                    Err(err) => {
842                        let msg = format!("unable to decode column: {}", err);
843                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
844                    }
845                }
846            }
847        }
848        rows.push(row_builder.clone());
849    }
850
851    Ok(rows)
852}
853
854#[cfg(test)]
855mod tests {
856    use mz_ore::collections::CollectionExt;
857    use mz_repr::{ColumnType, ScalarType};
858    use proptest::prelude::*;
859
860    use super::*;
861
862    #[mz_ore::test]
863    fn test_copy_format_text_parser() {
864        let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
865        let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
866        assert!(parser.is_column_delimiter());
867        parser
868            .expect_column_delimiter()
869            .expect("expected column delimiter");
870        assert_eq!(
871            parser
872                .consume_raw_value()
873                .expect("unexpected error")
874                .expect("unexpected empty result"),
875            "\nt e".as_bytes()
876        );
877        parser
878            .expect_column_delimiter()
879            .expect("expected column delimiter");
880        // null value
881        assert!(
882            parser
883                .consume_raw_value()
884                .expect("unexpected error")
885                .is_none()
886        );
887        parser
888            .expect_column_delimiter()
889            .expect("expected column delimiter");
890        assert!(parser.is_end_of_line());
891        parser.expect_end_of_line().expect("expected eol");
892        // hex value
893        assert_eq!(
894            parser
895                .consume_raw_value()
896                .expect("unexpected error")
897                .expect("unexpected empty result"),
898            "`\n}J".as_bytes()
899        );
900        parser.expect_end_of_line().expect("expected eol");
901        // octal value
902        assert_eq!(
903            parser
904                .consume_raw_value()
905                .expect("unexpected error")
906                .expect("unexpected empty result"),
907            "$$S".as_bytes()
908        );
909        assert!(parser.is_eof());
910    }
911
912    #[mz_ore::test]
913    fn test_copy_format_text_empty_null_string() {
914        let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
915        let expect = vec![
916            vec![None, None],
917            vec![Some("10"), Some("20")],
918            vec![Some("30"), None],
919            vec![Some("40"), None],
920        ];
921        let mut parser = CopyTextFormatParser::new(text, b'\t', "");
922        for line in expect {
923            for (i, value) in line.iter().enumerate() {
924                if i > 0 {
925                    parser
926                        .expect_column_delimiter()
927                        .expect("expected column delimiter");
928                }
929                match value {
930                    Some(s) => {
931                        assert!(!parser.consume_null_string());
932                        assert_eq!(
933                            parser
934                                .consume_raw_value()
935                                .expect("unexpected error")
936                                .expect("unexpected empty result"),
937                            s.as_bytes()
938                        );
939                    }
940                    None => {
941                        assert!(parser.consume_null_string());
942                    }
943                }
944            }
945            parser.expect_end_of_line().expect("expected eol");
946        }
947    }
948
949    #[mz_ore::test]
950    fn test_copy_format_text_parser_escapes() {
951        struct TestCase {
952            input: &'static str,
953            expect: &'static [u8],
954        }
955        let tests = vec![
956            TestCase {
957                input: "simple",
958                expect: b"simple",
959            },
960            TestCase {
961                input: r#"new\nline"#,
962                expect: b"new\nline",
963            },
964            TestCase {
965                input: r#"\b\f\n\r\t\v\\"#,
966                expect: b"\x08\x0c\n\r\t\x0b\\",
967            },
968            TestCase {
969                input: r#"\0\12\123"#,
970                expect: &[0, 0o12, 0o123],
971            },
972            TestCase {
973                input: r#"\x1\xaf"#,
974                expect: &[0x01, 0xaf],
975            },
976            TestCase {
977                input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
978                expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
979            },
980            TestCase {
981                input: r#"\\\""#,
982                expect: b"\\\"",
983            },
984            TestCase {
985                input: r#"\x"#,
986                expect: b"x",
987            },
988            TestCase {
989                input: r#"\xg"#,
990                expect: b"xg",
991            },
992            TestCase {
993                input: r#"\"#,
994                expect: b"\\",
995            },
996            TestCase {
997                input: r#"\8"#,
998                expect: b"8",
999            },
1000            TestCase {
1001                input: r#"\a"#,
1002                expect: b"a",
1003            },
1004            TestCase {
1005                input: r#"\x\xg\8\xH\x32\s\"#,
1006                expect: b"xxg8xH2s\\",
1007            },
1008        ];
1009
1010        for test in tests {
1011            let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1012            assert_eq!(
1013                parser
1014                    .consume_raw_value()
1015                    .expect("unexpected error")
1016                    .expect("unexpected empty result"),
1017                test.expect,
1018                "input: {}, expect: {:?}",
1019                test.input,
1020                std::str::from_utf8(test.expect),
1021            );
1022            assert!(parser.is_eof());
1023        }
1024    }
1025
1026    #[mz_ore::test]
1027    fn test_copy_csv_format_params() {
1028        assert_eq!(
1029            CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1030            Ok(CopyCsvFormatParams {
1031                delimiter: b't',
1032                quote: b'q',
1033                escape: b'q',
1034                header: false,
1035                null: Cow::from(""),
1036            })
1037        );
1038
1039        assert_eq!(
1040            CopyCsvFormatParams::try_new(
1041                Some(b't'),
1042                Some(b'q'),
1043                Some(b'e'),
1044                Some(true),
1045                Some("null".to_string())
1046            ),
1047            Ok(CopyCsvFormatParams {
1048                delimiter: b't',
1049                quote: b'q',
1050                escape: b'e',
1051                header: true,
1052                null: Cow::from("null"),
1053            })
1054        );
1055
1056        assert_eq!(
1057            CopyCsvFormatParams::try_new(
1058                None,
1059                Some(b','),
1060                Some(b'e'),
1061                Some(true),
1062                Some("null".to_string())
1063            ),
1064            Err("COPY delimiter and quote must be different".to_string())
1065        );
1066    }
1067
1068    #[mz_ore::test]
1069    fn test_copy_csv_row() -> Result<(), io::Error> {
1070        let mut row = Row::default();
1071        let mut packer = row.packer();
1072        packer.push(Datum::from("1,2,\"3\""));
1073        packer.push(Datum::Null);
1074        packer.push(Datum::from(1000u64));
1075        packer.push(Datum::from("qe")); // overridden quote and escape character in test below
1076        packer.push(Datum::from(""));
1077
1078        let typ: RelationType = RelationType::new(vec![
1079            ColumnType {
1080                scalar_type: mz_repr::ScalarType::String,
1081                nullable: false,
1082            },
1083            ColumnType {
1084                scalar_type: mz_repr::ScalarType::String,
1085                nullable: true,
1086            },
1087            ColumnType {
1088                scalar_type: mz_repr::ScalarType::UInt64,
1089                nullable: false,
1090            },
1091            ColumnType {
1092                scalar_type: mz_repr::ScalarType::String,
1093                nullable: false,
1094            },
1095            ColumnType {
1096                scalar_type: mz_repr::ScalarType::String,
1097                nullable: false,
1098            },
1099        ]);
1100
1101        let mut out = Vec::new();
1102
1103        struct TestCase<'a> {
1104            params: CopyCsvFormatParams<'a>,
1105            expected: &'static [u8],
1106        }
1107
1108        let tests = [
1109            TestCase {
1110                params: CopyCsvFormatParams::default(),
1111                expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1112            },
1113            TestCase {
1114                params: CopyCsvFormatParams {
1115                    null: Cow::from("NULL"),
1116                    quote: b'q',
1117                    escape: b'e',
1118                    ..Default::default()
1119                },
1120                expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1121            },
1122        ];
1123
1124        for TestCase { params, expected } in tests {
1125            out.clear();
1126            let params = CopyFormatParams::Csv(params);
1127            let _ = encode_copy_format(&params, &row, &typ, &mut out);
1128            let output = std::str::from_utf8(&out);
1129            assert_eq!(output, std::str::from_utf8(expected));
1130        }
1131
1132        Ok(())
1133    }
1134
1135    proptest! {
1136        #[mz_ore::test]
1137        #[cfg_attr(miri, ignore)]
1138        fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams)  {
1139            // Given a ScalarType and Datum roundtrips it through the CSV COPY format.
1140            let try_roundtrip_datum = |scalar_type: &ScalarType, datum| {
1141                let row = Row::pack_slice(&[datum]);
1142                let typ = RelationType::new(vec![
1143                    ColumnType {
1144                        scalar_type: scalar_type.clone(),
1145                        nullable: true,
1146                    }
1147                ]);
1148
1149                let mut buf = Vec::new();
1150                let mut csv_params = copy_csv_params.clone();
1151                // TODO: Encoding never writes a header.
1152                csv_params.header = false;
1153                let params = CopyFormatParams::Csv(csv_params);
1154
1155                // Roundtrip the Row through our CSV format.
1156                encode_copy_format(&params, &row, &typ, &mut buf)?;
1157                let column_types = typ
1158                    .column_types
1159                    .iter()
1160                    .map(|x| &x.scalar_type)
1161                    .map(mz_pgrepr::Type::from)
1162                    .collect::<Vec<mz_pgrepr::Type>>();
1163                let result = decode_copy_format(&buf, &column_types, params);
1164
1165                match result {
1166                    Ok(rows) => {
1167                        let out_str = std::str::from_utf8(&buf[..]);
1168
1169                        prop_assert_eq!(
1170                            rows.len(),
1171                            1,
1172                            "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1173                        );
1174                        let output = rows.into_element();
1175
1176                        prop_assert_eq!(
1177                            row,
1178                            output,
1179                            "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1180                        );
1181                    }
1182                    _ => {
1183                        // ignoring decoding failures
1184                    }
1185                }
1186
1187                Ok(())
1188            };
1189
1190            // Try roundtripping all of our interesting Datums.
1191            for scalar_type in ScalarType::enumerate() {
1192                for datum in scalar_type.interesting_datums() {
1193                    // TODO: The decoder cannot differentiate between empty string and null.
1194                    if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1195                        let mut buf = bytes::BytesMut::new();
1196                        value.encode_text(&mut buf);
1197
1198                        if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1199                            if datum_str == copy_csv_params.null {
1200                                continue;
1201                            }
1202                        }
1203                    }
1204
1205                    let updated_datum = match datum {
1206                        // TODO: Fix roundtrip decoding of these types.
1207                        Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1208                            continue;
1209                        }
1210                        Datum::String(s) => {
1211                            // TODO: The decoder cannot differentiate between empty string and null.
1212                            if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1213                                continue;
1214                            } else {
1215                                Datum::String(s)
1216                            }
1217                        }
1218                        other => other,
1219                    };
1220
1221                    let result = try_roundtrip_datum(scalar_type, updated_datum);
1222                    prop_assert!(result.is_ok(), "failure: {result:?}");
1223                }
1224            }
1225        }
1226    }
1227}