mz_pgcopy/
copy.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use itertools::Itertools;
16use mz_repr::{
17    Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18    SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, any};
21use proptest::strategy::{BoxedStrategy, Strategy};
22use serde::{Deserialize, Serialize};
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26fn encode_copy_row_binary(
27    row: &RowRef,
28    typ: &SqlRelationType,
29    out: &mut Vec<u8>,
30) -> Result<(), io::Error> {
31    const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
32
33    // 16-bit int of number of tuples.
34    let count = i16::try_from(typ.column_types.len()).map_err(|_| {
35        io::Error::new(
36            io::ErrorKind::Other,
37            "column count does not fit into an i16",
38        )
39    })?;
40
41    out.extend(count.to_be_bytes());
42    let mut buf = BytesMut::new();
43    for (field, typ) in row
44        .iter()
45        .zip_eq(&typ.column_types)
46        .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
47    {
48        match field {
49            None => out.extend(NULL_BYTES),
50            Some(field) => {
51                buf.clear();
52                field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
53                out.extend(
54                    i32::try_from(buf.len())
55                        .map_err(|_| {
56                            io::Error::new(
57                                io::ErrorKind::Other,
58                                "field length does not fit into an i32",
59                            )
60                        })?
61                        .to_be_bytes(),
62                );
63                out.extend(&buf);
64            }
65        }
66    }
67    Ok(())
68}
69
70fn encode_copy_row_text(
71    CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
72    row: &RowRef,
73    typ: &SqlRelationType,
74    out: &mut Vec<u8>,
75) -> Result<(), io::Error> {
76    let null = null.as_bytes();
77    let mut buf = BytesMut::new();
78    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
79        if idx > 0 {
80            out.push(*delimiter);
81        }
82        match field {
83            None => out.extend(null),
84            Some(field) => {
85                buf.clear();
86                field.encode_text(&mut buf);
87                for b in &buf {
88                    match b {
89                        b'\\' => out.extend(b"\\\\"),
90                        b'\n' => out.extend(b"\\n"),
91                        b'\r' => out.extend(b"\\r"),
92                        b'\t' => out.extend(b"\\t"),
93                        _ => out.push(*b),
94                    }
95                }
96            }
97        }
98    }
99    out.push(b'\n');
100    Ok(())
101}
102
103fn encode_copy_row_csv(
104    CopyCsvFormatParams {
105        delimiter: delim,
106        quote,
107        escape,
108        header: _,
109        null,
110    }: &CopyCsvFormatParams,
111    row: &RowRef,
112    typ: &SqlRelationType,
113    out: &mut Vec<u8>,
114) -> Result<(), io::Error> {
115    let null = null.as_bytes();
116    let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
117    let mut buf = BytesMut::new();
118    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
119        if idx > 0 {
120            out.push(*delim);
121        }
122        match field {
123            None => out.extend(null),
124            Some(field) => {
125                buf.clear();
126                field.encode_text(&mut buf);
127                // A field needs quoting if:
128                //   * It is the only field and the value is exactly the end
129                //     of copy marker.
130                //   * The field contains a special character.
131                //   * The field is exactly the NULL sentinel.
132                if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
133                    || buf.iter().any(is_special)
134                    || &*buf == null
135                {
136                    // Quote the value by wrapping it in the quote character and
137                    // emitting the escape character before any quote or escape
138                    // characters within.
139                    out.push(*quote);
140                    for b in &buf {
141                        if *b == *quote || *b == *escape {
142                            out.push(*escape);
143                        }
144                        out.push(*b);
145                    }
146                    out.push(*quote);
147                } else {
148                    // The value does not need quoting and can be emitted
149                    // directly.
150                    out.extend(&buf);
151                }
152            }
153        }
154    }
155    out.push(b'\n');
156    Ok(())
157}
158
159pub struct CopyTextFormatParser<'a> {
160    data: &'a [u8],
161    position: usize,
162    column_delimiter: u8,
163    null_string: &'a str,
164    buffer: Vec<u8>,
165}
166
167impl<'a> CopyTextFormatParser<'a> {
168    pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
169        Self {
170            data,
171            position: 0,
172            column_delimiter,
173            null_string,
174            buffer: Vec::new(),
175        }
176    }
177
178    fn peek(&self) -> Option<u8> {
179        if self.position < self.data.len() {
180            Some(self.data[self.position])
181        } else {
182            None
183        }
184    }
185
186    fn consume_n(&mut self, n: usize) {
187        self.position = std::cmp::min(self.position + n, self.data.len());
188    }
189
190    pub fn is_eof(&self) -> bool {
191        self.peek().is_none() || self.is_end_of_copy_marker()
192    }
193
194    pub fn is_end_of_copy_marker(&self) -> bool {
195        self.check_bytes(END_OF_COPY_MARKER)
196    }
197
198    fn is_end_of_line(&self) -> bool {
199        match self.peek() {
200            Some(b'\n') | None => true,
201            _ => false,
202        }
203    }
204
205    pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
206        if self.is_end_of_line() {
207            self.consume_n(1);
208            Ok(())
209        } else {
210            Err(io::Error::new(
211                io::ErrorKind::InvalidData,
212                "extra data after last expected column",
213            ))
214        }
215    }
216
217    fn is_column_delimiter(&self) -> bool {
218        self.check_bytes(&[self.column_delimiter])
219    }
220
221    pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
222        if self.consume_bytes(&[self.column_delimiter]) {
223            Ok(())
224        } else {
225            Err(io::Error::new(
226                io::ErrorKind::InvalidData,
227                "missing data for column",
228            ))
229        }
230    }
231
232    fn check_bytes(&self, bytes: &[u8]) -> bool {
233        self.data
234            .get(self.position..self.position + bytes.len())
235            .map_or(false, |d| d == bytes)
236    }
237
238    fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
239        if self.check_bytes(bytes) {
240            self.consume_n(bytes.len());
241            true
242        } else {
243            false
244        }
245    }
246
247    fn consume_null_string(&mut self) -> bool {
248        if self.null_string.is_empty() {
249            // An empty NULL marker is supported. Look ahead to ensure that is followed by
250            // a column delimiter, an end of line or it is at the end of the data.
251            self.is_column_delimiter()
252                || self.is_end_of_line()
253                || self.is_end_of_copy_marker()
254                || self.is_eof()
255        } else {
256            self.consume_bytes(self.null_string.as_bytes())
257        }
258    }
259
260    pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
261        if self.consume_null_string() {
262            return Ok(None);
263        }
264
265        let mut start = self.position;
266
267        // buffer where unescaped data is accumulated
268        self.buffer.clear();
269
270        while !self.is_eof() && !self.is_end_of_copy_marker() {
271            if self.is_end_of_line() || self.is_column_delimiter() {
272                break;
273            }
274            match self.peek() {
275                Some(b'\\') => {
276                    // Add non-escaped data parsed so far
277                    self.buffer.extend(&self.data[start..self.position]);
278
279                    self.consume_n(1);
280                    match self.peek() {
281                        Some(b'b') => {
282                            self.consume_n(1);
283                            self.buffer.push(8);
284                        }
285                        Some(b'f') => {
286                            self.consume_n(1);
287                            self.buffer.push(12);
288                        }
289                        Some(b'n') => {
290                            self.consume_n(1);
291                            self.buffer.push(b'\n');
292                        }
293                        Some(b'r') => {
294                            self.consume_n(1);
295                            self.buffer.push(b'\r');
296                        }
297                        Some(b't') => {
298                            self.consume_n(1);
299                            self.buffer.push(b'\t');
300                        }
301                        Some(b'v') => {
302                            self.consume_n(1);
303                            self.buffer.push(11);
304                        }
305                        Some(b'x') => {
306                            self.consume_n(1);
307                            match self.peek() {
308                                Some(_c @ b'0'..=b'9')
309                                | Some(_c @ b'A'..=b'F')
310                                | Some(_c @ b'a'..=b'f') => {
311                                    let mut value: u8 = 0;
312                                    let decode_nibble = |b| match b {
313                                        Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
314                                        Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
315                                        Some(c @ b'0'..=b'9') => Some(c - b'0'),
316                                        _ => None,
317                                    };
318                                    for _ in 0..2 {
319                                        match decode_nibble(self.peek()) {
320                                            Some(c) => {
321                                                self.consume_n(1);
322                                                value = value << 4 | c;
323                                            }
324                                            _ => break,
325                                        }
326                                    }
327                                    self.buffer.push(value);
328                                }
329                                _ => {
330                                    self.buffer.push(b'x');
331                                }
332                            }
333                        }
334                        Some(_c @ b'0'..=b'7') => {
335                            let mut value: u8 = 0;
336                            for _ in 0..3 {
337                                match self.peek() {
338                                    Some(c @ b'0'..=b'7') => {
339                                        self.consume_n(1);
340                                        value = value << 3 | (c - b'0');
341                                    }
342                                    _ => break,
343                                }
344                            }
345                            self.buffer.push(value);
346                        }
347                        Some(c) => {
348                            self.consume_n(1);
349                            self.buffer.push(c);
350                        }
351                        None => {
352                            self.buffer.push(b'\\');
353                        }
354                    }
355
356                    start = self.position;
357                }
358                Some(_) => {
359                    self.consume_n(1);
360                }
361                None => {}
362            }
363        }
364
365        // Return a slice of the original buffer if no escaped characters where processed
366        if self.buffer.is_empty() {
367            Ok(Some(&self.data[start..self.position]))
368        } else {
369            // ... otherwise, add the remaining non-escaped data to the decoding buffer
370            // and return a pointer to it
371            self.buffer.extend(&self.data[start..self.position]);
372            Ok(Some(&self.buffer[..]))
373        }
374    }
375
376    /// Error if more than `num_columns` values in `parser`.
377    pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
378        RawIterator {
379            parser: self,
380            current_column: 0,
381            num_columns,
382            truncate: false,
383        }
384    }
385
386    /// Return no more than `num_columns` values from `parser`.
387    pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
388        RawIterator {
389            parser: self,
390            current_column: 0,
391            num_columns,
392            truncate: true,
393        }
394    }
395}
396
397pub struct RawIterator<'a> {
398    parser: CopyTextFormatParser<'a>,
399    current_column: usize,
400    num_columns: usize,
401    truncate: bool,
402}
403
404impl<'a> RawIterator<'a> {
405    pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
406        if self.current_column > self.num_columns {
407            return None;
408        }
409
410        if self.current_column == self.num_columns {
411            if !self.truncate {
412                if let Some(err) = self.parser.expect_end_of_line().err() {
413                    return Some(Err(err));
414                }
415            }
416
417            return None;
418        }
419
420        if self.current_column > 0 {
421            if let Some(err) = self.parser.expect_column_delimiter().err() {
422                return Some(Err(err));
423            }
424        }
425
426        self.current_column += 1;
427        Some(self.parser.consume_raw_value())
428    }
429}
430
431#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
432pub enum CopyFormatParams<'a> {
433    Text(CopyTextFormatParams<'a>),
434    Csv(CopyCsvFormatParams<'a>),
435    Binary,
436    Parquet,
437}
438
439impl CopyFormatParams<'static> {
440    pub fn file_extension(&self) -> &str {
441        match self {
442            &CopyFormatParams::Text(_) => "txt",
443            &CopyFormatParams::Csv(_) => "csv",
444            &CopyFormatParams::Binary => "bin",
445            &CopyFormatParams::Parquet => "parquet",
446        }
447    }
448
449    pub fn requires_header(&self) -> bool {
450        match self {
451            CopyFormatParams::Text(_) => false,
452            CopyFormatParams::Csv(params) => params.header,
453            CopyFormatParams::Binary => false,
454            CopyFormatParams::Parquet => false,
455        }
456    }
457}
458
459/// Decodes the given bytes into `Row`-s based on the given `CopyFormatParams`.
460pub fn decode_copy_format<'a>(
461    data: &[u8],
462    column_types: &[mz_pgrepr::Type],
463    params: CopyFormatParams<'a>,
464) -> Result<Vec<Row>, io::Error> {
465    match params {
466        CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
467        CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
468        CopyFormatParams::Binary => Err(io::Error::new(
469            io::ErrorKind::Unsupported,
470            "cannot decode as binary format",
471        )),
472        CopyFormatParams::Parquet => {
473            // TODO(cf2): Support Parquet over STDIN.
474            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
475        }
476    }
477}
478
479/// Encodes the given `Row` into bytes based on the given `CopyFormatParams`.
480pub fn encode_copy_format<'a>(
481    params: &CopyFormatParams<'a>,
482    row: &RowRef,
483    typ: &SqlRelationType,
484    out: &mut Vec<u8>,
485) -> Result<(), io::Error> {
486    match params {
487        CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
488        CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
489        CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
490        CopyFormatParams::Parquet => {
491            // TODO(cf2): Support Parquet over STDIN.
492            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
493        }
494    }
495}
496
497pub fn encode_copy_format_header<'a>(
498    params: &CopyFormatParams<'a>,
499    desc: &RelationDesc,
500    out: &mut Vec<u8>,
501) -> Result<(), io::Error> {
502    match params {
503        CopyFormatParams::Text(_) => Ok(()),
504        CopyFormatParams::Binary => Ok(()),
505        CopyFormatParams::Csv(params) => {
506            let mut header_row = Row::with_capacity(desc.arity());
507            header_row
508                .packer()
509                .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
510            let typ = SqlRelationType::new(vec![
511                SqlColumnType {
512                    scalar_type: SqlScalarType::String,
513                    nullable: false,
514                };
515                desc.arity()
516            ]);
517            encode_copy_row_csv(params, &header_row, &typ, out)
518        }
519        CopyFormatParams::Parquet => {
520            // TODO(cf2): Support Parquet over STDIN.
521            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
522        }
523    }
524}
525
526#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
527pub struct CopyTextFormatParams<'a> {
528    pub null: Cow<'a, str>,
529    pub delimiter: u8,
530}
531
532impl<'a> Default for CopyTextFormatParams<'a> {
533    fn default() -> Self {
534        CopyTextFormatParams {
535            delimiter: b'\t',
536            null: Cow::from("\\N"),
537        }
538    }
539}
540
541pub fn decode_copy_format_text(
542    data: &[u8],
543    column_types: &[mz_pgrepr::Type],
544    CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
545) -> Result<Vec<Row>, io::Error> {
546    let mut rows = Vec::new();
547
548    // TODO: pass the `CopyTextFormatParams` to the `new` method
549    let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
550    while !parser.is_eof() && !parser.is_end_of_copy_marker() {
551        let mut row = Vec::new();
552        let buf = RowArena::new();
553        for (col, typ) in column_types.iter().enumerate() {
554            if col > 0 {
555                parser.expect_column_delimiter()?;
556            }
557            let raw_value = parser.consume_raw_value()?;
558            if let Some(raw_value) = raw_value {
559                match mz_pgrepr::Value::decode_text(typ, raw_value) {
560                    Ok(value) => row.push(value.into_datum(&buf, typ)),
561                    Err(err) => {
562                        let msg = format!("unable to decode column: {}", err);
563                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
564                    }
565                }
566            } else {
567                row.push(Datum::Null);
568            }
569        }
570        parser.expect_end_of_line()?;
571        rows.push(Row::pack(row));
572    }
573    // Note that if there is any junk data after the end of copy marker, we drop
574    // it on the floor as PG does.
575    Ok(rows)
576}
577
578#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
579pub struct CopyCsvFormatParams<'a> {
580    pub delimiter: u8,
581    pub quote: u8,
582    pub escape: u8,
583    pub header: bool,
584    pub null: Cow<'a, str>,
585}
586
587impl<'a> CopyCsvFormatParams<'a> {
588    pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
589        CopyCsvFormatParams {
590            delimiter: self.delimiter,
591            quote: self.quote,
592            escape: self.escape,
593            header: self.header,
594            null: Cow::Owned(self.null.to_string()),
595        }
596    }
597}
598
599impl Arbitrary for CopyCsvFormatParams<'static> {
600    type Parameters = ();
601    type Strategy = BoxedStrategy<Self>;
602
603    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
604        (
605            any::<u8>(),
606            any::<u8>(),
607            any::<u8>(),
608            any::<bool>(),
609            any::<String>(),
610        )
611            .prop_map(|(delimiter, diff, escape, header, null)| {
612                // Delimiter and Quote need to be different.
613                let diff = diff.saturating_sub(1).max(1);
614                let quote = delimiter.wrapping_add(diff);
615
616                Self::try_new(
617                    Some(delimiter),
618                    Some(quote),
619                    Some(escape),
620                    Some(header),
621                    Some(null),
622                )
623                .expect("delimiter and quote should be different")
624            })
625            .boxed()
626    }
627}
628
629impl<'a> Default for CopyCsvFormatParams<'a> {
630    fn default() -> Self {
631        CopyCsvFormatParams {
632            delimiter: b',',
633            quote: b'"',
634            escape: b'"',
635            header: false,
636            null: Cow::from(""),
637        }
638    }
639}
640
641impl<'a> CopyCsvFormatParams<'a> {
642    pub fn try_new(
643        delimiter: Option<u8>,
644        quote: Option<u8>,
645        escape: Option<u8>,
646        header: Option<bool>,
647        null: Option<String>,
648    ) -> Result<CopyCsvFormatParams<'a>, String> {
649        let mut params = CopyCsvFormatParams::default();
650
651        if let Some(delimiter) = delimiter {
652            params.delimiter = delimiter;
653        }
654        if let Some(quote) = quote {
655            params.quote = quote;
656            // escape defaults to the value provided for quote
657            params.escape = quote;
658        }
659        if let Some(escape) = escape {
660            params.escape = escape;
661        }
662        if let Some(header) = header {
663            params.header = header;
664        }
665        if let Some(null) = null {
666            params.null = Cow::from(null);
667        }
668
669        if params.quote == params.delimiter {
670            return Err("COPY delimiter and quote must be different".to_string());
671        }
672        Ok(params)
673    }
674}
675
676pub fn decode_copy_format_csv(
677    data: &[u8],
678    column_types: &[mz_pgrepr::Type],
679    CopyCsvFormatParams {
680        delimiter,
681        quote,
682        escape,
683        null,
684        header,
685    }: CopyCsvFormatParams,
686) -> Result<Vec<Row>, io::Error> {
687    let mut rows = Vec::new();
688
689    let (double_quote, escape) = if quote == escape {
690        (true, None)
691    } else {
692        (false, Some(escape))
693    };
694
695    let mut rdr = ReaderBuilder::new()
696        .delimiter(delimiter)
697        .quote(quote)
698        .has_headers(header)
699        .double_quote(double_quote)
700        .escape(escape)
701        // Must be flexible to accept end of copy marker, which will always be 1
702        // field.
703        .flexible(true)
704        .from_reader(data);
705
706    let null_as_bytes = null.as_bytes();
707
708    let mut record = ByteRecord::new();
709
710    while rdr.read_byte_record(&mut record)? {
711        if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
712            break;
713        }
714
715        match record.len().cmp(&column_types.len()) {
716            std::cmp::Ordering::Less => Err(io::Error::new(
717                io::ErrorKind::InvalidData,
718                "missing data for column",
719            )),
720            std::cmp::Ordering::Greater => Err(io::Error::new(
721                io::ErrorKind::InvalidData,
722                "extra data after last expected column",
723            )),
724            std::cmp::Ordering::Equal => Ok(()),
725        }?;
726
727        let mut row_builder = SharedRow::get();
728        let mut row_packer = row_builder.packer();
729
730        for (typ, raw_value) in column_types.iter().zip_eq(record.iter()) {
731            if raw_value == null_as_bytes {
732                row_packer.push(Datum::Null);
733            } else {
734                let s = match std::str::from_utf8(raw_value) {
735                    Ok(s) => s,
736                    Err(err) => {
737                        let msg = format!("invalid utf8 data in column: {}", err);
738                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
739                    }
740                };
741                match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
742                    Ok(()) => {}
743                    Err(err) => {
744                        let msg = format!("unable to decode column: {}", err);
745                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
746                    }
747                }
748            }
749        }
750        rows.push(row_builder.clone());
751    }
752
753    Ok(rows)
754}
755
756#[cfg(test)]
757mod tests {
758    use mz_ore::collections::CollectionExt;
759    use mz_repr::SqlColumnType;
760    use proptest::prelude::*;
761
762    use super::*;
763
764    #[mz_ore::test]
765    fn test_copy_format_text_parser() {
766        let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
767        let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
768        assert!(parser.is_column_delimiter());
769        parser
770            .expect_column_delimiter()
771            .expect("expected column delimiter");
772        assert_eq!(
773            parser
774                .consume_raw_value()
775                .expect("unexpected error")
776                .expect("unexpected empty result"),
777            "\nt e".as_bytes()
778        );
779        parser
780            .expect_column_delimiter()
781            .expect("expected column delimiter");
782        // null value
783        assert!(
784            parser
785                .consume_raw_value()
786                .expect("unexpected error")
787                .is_none()
788        );
789        parser
790            .expect_column_delimiter()
791            .expect("expected column delimiter");
792        assert!(parser.is_end_of_line());
793        parser.expect_end_of_line().expect("expected eol");
794        // hex value
795        assert_eq!(
796            parser
797                .consume_raw_value()
798                .expect("unexpected error")
799                .expect("unexpected empty result"),
800            "`\n}J".as_bytes()
801        );
802        parser.expect_end_of_line().expect("expected eol");
803        // octal value
804        assert_eq!(
805            parser
806                .consume_raw_value()
807                .expect("unexpected error")
808                .expect("unexpected empty result"),
809            "$$S".as_bytes()
810        );
811        assert!(parser.is_eof());
812    }
813
814    #[mz_ore::test]
815    fn test_copy_format_text_empty_null_string() {
816        let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
817        let expect = vec![
818            vec![None, None],
819            vec![Some("10"), Some("20")],
820            vec![Some("30"), None],
821            vec![Some("40"), None],
822        ];
823        let mut parser = CopyTextFormatParser::new(text, b'\t', "");
824        for line in expect {
825            for (i, value) in line.iter().enumerate() {
826                if i > 0 {
827                    parser
828                        .expect_column_delimiter()
829                        .expect("expected column delimiter");
830                }
831                match value {
832                    Some(s) => {
833                        assert!(!parser.consume_null_string());
834                        assert_eq!(
835                            parser
836                                .consume_raw_value()
837                                .expect("unexpected error")
838                                .expect("unexpected empty result"),
839                            s.as_bytes()
840                        );
841                    }
842                    None => {
843                        assert!(parser.consume_null_string());
844                    }
845                }
846            }
847            parser.expect_end_of_line().expect("expected eol");
848        }
849    }
850
851    #[mz_ore::test]
852    fn test_copy_format_text_parser_escapes() {
853        struct TestCase {
854            input: &'static str,
855            expect: &'static [u8],
856        }
857        let tests = vec![
858            TestCase {
859                input: "simple",
860                expect: b"simple",
861            },
862            TestCase {
863                input: r#"new\nline"#,
864                expect: b"new\nline",
865            },
866            TestCase {
867                input: r#"\b\f\n\r\t\v\\"#,
868                expect: b"\x08\x0c\n\r\t\x0b\\",
869            },
870            TestCase {
871                input: r#"\0\12\123"#,
872                expect: &[0, 0o12, 0o123],
873            },
874            TestCase {
875                input: r#"\x1\xaf"#,
876                expect: &[0x01, 0xaf],
877            },
878            TestCase {
879                input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
880                expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
881            },
882            TestCase {
883                input: r#"\\\""#,
884                expect: b"\\\"",
885            },
886            TestCase {
887                input: r#"\x"#,
888                expect: b"x",
889            },
890            TestCase {
891                input: r#"\xg"#,
892                expect: b"xg",
893            },
894            TestCase {
895                input: r#"\"#,
896                expect: b"\\",
897            },
898            TestCase {
899                input: r#"\8"#,
900                expect: b"8",
901            },
902            TestCase {
903                input: r#"\a"#,
904                expect: b"a",
905            },
906            TestCase {
907                input: r#"\x\xg\8\xH\x32\s\"#,
908                expect: b"xxg8xH2s\\",
909            },
910        ];
911
912        for test in tests {
913            let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
914            assert_eq!(
915                parser
916                    .consume_raw_value()
917                    .expect("unexpected error")
918                    .expect("unexpected empty result"),
919                test.expect,
920                "input: {}, expect: {:?}",
921                test.input,
922                std::str::from_utf8(test.expect),
923            );
924            assert!(parser.is_eof());
925        }
926    }
927
928    #[mz_ore::test]
929    fn test_copy_csv_format_params() {
930        assert_eq!(
931            CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
932            Ok(CopyCsvFormatParams {
933                delimiter: b't',
934                quote: b'q',
935                escape: b'q',
936                header: false,
937                null: Cow::from(""),
938            })
939        );
940
941        assert_eq!(
942            CopyCsvFormatParams::try_new(
943                Some(b't'),
944                Some(b'q'),
945                Some(b'e'),
946                Some(true),
947                Some("null".to_string())
948            ),
949            Ok(CopyCsvFormatParams {
950                delimiter: b't',
951                quote: b'q',
952                escape: b'e',
953                header: true,
954                null: Cow::from("null"),
955            })
956        );
957
958        assert_eq!(
959            CopyCsvFormatParams::try_new(
960                None,
961                Some(b','),
962                Some(b'e'),
963                Some(true),
964                Some("null".to_string())
965            ),
966            Err("COPY delimiter and quote must be different".to_string())
967        );
968    }
969
970    #[mz_ore::test]
971    fn test_copy_csv_row() -> Result<(), io::Error> {
972        let mut row = Row::default();
973        let mut packer = row.packer();
974        packer.push(Datum::from("1,2,\"3\""));
975        packer.push(Datum::Null);
976        packer.push(Datum::from(1000u64));
977        packer.push(Datum::from("qe")); // overridden quote and escape character in test below
978        packer.push(Datum::from(""));
979
980        let typ: SqlRelationType = SqlRelationType::new(vec![
981            SqlColumnType {
982                scalar_type: mz_repr::SqlScalarType::String,
983                nullable: false,
984            },
985            SqlColumnType {
986                scalar_type: mz_repr::SqlScalarType::String,
987                nullable: true,
988            },
989            SqlColumnType {
990                scalar_type: mz_repr::SqlScalarType::UInt64,
991                nullable: false,
992            },
993            SqlColumnType {
994                scalar_type: mz_repr::SqlScalarType::String,
995                nullable: false,
996            },
997            SqlColumnType {
998                scalar_type: mz_repr::SqlScalarType::String,
999                nullable: false,
1000            },
1001        ]);
1002
1003        let mut out = Vec::new();
1004
1005        struct TestCase<'a> {
1006            params: CopyCsvFormatParams<'a>,
1007            expected: &'static [u8],
1008        }
1009
1010        let tests = [
1011            TestCase {
1012                params: CopyCsvFormatParams::default(),
1013                expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1014            },
1015            TestCase {
1016                params: CopyCsvFormatParams {
1017                    null: Cow::from("NULL"),
1018                    quote: b'q',
1019                    escape: b'e',
1020                    ..Default::default()
1021                },
1022                expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1023            },
1024        ];
1025
1026        for TestCase { params, expected } in tests {
1027            out.clear();
1028            let params = CopyFormatParams::Csv(params);
1029            let _ = encode_copy_format(&params, &row, &typ, &mut out);
1030            let output = std::str::from_utf8(&out);
1031            assert_eq!(output, std::str::from_utf8(expected));
1032        }
1033
1034        Ok(())
1035    }
1036
1037    proptest! {
1038        #[mz_ore::test]
1039        #[cfg_attr(miri, ignore)]
1040        fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams)  {
1041            // Given a SqlScalarType and Datum roundtrips it through the CSV COPY format.
1042            let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1043                let row = Row::pack_slice(&[datum]);
1044                let typ = SqlRelationType::new(vec![
1045                    SqlColumnType {
1046                        scalar_type: scalar_type.clone(),
1047                        nullable: true,
1048                    }
1049                ]);
1050
1051                let mut buf = Vec::new();
1052                let mut csv_params = copy_csv_params.clone();
1053                // TODO: Encoding never writes a header.
1054                csv_params.header = false;
1055                let params = CopyFormatParams::Csv(csv_params);
1056
1057                // Roundtrip the Row through our CSV format.
1058                encode_copy_format(&params, &row, &typ, &mut buf)?;
1059                let column_types = typ
1060                    .column_types
1061                    .iter()
1062                    .map(|x| &x.scalar_type)
1063                    .map(mz_pgrepr::Type::from)
1064                    .collect::<Vec<mz_pgrepr::Type>>();
1065                let result = decode_copy_format(&buf, &column_types, params);
1066
1067                match result {
1068                    Ok(rows) => {
1069                        let out_str = std::str::from_utf8(&buf[..]);
1070
1071                        prop_assert_eq!(
1072                            rows.len(),
1073                            1,
1074                            "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1075                        );
1076                        let output = rows.into_element();
1077
1078                        prop_assert_eq!(
1079                            row,
1080                            output,
1081                            "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1082                        );
1083                    }
1084                    _ => {
1085                        // ignoring decoding failures
1086                    }
1087                }
1088
1089                Ok(())
1090            };
1091
1092            // Try roundtripping all of our interesting Datums.
1093            for scalar_type in SqlScalarType::enumerate() {
1094                for datum in scalar_type.interesting_datums() {
1095                    // TODO: The decoder cannot differentiate between empty string and null.
1096                    if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1097                        let mut buf = bytes::BytesMut::new();
1098                        value.encode_text(&mut buf);
1099
1100                        if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1101                            if datum_str == copy_csv_params.null {
1102                                continue;
1103                            }
1104                        }
1105                    }
1106
1107                    let updated_datum = match datum {
1108                        // TODO: Fix roundtrip decoding of these types.
1109                        Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1110                            continue;
1111                        }
1112                        Datum::String(s) => {
1113                            // TODO: The decoder cannot differentiate between empty string and null.
1114                            if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1115                                continue;
1116                            } else {
1117                                Datum::String(s)
1118                            }
1119                        }
1120                        other => other,
1121                    };
1122
1123                    let result = try_roundtrip_datum(scalar_type, updated_datum);
1124                    prop_assert!(result.is_ok(), "failure: {result:?}");
1125                }
1126            }
1127        }
1128    }
1129}