Skip to main content

mz_pgcopy/
copy.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use itertools::Itertools;
15use mz_repr::{
16    Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
17    SqlScalarType,
18};
19use proptest::prelude::{Arbitrary, any};
20use proptest::strategy::{BoxedStrategy, Strategy};
21use serde::{Deserialize, Serialize};
22
23static END_OF_COPY_MARKER: &[u8] = b"\\.";
24
25fn encode_copy_row_binary(
26    row: &RowRef,
27    typ: &SqlRelationType,
28    out: &mut Vec<u8>,
29) -> Result<(), io::Error> {
30    const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
31
32    // 16-bit int of number of tuples.
33    let count = i16::try_from(typ.column_types.len()).map_err(|_| {
34        io::Error::new(
35            io::ErrorKind::InvalidData,
36            "column count does not fit into an i16",
37        )
38    })?;
39
40    out.extend(count.to_be_bytes());
41    let mut buf = BytesMut::new();
42    for (field, typ) in row
43        .iter()
44        .zip_eq(&typ.column_types)
45        .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
46    {
47        match field {
48            None => out.extend(NULL_BYTES),
49            Some(field) => {
50                buf.clear();
51                field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
52                out.extend(
53                    i32::try_from(buf.len())
54                        .map_err(|_| {
55                            io::Error::new(
56                                io::ErrorKind::InvalidData,
57                                "field length does not fit into an i32",
58                            )
59                        })?
60                        .to_be_bytes(),
61                );
62                out.extend(&buf);
63            }
64        }
65    }
66    Ok(())
67}
68
69fn encode_copy_row_text(
70    CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
71    row: &RowRef,
72    typ: &SqlRelationType,
73    out: &mut Vec<u8>,
74) -> Result<(), io::Error> {
75    let null = null.as_bytes();
76    let mut buf = BytesMut::new();
77    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
78        if idx > 0 {
79            out.push(*delimiter);
80        }
81        match field {
82            None => out.extend(null),
83            Some(field) => {
84                buf.clear();
85                field.encode_text(&mut buf);
86                for b in &buf {
87                    match b {
88                        b'\\' => out.extend(b"\\\\"),
89                        b'\n' => out.extend(b"\\n"),
90                        b'\r' => out.extend(b"\\r"),
91                        b'\t' => out.extend(b"\\t"),
92                        _ => out.push(*b),
93                    }
94                }
95            }
96        }
97    }
98    out.push(b'\n');
99    Ok(())
100}
101
102fn encode_copy_row_csv(
103    CopyCsvFormatParams {
104        delimiter: delim,
105        quote,
106        escape,
107        header: _,
108        null,
109    }: &CopyCsvFormatParams,
110    row: &RowRef,
111    typ: &SqlRelationType,
112    out: &mut Vec<u8>,
113) -> Result<(), io::Error> {
114    let null = null.as_bytes();
115    let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
116    let mut buf = BytesMut::new();
117    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
118        if idx > 0 {
119            out.push(*delim);
120        }
121        match field {
122            None => out.extend(null),
123            Some(field) => {
124                buf.clear();
125                field.encode_text(&mut buf);
126                // A field needs quoting if:
127                //   * It is the only field and the value is exactly the end
128                //     of copy marker.
129                //   * The field contains a special character.
130                //   * The field is exactly the NULL sentinel.
131                if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
132                    || buf.iter().any(is_special)
133                    || &*buf == null
134                {
135                    // Quote the value by wrapping it in the quote character and
136                    // emitting the escape character before any quote or escape
137                    // characters within.
138                    out.push(*quote);
139                    for b in &buf {
140                        if *b == *quote || *b == *escape {
141                            out.push(*escape);
142                        }
143                        out.push(*b);
144                    }
145                    out.push(*quote);
146                } else {
147                    // The value does not need quoting and can be emitted
148                    // directly.
149                    out.extend(&buf);
150                }
151            }
152        }
153    }
154    out.push(b'\n');
155    Ok(())
156}
157
158pub struct CopyTextFormatParser<'a> {
159    data: &'a [u8],
160    position: usize,
161    column_delimiter: u8,
162    null_string: &'a str,
163    buffer: Vec<u8>,
164}
165
166impl<'a> CopyTextFormatParser<'a> {
167    pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
168        Self {
169            data,
170            position: 0,
171            column_delimiter,
172            null_string,
173            buffer: Vec::new(),
174        }
175    }
176
177    fn peek(&self) -> Option<u8> {
178        if self.position < self.data.len() {
179            Some(self.data[self.position])
180        } else {
181            None
182        }
183    }
184
185    fn consume_n(&mut self, n: usize) {
186        self.position = std::cmp::min(self.position + n, self.data.len());
187    }
188
189    pub fn is_eof(&self) -> bool {
190        self.peek().is_none() || self.is_end_of_copy_marker()
191    }
192
193    pub fn is_end_of_copy_marker(&self) -> bool {
194        self.check_bytes(END_OF_COPY_MARKER)
195    }
196
197    fn is_end_of_line(&self) -> bool {
198        match self.peek() {
199            Some(b'\n') | None => true,
200            _ => false,
201        }
202    }
203
204    pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
205        if self.is_end_of_line() {
206            self.consume_n(1);
207            Ok(())
208        } else {
209            Err(io::Error::new(
210                io::ErrorKind::InvalidData,
211                "extra data after last expected column",
212            ))
213        }
214    }
215
216    fn is_column_delimiter(&self) -> bool {
217        self.check_bytes(&[self.column_delimiter])
218    }
219
220    pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
221        if self.consume_bytes(&[self.column_delimiter]) {
222            Ok(())
223        } else {
224            Err(io::Error::new(
225                io::ErrorKind::InvalidData,
226                "missing data for column",
227            ))
228        }
229    }
230
231    fn check_bytes(&self, bytes: &[u8]) -> bool {
232        self.data
233            .get(self.position..self.position + bytes.len())
234            .map_or(false, |d| d == bytes)
235    }
236
237    fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
238        if self.check_bytes(bytes) {
239            self.consume_n(bytes.len());
240            true
241        } else {
242            false
243        }
244    }
245
246    fn consume_null_string(&mut self) -> bool {
247        if self.null_string.is_empty() {
248            // An empty NULL marker is supported. Look ahead to ensure that is followed by
249            // a column delimiter, an end of line or it is at the end of the data.
250            self.is_column_delimiter()
251                || self.is_end_of_line()
252                || self.is_end_of_copy_marker()
253                || self.is_eof()
254        } else {
255            self.consume_bytes(self.null_string.as_bytes())
256        }
257    }
258
259    pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
260        if self.consume_null_string() {
261            return Ok(None);
262        }
263
264        let mut start = self.position;
265
266        // buffer where unescaped data is accumulated
267        self.buffer.clear();
268
269        while !self.is_eof() && !self.is_end_of_copy_marker() {
270            if self.is_end_of_line() || self.is_column_delimiter() {
271                break;
272            }
273            match self.peek() {
274                Some(b'\\') => {
275                    // Add non-escaped data parsed so far
276                    self.buffer.extend(&self.data[start..self.position]);
277
278                    self.consume_n(1);
279                    match self.peek() {
280                        Some(b'b') => {
281                            self.consume_n(1);
282                            self.buffer.push(8);
283                        }
284                        Some(b'f') => {
285                            self.consume_n(1);
286                            self.buffer.push(12);
287                        }
288                        Some(b'n') => {
289                            self.consume_n(1);
290                            self.buffer.push(b'\n');
291                        }
292                        Some(b'r') => {
293                            self.consume_n(1);
294                            self.buffer.push(b'\r');
295                        }
296                        Some(b't') => {
297                            self.consume_n(1);
298                            self.buffer.push(b'\t');
299                        }
300                        Some(b'v') => {
301                            self.consume_n(1);
302                            self.buffer.push(11);
303                        }
304                        Some(b'x') => {
305                            self.consume_n(1);
306                            match self.peek() {
307                                Some(_c @ b'0'..=b'9')
308                                | Some(_c @ b'A'..=b'F')
309                                | Some(_c @ b'a'..=b'f') => {
310                                    let mut value: u8 = 0;
311                                    let decode_nibble = |b| match b {
312                                        Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
313                                        Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
314                                        Some(c @ b'0'..=b'9') => Some(c - b'0'),
315                                        _ => None,
316                                    };
317                                    for _ in 0..2 {
318                                        match decode_nibble(self.peek()) {
319                                            Some(c) => {
320                                                self.consume_n(1);
321                                                value = (value << 4) | c;
322                                            }
323                                            _ => break,
324                                        }
325                                    }
326                                    self.buffer.push(value);
327                                }
328                                _ => {
329                                    self.buffer.push(b'x');
330                                }
331                            }
332                        }
333                        Some(_c @ b'0'..=b'7') => {
334                            let mut value: u8 = 0;
335                            for _ in 0..3 {
336                                match self.peek() {
337                                    Some(c @ b'0'..=b'7') => {
338                                        self.consume_n(1);
339                                        value = (value << 3) | (c - b'0');
340                                    }
341                                    _ => break,
342                                }
343                            }
344                            self.buffer.push(value);
345                        }
346                        Some(c) => {
347                            self.consume_n(1);
348                            self.buffer.push(c);
349                        }
350                        None => {
351                            self.buffer.push(b'\\');
352                        }
353                    }
354
355                    start = self.position;
356                }
357                Some(_) => {
358                    self.consume_n(1);
359                }
360                None => {}
361            }
362        }
363
364        // Return a slice of the original buffer if no escaped characters where processed
365        if self.buffer.is_empty() {
366            Ok(Some(&self.data[start..self.position]))
367        } else {
368            // ... otherwise, add the remaining non-escaped data to the decoding buffer
369            // and return a pointer to it
370            self.buffer.extend(&self.data[start..self.position]);
371            Ok(Some(&self.buffer[..]))
372        }
373    }
374
375    /// Error if more than `num_columns` values in `parser`.
376    pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
377        RawIterator {
378            parser: self,
379            current_column: 0,
380            num_columns,
381            truncate: false,
382        }
383    }
384
385    /// Return no more than `num_columns` values from `parser`.
386    pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
387        RawIterator {
388            parser: self,
389            current_column: 0,
390            num_columns,
391            truncate: true,
392        }
393    }
394}
395
396pub struct RawIterator<'a> {
397    parser: CopyTextFormatParser<'a>,
398    current_column: usize,
399    num_columns: usize,
400    truncate: bool,
401}
402
403impl<'a> RawIterator<'a> {
404    pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
405        if self.current_column > self.num_columns {
406            return None;
407        }
408
409        if self.current_column == self.num_columns {
410            if !self.truncate {
411                if let Some(err) = self.parser.expect_end_of_line().err() {
412                    return Some(Err(err));
413                }
414            }
415
416            return None;
417        }
418
419        if self.current_column > 0 {
420            if let Some(err) = self.parser.expect_column_delimiter().err() {
421                return Some(Err(err));
422            }
423        }
424
425        self.current_column += 1;
426        Some(self.parser.consume_raw_value())
427    }
428}
429
430#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
431pub enum CopyFormatParams<'a> {
432    Text(CopyTextFormatParams<'a>),
433    Csv(CopyCsvFormatParams<'a>),
434    Binary,
435    Parquet,
436}
437
438impl CopyFormatParams<'static> {
439    pub fn file_extension(&self) -> &str {
440        match self {
441            &CopyFormatParams::Text(_) => "txt",
442            &CopyFormatParams::Csv(_) => "csv",
443            &CopyFormatParams::Binary => "bin",
444            &CopyFormatParams::Parquet => "parquet",
445        }
446    }
447
448    pub fn requires_header(&self) -> bool {
449        match self {
450            CopyFormatParams::Text(_) => false,
451            CopyFormatParams::Csv(params) => params.header,
452            CopyFormatParams::Binary => false,
453            CopyFormatParams::Parquet => false,
454        }
455    }
456}
457
458/// Decodes the given bytes into `Row`-s based on the given `CopyFormatParams`.
459pub fn decode_copy_format<'a>(
460    data: &[u8],
461    column_types: &[mz_pgrepr::Type],
462    params: CopyFormatParams<'a>,
463) -> Result<Vec<Row>, io::Error> {
464    match params {
465        CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
466        CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
467        CopyFormatParams::Binary => Err(io::Error::new(
468            io::ErrorKind::Unsupported,
469            "cannot decode as binary format",
470        )),
471        CopyFormatParams::Parquet => {
472            // TODO(cf2): Support Parquet over STDIN.
473            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
474        }
475    }
476}
477
478/// Encodes the given `Row` into bytes based on the given `CopyFormatParams`.
479pub fn encode_copy_format<'a>(
480    params: &CopyFormatParams<'a>,
481    row: &RowRef,
482    typ: &SqlRelationType,
483    out: &mut Vec<u8>,
484) -> Result<(), io::Error> {
485    match params {
486        CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
487        CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
488        CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
489        CopyFormatParams::Parquet => {
490            // TODO(cf2): Support Parquet over STDIN.
491            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
492        }
493    }
494}
495
496pub fn encode_copy_format_header<'a>(
497    params: &CopyFormatParams<'a>,
498    desc: &RelationDesc,
499    out: &mut Vec<u8>,
500) -> Result<(), io::Error> {
501    match params {
502        CopyFormatParams::Text(_) => Ok(()),
503        CopyFormatParams::Binary => Ok(()),
504        CopyFormatParams::Csv(params) => {
505            let mut header_row = Row::with_capacity(desc.arity());
506            header_row
507                .packer()
508                .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
509            let typ = SqlRelationType::new(vec![
510                SqlColumnType {
511                    scalar_type: SqlScalarType::String,
512                    nullable: false,
513                };
514                desc.arity()
515            ]);
516            encode_copy_row_csv(params, &header_row, &typ, out)
517        }
518        CopyFormatParams::Parquet => {
519            // TODO(cf2): Support Parquet over STDIN.
520            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
521        }
522    }
523}
524
525#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
526pub struct CopyTextFormatParams<'a> {
527    pub null: Cow<'a, str>,
528    pub delimiter: u8,
529}
530
531impl<'a> Default for CopyTextFormatParams<'a> {
532    fn default() -> Self {
533        CopyTextFormatParams {
534            delimiter: b'\t',
535            null: Cow::from("\\N"),
536        }
537    }
538}
539
540pub fn decode_copy_format_text(
541    data: &[u8],
542    column_types: &[mz_pgrepr::Type],
543    CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
544) -> Result<Vec<Row>, io::Error> {
545    let mut rows = Vec::new();
546
547    // TODO: pass the `CopyTextFormatParams` to the `new` method
548    let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
549    while !parser.is_eof() && !parser.is_end_of_copy_marker() {
550        let mut row = Vec::new();
551        let buf = RowArena::new();
552        for (col, typ) in column_types.iter().enumerate() {
553            if col > 0 {
554                parser.expect_column_delimiter()?;
555            }
556            let raw_value = parser.consume_raw_value()?;
557            if let Some(raw_value) = raw_value {
558                match mz_pgrepr::Value::decode_text(typ, raw_value) {
559                    Ok(value) => {
560                        row.push(
561                            value
562                                .into_datum_decode_error(&buf, typ, "column")
563                                .map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?,
564                        );
565                    }
566                    Err(err) => {
567                        let msg = format!("unable to decode column: {}", err);
568                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
569                    }
570                }
571            } else {
572                row.push(Datum::Null);
573            }
574        }
575        parser.expect_end_of_line()?;
576        rows.push(Row::pack(row));
577    }
578    // Note that if there is any junk data after the end of copy marker, we drop
579    // it on the floor as PG does.
580    Ok(rows)
581}
582
583#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
584pub struct CopyCsvFormatParams<'a> {
585    pub delimiter: u8,
586    pub quote: u8,
587    pub escape: u8,
588    pub header: bool,
589    pub null: Cow<'a, str>,
590}
591
592impl<'a> CopyCsvFormatParams<'a> {
593    pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
594        CopyCsvFormatParams {
595            delimiter: self.delimiter,
596            quote: self.quote,
597            escape: self.escape,
598            header: self.header,
599            null: Cow::Owned(self.null.to_string()),
600        }
601    }
602}
603
604impl Arbitrary for CopyCsvFormatParams<'static> {
605    type Parameters = ();
606    type Strategy = BoxedStrategy<Self>;
607
608    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
609        (
610            any::<u8>(),
611            any::<u8>(),
612            any::<u8>(),
613            any::<bool>(),
614            any::<String>(),
615        )
616            .prop_map(|(delimiter, diff, escape, header, null)| {
617                // Delimiter and Quote need to be different.
618                let diff = diff.saturating_sub(1).max(1);
619                let quote = delimiter.wrapping_add(diff);
620
621                Self::try_new(
622                    Some(delimiter),
623                    Some(quote),
624                    Some(escape),
625                    Some(header),
626                    Some(null),
627                )
628                .expect("delimiter and quote should be different")
629            })
630            .boxed()
631    }
632}
633
634impl<'a> Default for CopyCsvFormatParams<'a> {
635    fn default() -> Self {
636        CopyCsvFormatParams {
637            delimiter: b',',
638            quote: b'"',
639            escape: b'"',
640            header: false,
641            null: Cow::from(""),
642        }
643    }
644}
645
646impl<'a> CopyCsvFormatParams<'a> {
647    pub fn try_new(
648        delimiter: Option<u8>,
649        quote: Option<u8>,
650        escape: Option<u8>,
651        header: Option<bool>,
652        null: Option<String>,
653    ) -> Result<CopyCsvFormatParams<'a>, String> {
654        let mut params = CopyCsvFormatParams::default();
655
656        if let Some(delimiter) = delimiter {
657            params.delimiter = delimiter;
658        }
659        if let Some(quote) = quote {
660            params.quote = quote;
661            // escape defaults to the value provided for quote
662            params.escape = quote;
663        }
664        if let Some(escape) = escape {
665            params.escape = escape;
666        }
667        if let Some(header) = header {
668            params.header = header;
669        }
670        if let Some(null) = null {
671            params.null = Cow::from(null);
672        }
673
674        if params.quote == params.delimiter {
675            return Err("COPY delimiter and quote must be different".to_string());
676        }
677        Ok(params)
678    }
679}
680
681/// One field decoded out of a CSV record by [`decode_copy_format_csv`]:
682/// `start..end` indexes into the per-record `output` buffer (csv-core
683/// unquotes/unescapes into that buffer), and `quoted` records whether the
684/// field's first input byte was the quote character. The quote flag is what
685/// distinguishes a literal `""` (quoted empty string) from an unquoted empty
686/// field (the default NULL marker), and a quoted `"\."` data row from the bare
687/// `\.` end-of-copy marker.
688struct DecodedField {
689    start: usize,
690    end: usize,
691    quoted: bool,
692}
693
694pub fn decode_copy_format_csv(
695    data: &[u8],
696    column_types: &[mz_pgrepr::Type],
697    CopyCsvFormatParams {
698        delimiter,
699        quote,
700        escape,
701        null,
702        header,
703    }: CopyCsvFormatParams,
704) -> Result<Vec<Row>, io::Error> {
705    let (double_quote, escape) = if quote == escape {
706        (true, None)
707    } else {
708        (false, Some(escape))
709    };
710
711    let mut rdr = csv_core::ReaderBuilder::new()
712        .delimiter(delimiter)
713        .quote(quote)
714        .escape(escape)
715        .double_quote(double_quote)
716        .build();
717
718    let null_as_bytes = null.as_bytes();
719    let mut rows = Vec::new();
720    // We use csv-core (rather than the higher-level csv crate) so we can
721    // recover per-field "was this field quoted?" information by inspecting the
722    // first byte of each field's input. csv unquotes during parsing, which
723    // makes a quoted empty string indistinguishable from an unquoted empty
724    // field — and PostgreSQL COPY ... FORMAT CSV semantics need that
725    // distinction to honor the NULL marker.
726    let mut input = data;
727    let mut output = vec![0u8; data.len().max(1024)];
728    let mut out_pos = 0;
729    let mut fields: Vec<DecodedField> = Vec::new();
730    let mut field_start = 0;
731    let mut field_quoted: Option<bool> = None;
732    let mut skip_header = header;
733    // True at the start of a record (including the very first), where csv-core
734    // may have left an orphaned terminator byte; false for fields that follow a
735    // delimiter within a record.
736    let mut at_record_start = true;
737
738    loop {
739        if field_quoted.is_none() {
740            if at_record_start {
741                // csv-core's default terminator is CRLF, and it reports a
742                // record complete after consuming the `\r`, leaving the
743                // trailing `\n` as the first byte of the next record's input
744                // (and, after a worker chunk is split at a `\r` boundary, a
745                // chunk can likewise begin with that orphan `\n`). csv-core
746                // itself consumes and ignores those stray terminator bytes when
747                // parsing the field, but the quote probe below must skip them
748                // so it inspects the field's real first byte rather than an
749                // orphan — otherwise a quoted field on any non-first CRLF record
750                // is misclassified as unquoted. Only skip at a record boundary:
751                // a `\r`/`\n` after a delimiter legitimately terminates an empty
752                // trailing field and must not be swallowed.
753                while matches!(input.first(), Some(&b'\r') | Some(&b'\n')) {
754                    input = &input[1..];
755                }
756            }
757            field_quoted = Some(input.first() == Some(&quote));
758            field_start = out_pos;
759        }
760
761        let (result, nin, nout) = rdr.read_field(input, &mut output[out_pos..]);
762        input = &input[nin..];
763        out_pos += nout;
764
765        match result {
766            csv_core::ReadFieldResult::Field { record_end } => {
767                fields.push(DecodedField {
768                    start: field_start,
769                    end: out_pos,
770                    quoted: field_quoted.take().unwrap(),
771                });
772                // The next field begins a new record only if this field ended
773                // one; otherwise it follows a delimiter mid-record.
774                at_record_start = record_end;
775
776                if record_end {
777                    if skip_header {
778                        skip_header = false;
779                    } else if let [
780                        DecodedField {
781                            start,
782                            end,
783                            quoted: false,
784                        },
785                    ] = fields[..]
786                        && &output[start..end] == END_OF_COPY_MARKER
787                    {
788                        // Bare `\.` on its own line: end-of-copy marker. A
789                        // quoted `"\."` also decodes to `\.` but is data, so
790                        // only an unquoted match terminates the import.
791                        return Ok(rows);
792                    } else {
793                        match fields.len().cmp(&column_types.len()) {
794                            std::cmp::Ordering::Less => {
795                                return Err(io::Error::new(
796                                    io::ErrorKind::InvalidData,
797                                    "missing data for column",
798                                ));
799                            }
800                            std::cmp::Ordering::Greater => {
801                                return Err(io::Error::new(
802                                    io::ErrorKind::InvalidData,
803                                    "extra data after last expected column",
804                                ));
805                            }
806                            std::cmp::Ordering::Equal => {}
807                        }
808
809                        let mut row_builder = SharedRow::get();
810                        let mut row_packer = row_builder.packer();
811                        for (typ, field) in column_types.iter().zip_eq(fields.iter()) {
812                            let raw_value = &output[field.start..field.end];
813                            if !field.quoted && raw_value == null_as_bytes {
814                                row_packer.push(Datum::Null);
815                            } else {
816                                let s = match std::str::from_utf8(raw_value) {
817                                    Ok(s) => s,
818                                    Err(err) => {
819                                        let msg = format!("invalid utf8 data in column: {}", err);
820                                        return Err(io::Error::new(
821                                            io::ErrorKind::InvalidData,
822                                            msg,
823                                        ));
824                                    }
825                                };
826                                if let Err(err) =
827                                    mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer)
828                                {
829                                    let msg = format!("unable to decode column: {}", err);
830                                    return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
831                                }
832                            }
833                        }
834                        rows.push(row_builder.clone());
835                    }
836                    fields.clear();
837                    out_pos = 0;
838                }
839            }
840            csv_core::ReadFieldResult::OutputFull => {
841                let new_len = output.len().saturating_mul(2).max(out_pos + 1);
842                output.resize(new_len, 0);
843            }
844            csv_core::ReadFieldResult::InputEmpty => {
845                // InputEmpty means csv-core consumed all our input mid-field
846                // and wants more. We have no more to give, so the next
847                // iteration calls it with the now-empty slice; per its
848                // documented termination protocol, that yields the partial
849                // field (if any) as a final `Field`, then `End`.
850            }
851            csv_core::ReadFieldResult::End => return Ok(rows),
852        }
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use mz_ore::collections::CollectionExt;
859    use mz_repr::SqlColumnType;
860    use proptest::prelude::*;
861
862    use super::*;
863
864    #[mz_ore::test]
865    fn test_copy_format_text_parser() {
866        let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
867        let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
868        assert!(parser.is_column_delimiter());
869        parser
870            .expect_column_delimiter()
871            .expect("expected column delimiter");
872        assert_eq!(
873            parser
874                .consume_raw_value()
875                .expect("unexpected error")
876                .expect("unexpected empty result"),
877            "\nt e".as_bytes()
878        );
879        parser
880            .expect_column_delimiter()
881            .expect("expected column delimiter");
882        // null value
883        assert!(
884            parser
885                .consume_raw_value()
886                .expect("unexpected error")
887                .is_none()
888        );
889        parser
890            .expect_column_delimiter()
891            .expect("expected column delimiter");
892        assert!(parser.is_end_of_line());
893        parser.expect_end_of_line().expect("expected eol");
894        // hex value
895        assert_eq!(
896            parser
897                .consume_raw_value()
898                .expect("unexpected error")
899                .expect("unexpected empty result"),
900            "`\n}J".as_bytes()
901        );
902        parser.expect_end_of_line().expect("expected eol");
903        // octal value
904        assert_eq!(
905            parser
906                .consume_raw_value()
907                .expect("unexpected error")
908                .expect("unexpected empty result"),
909            "$$S".as_bytes()
910        );
911        assert!(parser.is_eof());
912    }
913
914    #[mz_ore::test]
915    fn test_copy_format_text_empty_null_string() {
916        let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
917        let expect = vec![
918            vec![None, None],
919            vec![Some("10"), Some("20")],
920            vec![Some("30"), None],
921            vec![Some("40"), None],
922        ];
923        let mut parser = CopyTextFormatParser::new(text, b'\t', "");
924        for line in expect {
925            for (i, value) in line.iter().enumerate() {
926                if i > 0 {
927                    parser
928                        .expect_column_delimiter()
929                        .expect("expected column delimiter");
930                }
931                match value {
932                    Some(s) => {
933                        assert!(!parser.consume_null_string());
934                        assert_eq!(
935                            parser
936                                .consume_raw_value()
937                                .expect("unexpected error")
938                                .expect("unexpected empty result"),
939                            s.as_bytes()
940                        );
941                    }
942                    None => {
943                        assert!(parser.consume_null_string());
944                    }
945                }
946            }
947            parser.expect_end_of_line().expect("expected eol");
948        }
949    }
950
951    #[mz_ore::test]
952    fn test_copy_format_text_parser_escapes() {
953        struct TestCase {
954            input: &'static str,
955            expect: &'static [u8],
956        }
957        let tests = vec![
958            TestCase {
959                input: "simple",
960                expect: b"simple",
961            },
962            TestCase {
963                input: r#"new\nline"#,
964                expect: b"new\nline",
965            },
966            TestCase {
967                input: r#"\b\f\n\r\t\v\\"#,
968                expect: b"\x08\x0c\n\r\t\x0b\\",
969            },
970            TestCase {
971                input: r#"\0\12\123"#,
972                expect: &[0, 0o12, 0o123],
973            },
974            TestCase {
975                input: r#"\x1\xaf"#,
976                expect: &[0x01, 0xaf],
977            },
978            TestCase {
979                input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
980                expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
981            },
982            TestCase {
983                input: r#"\\\""#,
984                expect: b"\\\"",
985            },
986            TestCase {
987                input: r#"\x"#,
988                expect: b"x",
989            },
990            TestCase {
991                input: r#"\xg"#,
992                expect: b"xg",
993            },
994            TestCase {
995                input: r#"\"#,
996                expect: b"\\",
997            },
998            TestCase {
999                input: r#"\8"#,
1000                expect: b"8",
1001            },
1002            TestCase {
1003                input: r#"\a"#,
1004                expect: b"a",
1005            },
1006            TestCase {
1007                input: r#"\x\xg\8\xH\x32\s\"#,
1008                expect: b"xxg8xH2s\\",
1009            },
1010        ];
1011
1012        for test in tests {
1013            let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
1014            assert_eq!(
1015                parser
1016                    .consume_raw_value()
1017                    .expect("unexpected error")
1018                    .expect("unexpected empty result"),
1019                test.expect,
1020                "input: {}, expect: {:?}",
1021                test.input,
1022                std::str::from_utf8(test.expect),
1023            );
1024            assert!(parser.is_eof());
1025        }
1026    }
1027
1028    #[mz_ore::test]
1029    fn test_copy_csv_format_params() {
1030        assert_eq!(
1031            CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
1032            Ok(CopyCsvFormatParams {
1033                delimiter: b't',
1034                quote: b'q',
1035                escape: b'q',
1036                header: false,
1037                null: Cow::from(""),
1038            })
1039        );
1040
1041        assert_eq!(
1042            CopyCsvFormatParams::try_new(
1043                Some(b't'),
1044                Some(b'q'),
1045                Some(b'e'),
1046                Some(true),
1047                Some("null".to_string())
1048            ),
1049            Ok(CopyCsvFormatParams {
1050                delimiter: b't',
1051                quote: b'q',
1052                escape: b'e',
1053                header: true,
1054                null: Cow::from("null"),
1055            })
1056        );
1057
1058        assert_eq!(
1059            CopyCsvFormatParams::try_new(
1060                None,
1061                Some(b','),
1062                Some(b'e'),
1063                Some(true),
1064                Some("null".to_string())
1065            ),
1066            Err("COPY delimiter and quote must be different".to_string())
1067        );
1068    }
1069
1070    #[mz_ore::test]
1071    fn test_copy_csv_row() -> Result<(), io::Error> {
1072        let mut row = Row::default();
1073        let mut packer = row.packer();
1074        packer.push(Datum::from("1,2,\"3\""));
1075        packer.push(Datum::Null);
1076        packer.push(Datum::from(1000u64));
1077        packer.push(Datum::from("qe")); // overridden quote and escape character in test below
1078        packer.push(Datum::from(""));
1079
1080        let typ: SqlRelationType = SqlRelationType::new(vec![
1081            SqlColumnType {
1082                scalar_type: mz_repr::SqlScalarType::String,
1083                nullable: false,
1084            },
1085            SqlColumnType {
1086                scalar_type: mz_repr::SqlScalarType::String,
1087                nullable: true,
1088            },
1089            SqlColumnType {
1090                scalar_type: mz_repr::SqlScalarType::UInt64,
1091                nullable: false,
1092            },
1093            SqlColumnType {
1094                scalar_type: mz_repr::SqlScalarType::String,
1095                nullable: false,
1096            },
1097            SqlColumnType {
1098                scalar_type: mz_repr::SqlScalarType::String,
1099                nullable: false,
1100            },
1101        ]);
1102
1103        let mut out = Vec::new();
1104
1105        struct TestCase<'a> {
1106            params: CopyCsvFormatParams<'a>,
1107            expected: &'static [u8],
1108        }
1109
1110        let tests = [
1111            TestCase {
1112                params: CopyCsvFormatParams::default(),
1113                expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1114            },
1115            TestCase {
1116                params: CopyCsvFormatParams {
1117                    null: Cow::from("NULL"),
1118                    quote: b'q',
1119                    escape: b'e',
1120                    ..Default::default()
1121                },
1122                expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1123            },
1124        ];
1125
1126        for TestCase { params, expected } in tests {
1127            out.clear();
1128            let params = CopyFormatParams::Csv(params);
1129            let _ = encode_copy_format(&params, &row, &typ, &mut out);
1130            let output = std::str::from_utf8(&out);
1131            assert_eq!(output, std::str::from_utf8(expected));
1132        }
1133
1134        Ok(())
1135    }
1136
1137    #[mz_ore::test]
1138    fn test_decode_copy_format_csv_end_marker() {
1139        // Bare `\.` on its own line terminates the COPY. A quoted `"\."`
1140        // decodes to the same bytes but must be treated as data. This must
1141        // hold for every line ending (LF, CRLF, CR): csv-core places a CRLF
1142        // record boundary between `\r` and `\n`, so a naive raw-byte check is
1143        // fooled by the orphaned terminator bytes on CRLF/CR input.
1144        let column_types = vec![mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String)];
1145
1146        let decode_strings = |input: &[u8]| -> Vec<String> {
1147            decode_copy_format_csv(input, &column_types, CopyCsvFormatParams::default())
1148                .expect("decode should succeed")
1149                .iter()
1150                .map(|r| match r.iter().next().unwrap() {
1151                    Datum::String(s) => s.to_owned(),
1152                    d => panic!("unexpected datum: {:?}", d),
1153                })
1154                .collect()
1155        };
1156
1157        for eol in [&b"\n"[..], b"\r\n", b"\r"] {
1158            let join = |lines: &[&str]| -> Vec<u8> {
1159                let mut out = Vec::new();
1160                for line in lines {
1161                    out.extend_from_slice(line.as_bytes());
1162                    out.extend_from_slice(eol);
1163                }
1164                out
1165            };
1166
1167            // Quoted "\." is data — all three rows are imported.
1168            assert_eq!(
1169                decode_strings(&join(&["before", "\"\\.\"", "after"])),
1170                vec!["before", "\\.", "after"],
1171                "quoted marker, eol={eol:?}"
1172            );
1173
1174            // Bare `\.` terminates the COPY; rows after it are dropped.
1175            assert_eq!(
1176                decode_strings(&join(&["first", "\\.", "ignored"])),
1177                vec!["first"],
1178                "bare marker, eol={eol:?}"
1179            );
1180        }
1181    }
1182
1183    #[mz_ore::test]
1184    fn test_decode_copy_format_csv_leading_orphan() {
1185        // Worker chunks after the first begin at a `\r` boundary, so for CRLF
1186        // input a chunk's bytes start with the orphan `\n` that csv-core left
1187        // behind. The decoder must still classify the first field's quote
1188        // state from its real first byte, not the stray `\n`. Regression test
1189        // for the per-chunk quote probe.
1190        let column_types = vec![
1191            mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1192            mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1193        ];
1194
1195        let rows = decode_copy_format_csv(
1196            b"\n\"\",x\r\ny,z\r\n",
1197            &column_types,
1198            CopyCsvFormatParams::default(),
1199        )
1200        .expect("decode should succeed");
1201        let got: Vec<Vec<Datum>> = rows.iter().map(|r| r.iter().collect()).collect();
1202        assert_eq!(got.len(), 2);
1203        // Quoted empty first field on the leading-orphan record must decode to
1204        // the empty string, not SQL NULL (the bug would misread it as
1205        // unquoted and match the default empty NULL marker).
1206        assert_eq!(got[0][0], Datum::String(""));
1207        assert_eq!(got[0][1], Datum::String("x"));
1208        assert_eq!(got[1][0], Datum::String("y"));
1209        assert_eq!(got[1][1], Datum::String("z"));
1210    }
1211
1212    #[mz_ore::test]
1213    fn test_decode_copy_format_csv_quoted_null() {
1214        // PG COPY ... FORMAT CSV distinguishes quoted vs unquoted NULL
1215        // markers: unquoted → SQL NULL, quoted → the literal string.
1216        let column_types = vec![
1217            mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1218            mz_pgrepr::Type::from(&mz_repr::SqlScalarType::String),
1219        ];
1220
1221        // Lines as (col_a, col_b) literals, joined per-eol below. The quoted
1222        // vs unquoted distinction must survive CRLF/CR line endings, where
1223        // csv-core leaves an orphaned terminator byte at the start of every
1224        // non-first record.
1225        let cases: &[(CopyCsvFormatParams, &[(&str, &str)], &[[Option<&str>; 2]])] = &[
1226            // Default params: NULL marker is empty string.
1227            (
1228                CopyCsvFormatParams::default(),
1229                &[("a", ""), ("b", "\"\""), ("\"\"", "c")],
1230                &[
1231                    [Some("a"), None],
1232                    [Some("b"), Some("")],
1233                    [Some(""), Some("c")],
1234                ],
1235            ),
1236            // Custom NULL marker "NULL".
1237            (
1238                CopyCsvFormatParams {
1239                    null: Cow::from("NULL"),
1240                    ..Default::default()
1241                },
1242                &[("a", "NULL"), ("b", "\"NULL\""), ("NULL", "c")],
1243                &[
1244                    [Some("a"), None],
1245                    [Some("b"), Some("NULL")],
1246                    [None, Some("c")],
1247                ],
1248            ),
1249        ];
1250
1251        for eol in [&b"\n"[..], b"\r\n", b"\r"] {
1252            for (params, lines, expected) in cases {
1253                let mut input = Vec::new();
1254                for (a, b) in *lines {
1255                    input.extend_from_slice(a.as_bytes());
1256                    input.push(b',');
1257                    input.extend_from_slice(b.as_bytes());
1258                    input.extend_from_slice(eol);
1259                }
1260                let rows = decode_copy_format_csv(&input, &column_types, params.clone())
1261                    .expect("decode should succeed");
1262                assert_eq!(rows.len(), expected.len(), "eol={eol:?}");
1263                for (row, want) in rows.iter().zip_eq(expected.iter()) {
1264                    let got: Vec<Datum> = row.iter().collect();
1265                    assert_eq!(got.len(), 2);
1266                    for (g, w) in got.iter().zip_eq(want.iter()) {
1267                        match (g, w) {
1268                            (Datum::Null, None) => {}
1269                            (Datum::String(s), Some(w)) => assert_eq!(s, w, "eol={eol:?}"),
1270                            _ => panic!("mismatch: got {g:?}, want {w:?}, eol={eol:?}"),
1271                        }
1272                    }
1273                }
1274            }
1275        }
1276    }
1277
1278    proptest! {
1279        #[mz_ore::test]
1280        #[cfg_attr(miri, ignore)]
1281        fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams)  {
1282            // Given a SqlScalarType and Datum roundtrips it through the CSV COPY format.
1283            let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1284                let row = Row::pack_slice(&[datum]);
1285                let typ = SqlRelationType::new(vec![
1286                    SqlColumnType {
1287                        scalar_type: scalar_type.clone(),
1288                        nullable: true,
1289                    }
1290                ]);
1291
1292                let mut buf = Vec::new();
1293                let mut csv_params = copy_csv_params.clone();
1294                // TODO: Encoding never writes a header.
1295                csv_params.header = false;
1296                let params = CopyFormatParams::Csv(csv_params);
1297
1298                // Roundtrip the Row through our CSV format.
1299                encode_copy_format(&params, &row, &typ, &mut buf)?;
1300                let column_types = typ
1301                    .column_types
1302                    .iter()
1303                    .map(|x| &x.scalar_type)
1304                    .map(mz_pgrepr::Type::from)
1305                    .collect::<Vec<mz_pgrepr::Type>>();
1306                let result = decode_copy_format(&buf, &column_types, params);
1307
1308                match result {
1309                    Ok(rows) => {
1310                        let out_str = std::str::from_utf8(&buf[..]);
1311
1312                        prop_assert_eq!(
1313                            rows.len(),
1314                            1,
1315                            "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1316                        );
1317                        let output = rows.into_element();
1318
1319                        prop_assert_eq!(
1320                            row,
1321                            output,
1322                            "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1323                        );
1324                    }
1325                    _ => {
1326                        // ignoring decoding failures
1327                    }
1328                }
1329
1330                Ok(())
1331            };
1332
1333            // Try roundtripping all of our interesting Datums.
1334            for scalar_type in SqlScalarType::enumerate() {
1335                for datum in scalar_type.interesting_datums() {
1336                    // TODO: The decoder cannot differentiate between empty string and null.
1337                    if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1338                        let mut buf = bytes::BytesMut::new();
1339                        value.encode_text(&mut buf);
1340
1341                        if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1342                            if datum_str == copy_csv_params.null {
1343                                continue;
1344                            }
1345                        }
1346                    }
1347
1348                    let updated_datum = match datum {
1349                        // TODO: Fix roundtrip decoding of these types.
1350                        Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1351                            continue;
1352                        }
1353                        Datum::String(s) => {
1354                            // TODO: The decoder cannot differentiate between empty string and null.
1355                            if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1356                                continue;
1357                            } else {
1358                                Datum::String(s)
1359                            }
1360                        }
1361                        other => other,
1362                    };
1363
1364                    let result = try_roundtrip_datum(scalar_type, updated_datum);
1365                    prop_assert!(result.is_ok(), "failure: {result:?}");
1366                }
1367            }
1368        }
1369    }
1370}