Skip to main content

mz_pgcopy/
copy.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::borrow::Cow;
11use std::io;
12
13use bytes::BytesMut;
14use csv::{ByteRecord, ReaderBuilder};
15use itertools::Itertools;
16use mz_repr::{
17    Datum, RelationDesc, Row, RowArena, RowRef, SharedRow, SqlColumnType, SqlRelationType,
18    SqlScalarType,
19};
20use proptest::prelude::{Arbitrary, any};
21use proptest::strategy::{BoxedStrategy, Strategy};
22use serde::{Deserialize, Serialize};
23
24static END_OF_COPY_MARKER: &[u8] = b"\\.";
25
26fn encode_copy_row_binary(
27    row: &RowRef,
28    typ: &SqlRelationType,
29    out: &mut Vec<u8>,
30) -> Result<(), io::Error> {
31    const NULL_BYTES: [u8; 4] = (-1i32).to_be_bytes();
32
33    // 16-bit int of number of tuples.
34    let count = i16::try_from(typ.column_types.len()).map_err(|_| {
35        io::Error::new(
36            io::ErrorKind::InvalidData,
37            "column count does not fit into an i16",
38        )
39    })?;
40
41    out.extend(count.to_be_bytes());
42    let mut buf = BytesMut::new();
43    for (field, typ) in row
44        .iter()
45        .zip_eq(&typ.column_types)
46        .map(|(datum, typ)| (mz_pgrepr::Value::from_datum(datum, &typ.scalar_type), typ))
47    {
48        match field {
49            None => out.extend(NULL_BYTES),
50            Some(field) => {
51                buf.clear();
52                field.encode_binary(&mz_pgrepr::Type::from(&typ.scalar_type), &mut buf)?;
53                out.extend(
54                    i32::try_from(buf.len())
55                        .map_err(|_| {
56                            io::Error::new(
57                                io::ErrorKind::InvalidData,
58                                "field length does not fit into an i32",
59                            )
60                        })?
61                        .to_be_bytes(),
62                );
63                out.extend(&buf);
64            }
65        }
66    }
67    Ok(())
68}
69
70fn encode_copy_row_text(
71    CopyTextFormatParams { null, delimiter }: &CopyTextFormatParams,
72    row: &RowRef,
73    typ: &SqlRelationType,
74    out: &mut Vec<u8>,
75) -> Result<(), io::Error> {
76    let null = null.as_bytes();
77    let mut buf = BytesMut::new();
78    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
79        if idx > 0 {
80            out.push(*delimiter);
81        }
82        match field {
83            None => out.extend(null),
84            Some(field) => {
85                buf.clear();
86                field.encode_text(&mut buf);
87                for b in &buf {
88                    match b {
89                        b'\\' => out.extend(b"\\\\"),
90                        b'\n' => out.extend(b"\\n"),
91                        b'\r' => out.extend(b"\\r"),
92                        b'\t' => out.extend(b"\\t"),
93                        _ => out.push(*b),
94                    }
95                }
96            }
97        }
98    }
99    out.push(b'\n');
100    Ok(())
101}
102
103fn encode_copy_row_csv(
104    CopyCsvFormatParams {
105        delimiter: delim,
106        quote,
107        escape,
108        header: _,
109        null,
110    }: &CopyCsvFormatParams,
111    row: &RowRef,
112    typ: &SqlRelationType,
113    out: &mut Vec<u8>,
114) -> Result<(), io::Error> {
115    let null = null.as_bytes();
116    let is_special = |c: &u8| *c == *delim || *c == *quote || *c == b'\r' || *c == b'\n';
117    let mut buf = BytesMut::new();
118    for (idx, field) in mz_pgrepr::values_from_row(row, typ).into_iter().enumerate() {
119        if idx > 0 {
120            out.push(*delim);
121        }
122        match field {
123            None => out.extend(null),
124            Some(field) => {
125                buf.clear();
126                field.encode_text(&mut buf);
127                // A field needs quoting if:
128                //   * It is the only field and the value is exactly the end
129                //     of copy marker.
130                //   * The field contains a special character.
131                //   * The field is exactly the NULL sentinel.
132                if (typ.column_types.len() == 1 && buf == END_OF_COPY_MARKER)
133                    || buf.iter().any(is_special)
134                    || &*buf == null
135                {
136                    // Quote the value by wrapping it in the quote character and
137                    // emitting the escape character before any quote or escape
138                    // characters within.
139                    out.push(*quote);
140                    for b in &buf {
141                        if *b == *quote || *b == *escape {
142                            out.push(*escape);
143                        }
144                        out.push(*b);
145                    }
146                    out.push(*quote);
147                } else {
148                    // The value does not need quoting and can be emitted
149                    // directly.
150                    out.extend(&buf);
151                }
152            }
153        }
154    }
155    out.push(b'\n');
156    Ok(())
157}
158
159pub struct CopyTextFormatParser<'a> {
160    data: &'a [u8],
161    position: usize,
162    column_delimiter: u8,
163    null_string: &'a str,
164    buffer: Vec<u8>,
165}
166
167impl<'a> CopyTextFormatParser<'a> {
168    pub fn new(data: &'a [u8], column_delimiter: u8, null_string: &'a str) -> Self {
169        Self {
170            data,
171            position: 0,
172            column_delimiter,
173            null_string,
174            buffer: Vec::new(),
175        }
176    }
177
178    fn peek(&self) -> Option<u8> {
179        if self.position < self.data.len() {
180            Some(self.data[self.position])
181        } else {
182            None
183        }
184    }
185
186    fn consume_n(&mut self, n: usize) {
187        self.position = std::cmp::min(self.position + n, self.data.len());
188    }
189
190    pub fn is_eof(&self) -> bool {
191        self.peek().is_none() || self.is_end_of_copy_marker()
192    }
193
194    pub fn is_end_of_copy_marker(&self) -> bool {
195        self.check_bytes(END_OF_COPY_MARKER)
196    }
197
198    fn is_end_of_line(&self) -> bool {
199        match self.peek() {
200            Some(b'\n') | None => true,
201            _ => false,
202        }
203    }
204
205    pub fn expect_end_of_line(&mut self) -> Result<(), io::Error> {
206        if self.is_end_of_line() {
207            self.consume_n(1);
208            Ok(())
209        } else {
210            Err(io::Error::new(
211                io::ErrorKind::InvalidData,
212                "extra data after last expected column",
213            ))
214        }
215    }
216
217    fn is_column_delimiter(&self) -> bool {
218        self.check_bytes(&[self.column_delimiter])
219    }
220
221    pub fn expect_column_delimiter(&mut self) -> Result<(), io::Error> {
222        if self.consume_bytes(&[self.column_delimiter]) {
223            Ok(())
224        } else {
225            Err(io::Error::new(
226                io::ErrorKind::InvalidData,
227                "missing data for column",
228            ))
229        }
230    }
231
232    fn check_bytes(&self, bytes: &[u8]) -> bool {
233        self.data
234            .get(self.position..self.position + bytes.len())
235            .map_or(false, |d| d == bytes)
236    }
237
238    fn consume_bytes(&mut self, bytes: &[u8]) -> bool {
239        if self.check_bytes(bytes) {
240            self.consume_n(bytes.len());
241            true
242        } else {
243            false
244        }
245    }
246
247    fn consume_null_string(&mut self) -> bool {
248        if self.null_string.is_empty() {
249            // An empty NULL marker is supported. Look ahead to ensure that is followed by
250            // a column delimiter, an end of line or it is at the end of the data.
251            self.is_column_delimiter()
252                || self.is_end_of_line()
253                || self.is_end_of_copy_marker()
254                || self.is_eof()
255        } else {
256            self.consume_bytes(self.null_string.as_bytes())
257        }
258    }
259
260    pub fn consume_raw_value(&mut self) -> Result<Option<&[u8]>, io::Error> {
261        if self.consume_null_string() {
262            return Ok(None);
263        }
264
265        let mut start = self.position;
266
267        // buffer where unescaped data is accumulated
268        self.buffer.clear();
269
270        while !self.is_eof() && !self.is_end_of_copy_marker() {
271            if self.is_end_of_line() || self.is_column_delimiter() {
272                break;
273            }
274            match self.peek() {
275                Some(b'\\') => {
276                    // Add non-escaped data parsed so far
277                    self.buffer.extend(&self.data[start..self.position]);
278
279                    self.consume_n(1);
280                    match self.peek() {
281                        Some(b'b') => {
282                            self.consume_n(1);
283                            self.buffer.push(8);
284                        }
285                        Some(b'f') => {
286                            self.consume_n(1);
287                            self.buffer.push(12);
288                        }
289                        Some(b'n') => {
290                            self.consume_n(1);
291                            self.buffer.push(b'\n');
292                        }
293                        Some(b'r') => {
294                            self.consume_n(1);
295                            self.buffer.push(b'\r');
296                        }
297                        Some(b't') => {
298                            self.consume_n(1);
299                            self.buffer.push(b'\t');
300                        }
301                        Some(b'v') => {
302                            self.consume_n(1);
303                            self.buffer.push(11);
304                        }
305                        Some(b'x') => {
306                            self.consume_n(1);
307                            match self.peek() {
308                                Some(_c @ b'0'..=b'9')
309                                | Some(_c @ b'A'..=b'F')
310                                | Some(_c @ b'a'..=b'f') => {
311                                    let mut value: u8 = 0;
312                                    let decode_nibble = |b| match b {
313                                        Some(c @ b'a'..=b'f') => Some(c - b'a' + 10),
314                                        Some(c @ b'A'..=b'F') => Some(c - b'A' + 10),
315                                        Some(c @ b'0'..=b'9') => Some(c - b'0'),
316                                        _ => None,
317                                    };
318                                    for _ in 0..2 {
319                                        match decode_nibble(self.peek()) {
320                                            Some(c) => {
321                                                self.consume_n(1);
322                                                value = value << 4 | c;
323                                            }
324                                            _ => break,
325                                        }
326                                    }
327                                    self.buffer.push(value);
328                                }
329                                _ => {
330                                    self.buffer.push(b'x');
331                                }
332                            }
333                        }
334                        Some(_c @ b'0'..=b'7') => {
335                            let mut value: u8 = 0;
336                            for _ in 0..3 {
337                                match self.peek() {
338                                    Some(c @ b'0'..=b'7') => {
339                                        self.consume_n(1);
340                                        value = value << 3 | (c - b'0');
341                                    }
342                                    _ => break,
343                                }
344                            }
345                            self.buffer.push(value);
346                        }
347                        Some(c) => {
348                            self.consume_n(1);
349                            self.buffer.push(c);
350                        }
351                        None => {
352                            self.buffer.push(b'\\');
353                        }
354                    }
355
356                    start = self.position;
357                }
358                Some(_) => {
359                    self.consume_n(1);
360                }
361                None => {}
362            }
363        }
364
365        // Return a slice of the original buffer if no escaped characters where processed
366        if self.buffer.is_empty() {
367            Ok(Some(&self.data[start..self.position]))
368        } else {
369            // ... otherwise, add the remaining non-escaped data to the decoding buffer
370            // and return a pointer to it
371            self.buffer.extend(&self.data[start..self.position]);
372            Ok(Some(&self.buffer[..]))
373        }
374    }
375
376    /// Error if more than `num_columns` values in `parser`.
377    pub fn iter_raw(self, num_columns: usize) -> RawIterator<'a> {
378        RawIterator {
379            parser: self,
380            current_column: 0,
381            num_columns,
382            truncate: false,
383        }
384    }
385
386    /// Return no more than `num_columns` values from `parser`.
387    pub fn iter_raw_truncating(self, num_columns: usize) -> RawIterator<'a> {
388        RawIterator {
389            parser: self,
390            current_column: 0,
391            num_columns,
392            truncate: true,
393        }
394    }
395}
396
397pub struct RawIterator<'a> {
398    parser: CopyTextFormatParser<'a>,
399    current_column: usize,
400    num_columns: usize,
401    truncate: bool,
402}
403
404impl<'a> RawIterator<'a> {
405    pub fn next(&mut self) -> Option<Result<Option<&[u8]>, io::Error>> {
406        if self.current_column > self.num_columns {
407            return None;
408        }
409
410        if self.current_column == self.num_columns {
411            if !self.truncate {
412                if let Some(err) = self.parser.expect_end_of_line().err() {
413                    return Some(Err(err));
414                }
415            }
416
417            return None;
418        }
419
420        if self.current_column > 0 {
421            if let Some(err) = self.parser.expect_column_delimiter().err() {
422                return Some(Err(err));
423            }
424        }
425
426        self.current_column += 1;
427        Some(self.parser.consume_raw_value())
428    }
429}
430
431#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
432pub enum CopyFormatParams<'a> {
433    Text(CopyTextFormatParams<'a>),
434    Csv(CopyCsvFormatParams<'a>),
435    Binary,
436    Parquet,
437}
438
439impl CopyFormatParams<'static> {
440    pub fn file_extension(&self) -> &str {
441        match self {
442            &CopyFormatParams::Text(_) => "txt",
443            &CopyFormatParams::Csv(_) => "csv",
444            &CopyFormatParams::Binary => "bin",
445            &CopyFormatParams::Parquet => "parquet",
446        }
447    }
448
449    pub fn requires_header(&self) -> bool {
450        match self {
451            CopyFormatParams::Text(_) => false,
452            CopyFormatParams::Csv(params) => params.header,
453            CopyFormatParams::Binary => false,
454            CopyFormatParams::Parquet => false,
455        }
456    }
457}
458
459/// Decodes the given bytes into `Row`-s based on the given `CopyFormatParams`.
460pub fn decode_copy_format<'a>(
461    data: &[u8],
462    column_types: &[mz_pgrepr::Type],
463    params: CopyFormatParams<'a>,
464) -> Result<Vec<Row>, io::Error> {
465    match params {
466        CopyFormatParams::Text(params) => decode_copy_format_text(data, column_types, params),
467        CopyFormatParams::Csv(params) => decode_copy_format_csv(data, column_types, params),
468        CopyFormatParams::Binary => Err(io::Error::new(
469            io::ErrorKind::Unsupported,
470            "cannot decode as binary format",
471        )),
472        CopyFormatParams::Parquet => {
473            // TODO(cf2): Support Parquet over STDIN.
474            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
475        }
476    }
477}
478
479/// Encodes the given `Row` into bytes based on the given `CopyFormatParams`.
480pub fn encode_copy_format<'a>(
481    params: &CopyFormatParams<'a>,
482    row: &RowRef,
483    typ: &SqlRelationType,
484    out: &mut Vec<u8>,
485) -> Result<(), io::Error> {
486    match params {
487        CopyFormatParams::Text(params) => encode_copy_row_text(params, row, typ, out),
488        CopyFormatParams::Csv(params) => encode_copy_row_csv(params, row, typ, out),
489        CopyFormatParams::Binary => encode_copy_row_binary(row, typ, out),
490        CopyFormatParams::Parquet => {
491            // TODO(cf2): Support Parquet over STDIN.
492            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
493        }
494    }
495}
496
497pub fn encode_copy_format_header<'a>(
498    params: &CopyFormatParams<'a>,
499    desc: &RelationDesc,
500    out: &mut Vec<u8>,
501) -> Result<(), io::Error> {
502    match params {
503        CopyFormatParams::Text(_) => Ok(()),
504        CopyFormatParams::Binary => Ok(()),
505        CopyFormatParams::Csv(params) => {
506            let mut header_row = Row::with_capacity(desc.arity());
507            header_row
508                .packer()
509                .extend(desc.iter_names().map(|s| Datum::from(s.as_str())));
510            let typ = SqlRelationType::new(vec![
511                SqlColumnType {
512                    scalar_type: SqlScalarType::String,
513                    nullable: false,
514                };
515                desc.arity()
516            ]);
517            encode_copy_row_csv(params, &header_row, &typ, out)
518        }
519        CopyFormatParams::Parquet => {
520            // TODO(cf2): Support Parquet over STDIN.
521            Err(io::Error::new(io::ErrorKind::Unsupported, "parquet format"))
522        }
523    }
524}
525
526#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
527pub struct CopyTextFormatParams<'a> {
528    pub null: Cow<'a, str>,
529    pub delimiter: u8,
530}
531
532impl<'a> Default for CopyTextFormatParams<'a> {
533    fn default() -> Self {
534        CopyTextFormatParams {
535            delimiter: b'\t',
536            null: Cow::from("\\N"),
537        }
538    }
539}
540
541pub fn decode_copy_format_text(
542    data: &[u8],
543    column_types: &[mz_pgrepr::Type],
544    CopyTextFormatParams { null, delimiter }: CopyTextFormatParams,
545) -> Result<Vec<Row>, io::Error> {
546    let mut rows = Vec::new();
547
548    // TODO: pass the `CopyTextFormatParams` to the `new` method
549    let mut parser = CopyTextFormatParser::new(data, delimiter, &null);
550    while !parser.is_eof() && !parser.is_end_of_copy_marker() {
551        let mut row = Vec::new();
552        let buf = RowArena::new();
553        for (col, typ) in column_types.iter().enumerate() {
554            if col > 0 {
555                parser.expect_column_delimiter()?;
556            }
557            let raw_value = parser.consume_raw_value()?;
558            if let Some(raw_value) = raw_value {
559                match mz_pgrepr::Value::decode_text(typ, raw_value) {
560                    Ok(value) => {
561                        row.push(
562                            value
563                                .into_datum_decode_error(&buf, typ, "column")
564                                .map_err(|msg| io::Error::new(io::ErrorKind::InvalidData, msg))?,
565                        );
566                    }
567                    Err(err) => {
568                        let msg = format!("unable to decode column: {}", err);
569                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
570                    }
571                }
572            } else {
573                row.push(Datum::Null);
574            }
575        }
576        parser.expect_end_of_line()?;
577        rows.push(Row::pack(row));
578    }
579    // Note that if there is any junk data after the end of copy marker, we drop
580    // it on the floor as PG does.
581    Ok(rows)
582}
583
584#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)]
585pub struct CopyCsvFormatParams<'a> {
586    pub delimiter: u8,
587    pub quote: u8,
588    pub escape: u8,
589    pub header: bool,
590    pub null: Cow<'a, str>,
591}
592
593impl<'a> CopyCsvFormatParams<'a> {
594    pub fn to_owned(&self) -> CopyCsvFormatParams<'static> {
595        CopyCsvFormatParams {
596            delimiter: self.delimiter,
597            quote: self.quote,
598            escape: self.escape,
599            header: self.header,
600            null: Cow::Owned(self.null.to_string()),
601        }
602    }
603}
604
605impl Arbitrary for CopyCsvFormatParams<'static> {
606    type Parameters = ();
607    type Strategy = BoxedStrategy<Self>;
608
609    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
610        (
611            any::<u8>(),
612            any::<u8>(),
613            any::<u8>(),
614            any::<bool>(),
615            any::<String>(),
616        )
617            .prop_map(|(delimiter, diff, escape, header, null)| {
618                // Delimiter and Quote need to be different.
619                let diff = diff.saturating_sub(1).max(1);
620                let quote = delimiter.wrapping_add(diff);
621
622                Self::try_new(
623                    Some(delimiter),
624                    Some(quote),
625                    Some(escape),
626                    Some(header),
627                    Some(null),
628                )
629                .expect("delimiter and quote should be different")
630            })
631            .boxed()
632    }
633}
634
635impl<'a> Default for CopyCsvFormatParams<'a> {
636    fn default() -> Self {
637        CopyCsvFormatParams {
638            delimiter: b',',
639            quote: b'"',
640            escape: b'"',
641            header: false,
642            null: Cow::from(""),
643        }
644    }
645}
646
647impl<'a> CopyCsvFormatParams<'a> {
648    pub fn try_new(
649        delimiter: Option<u8>,
650        quote: Option<u8>,
651        escape: Option<u8>,
652        header: Option<bool>,
653        null: Option<String>,
654    ) -> Result<CopyCsvFormatParams<'a>, String> {
655        let mut params = CopyCsvFormatParams::default();
656
657        if let Some(delimiter) = delimiter {
658            params.delimiter = delimiter;
659        }
660        if let Some(quote) = quote {
661            params.quote = quote;
662            // escape defaults to the value provided for quote
663            params.escape = quote;
664        }
665        if let Some(escape) = escape {
666            params.escape = escape;
667        }
668        if let Some(header) = header {
669            params.header = header;
670        }
671        if let Some(null) = null {
672            params.null = Cow::from(null);
673        }
674
675        if params.quote == params.delimiter {
676            return Err("COPY delimiter and quote must be different".to_string());
677        }
678        Ok(params)
679    }
680}
681
682pub fn decode_copy_format_csv(
683    data: &[u8],
684    column_types: &[mz_pgrepr::Type],
685    CopyCsvFormatParams {
686        delimiter,
687        quote,
688        escape,
689        null,
690        header,
691    }: CopyCsvFormatParams,
692) -> Result<Vec<Row>, io::Error> {
693    let mut rows = Vec::new();
694
695    let (double_quote, escape) = if quote == escape {
696        (true, None)
697    } else {
698        (false, Some(escape))
699    };
700
701    let mut rdr = ReaderBuilder::new()
702        .delimiter(delimiter)
703        .quote(quote)
704        .has_headers(header)
705        .double_quote(double_quote)
706        .escape(escape)
707        // Must be flexible to accept end of copy marker, which will always be 1
708        // field.
709        .flexible(true)
710        .from_reader(data);
711
712    let null_as_bytes = null.as_bytes();
713
714    let mut record = ByteRecord::new();
715
716    while rdr.read_byte_record(&mut record)? {
717        if record.len() == 1 && record.iter().next() == Some(END_OF_COPY_MARKER) {
718            break;
719        }
720
721        match record.len().cmp(&column_types.len()) {
722            std::cmp::Ordering::Less => Err(io::Error::new(
723                io::ErrorKind::InvalidData,
724                "missing data for column",
725            )),
726            std::cmp::Ordering::Greater => Err(io::Error::new(
727                io::ErrorKind::InvalidData,
728                "extra data after last expected column",
729            )),
730            std::cmp::Ordering::Equal => Ok(()),
731        }?;
732
733        let mut row_builder = SharedRow::get();
734        let mut row_packer = row_builder.packer();
735
736        for (typ, raw_value) in column_types.iter().zip_eq(record.iter()) {
737            if raw_value == null_as_bytes {
738                row_packer.push(Datum::Null);
739            } else {
740                let s = match std::str::from_utf8(raw_value) {
741                    Ok(s) => s,
742                    Err(err) => {
743                        let msg = format!("invalid utf8 data in column: {}", err);
744                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
745                    }
746                };
747                match mz_pgrepr::Value::decode_text_into_row(typ, s, &mut row_packer) {
748                    Ok(()) => {}
749                    Err(err) => {
750                        let msg = format!("unable to decode column: {}", err);
751                        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
752                    }
753                }
754            }
755        }
756        rows.push(row_builder.clone());
757    }
758
759    Ok(rows)
760}
761
762#[cfg(test)]
763mod tests {
764    use mz_ore::collections::CollectionExt;
765    use mz_repr::SqlColumnType;
766    use proptest::prelude::*;
767
768    use super::*;
769
770    #[mz_ore::test]
771    fn test_copy_format_text_parser() {
772        let text = "\t\\nt e\t\\N\t\n\\x60\\xA\\x7D\\x4a\n\\44\\044\\123".as_bytes();
773        let mut parser = CopyTextFormatParser::new(text, b'\t', "\\N");
774        assert!(parser.is_column_delimiter());
775        parser
776            .expect_column_delimiter()
777            .expect("expected column delimiter");
778        assert_eq!(
779            parser
780                .consume_raw_value()
781                .expect("unexpected error")
782                .expect("unexpected empty result"),
783            "\nt e".as_bytes()
784        );
785        parser
786            .expect_column_delimiter()
787            .expect("expected column delimiter");
788        // null value
789        assert!(
790            parser
791                .consume_raw_value()
792                .expect("unexpected error")
793                .is_none()
794        );
795        parser
796            .expect_column_delimiter()
797            .expect("expected column delimiter");
798        assert!(parser.is_end_of_line());
799        parser.expect_end_of_line().expect("expected eol");
800        // hex value
801        assert_eq!(
802            parser
803                .consume_raw_value()
804                .expect("unexpected error")
805                .expect("unexpected empty result"),
806            "`\n}J".as_bytes()
807        );
808        parser.expect_end_of_line().expect("expected eol");
809        // octal value
810        assert_eq!(
811            parser
812                .consume_raw_value()
813                .expect("unexpected error")
814                .expect("unexpected empty result"),
815            "$$S".as_bytes()
816        );
817        assert!(parser.is_eof());
818    }
819
820    #[mz_ore::test]
821    fn test_copy_format_text_empty_null_string() {
822        let text = "\t\n10\t20\n30\t\n40\t".as_bytes();
823        let expect = vec![
824            vec![None, None],
825            vec![Some("10"), Some("20")],
826            vec![Some("30"), None],
827            vec![Some("40"), None],
828        ];
829        let mut parser = CopyTextFormatParser::new(text, b'\t', "");
830        for line in expect {
831            for (i, value) in line.iter().enumerate() {
832                if i > 0 {
833                    parser
834                        .expect_column_delimiter()
835                        .expect("expected column delimiter");
836                }
837                match value {
838                    Some(s) => {
839                        assert!(!parser.consume_null_string());
840                        assert_eq!(
841                            parser
842                                .consume_raw_value()
843                                .expect("unexpected error")
844                                .expect("unexpected empty result"),
845                            s.as_bytes()
846                        );
847                    }
848                    None => {
849                        assert!(parser.consume_null_string());
850                    }
851                }
852            }
853            parser.expect_end_of_line().expect("expected eol");
854        }
855    }
856
857    #[mz_ore::test]
858    fn test_copy_format_text_parser_escapes() {
859        struct TestCase {
860            input: &'static str,
861            expect: &'static [u8],
862        }
863        let tests = vec![
864            TestCase {
865                input: "simple",
866                expect: b"simple",
867            },
868            TestCase {
869                input: r#"new\nline"#,
870                expect: b"new\nline",
871            },
872            TestCase {
873                input: r#"\b\f\n\r\t\v\\"#,
874                expect: b"\x08\x0c\n\r\t\x0b\\",
875            },
876            TestCase {
877                input: r#"\0\12\123"#,
878                expect: &[0, 0o12, 0o123],
879            },
880            TestCase {
881                input: r#"\x1\xaf"#,
882                expect: &[0x01, 0xaf],
883            },
884            TestCase {
885                input: r#"T\n\07\xEV\x0fA\xb2C\1"#,
886                expect: b"T\n\x07\x0eV\x0fA\xb2C\x01",
887            },
888            TestCase {
889                input: r#"\\\""#,
890                expect: b"\\\"",
891            },
892            TestCase {
893                input: r#"\x"#,
894                expect: b"x",
895            },
896            TestCase {
897                input: r#"\xg"#,
898                expect: b"xg",
899            },
900            TestCase {
901                input: r#"\"#,
902                expect: b"\\",
903            },
904            TestCase {
905                input: r#"\8"#,
906                expect: b"8",
907            },
908            TestCase {
909                input: r#"\a"#,
910                expect: b"a",
911            },
912            TestCase {
913                input: r#"\x\xg\8\xH\x32\s\"#,
914                expect: b"xxg8xH2s\\",
915            },
916        ];
917
918        for test in tests {
919            let mut parser = CopyTextFormatParser::new(test.input.as_bytes(), b'\t', "\\N");
920            assert_eq!(
921                parser
922                    .consume_raw_value()
923                    .expect("unexpected error")
924                    .expect("unexpected empty result"),
925                test.expect,
926                "input: {}, expect: {:?}",
927                test.input,
928                std::str::from_utf8(test.expect),
929            );
930            assert!(parser.is_eof());
931        }
932    }
933
934    #[mz_ore::test]
935    fn test_copy_csv_format_params() {
936        assert_eq!(
937            CopyCsvFormatParams::try_new(Some(b't'), Some(b'q'), None, None, None),
938            Ok(CopyCsvFormatParams {
939                delimiter: b't',
940                quote: b'q',
941                escape: b'q',
942                header: false,
943                null: Cow::from(""),
944            })
945        );
946
947        assert_eq!(
948            CopyCsvFormatParams::try_new(
949                Some(b't'),
950                Some(b'q'),
951                Some(b'e'),
952                Some(true),
953                Some("null".to_string())
954            ),
955            Ok(CopyCsvFormatParams {
956                delimiter: b't',
957                quote: b'q',
958                escape: b'e',
959                header: true,
960                null: Cow::from("null"),
961            })
962        );
963
964        assert_eq!(
965            CopyCsvFormatParams::try_new(
966                None,
967                Some(b','),
968                Some(b'e'),
969                Some(true),
970                Some("null".to_string())
971            ),
972            Err("COPY delimiter and quote must be different".to_string())
973        );
974    }
975
976    #[mz_ore::test]
977    fn test_copy_csv_row() -> Result<(), io::Error> {
978        let mut row = Row::default();
979        let mut packer = row.packer();
980        packer.push(Datum::from("1,2,\"3\""));
981        packer.push(Datum::Null);
982        packer.push(Datum::from(1000u64));
983        packer.push(Datum::from("qe")); // overridden quote and escape character in test below
984        packer.push(Datum::from(""));
985
986        let typ: SqlRelationType = SqlRelationType::new(vec![
987            SqlColumnType {
988                scalar_type: mz_repr::SqlScalarType::String,
989                nullable: false,
990            },
991            SqlColumnType {
992                scalar_type: mz_repr::SqlScalarType::String,
993                nullable: true,
994            },
995            SqlColumnType {
996                scalar_type: mz_repr::SqlScalarType::UInt64,
997                nullable: false,
998            },
999            SqlColumnType {
1000                scalar_type: mz_repr::SqlScalarType::String,
1001                nullable: false,
1002            },
1003            SqlColumnType {
1004                scalar_type: mz_repr::SqlScalarType::String,
1005                nullable: false,
1006            },
1007        ]);
1008
1009        let mut out = Vec::new();
1010
1011        struct TestCase<'a> {
1012            params: CopyCsvFormatParams<'a>,
1013            expected: &'static [u8],
1014        }
1015
1016        let tests = [
1017            TestCase {
1018                params: CopyCsvFormatParams::default(),
1019                expected: b"\"1,2,\"\"3\"\"\",,1000,qe,\"\"\n",
1020            },
1021            TestCase {
1022                params: CopyCsvFormatParams {
1023                    null: Cow::from("NULL"),
1024                    quote: b'q',
1025                    escape: b'e',
1026                    ..Default::default()
1027                },
1028                expected: b"q1,2,\"3\"q,NULL,1000,qeqeeq,\n",
1029            },
1030        ];
1031
1032        for TestCase { params, expected } in tests {
1033            out.clear();
1034            let params = CopyFormatParams::Csv(params);
1035            let _ = encode_copy_format(&params, &row, &typ, &mut out);
1036            let output = std::str::from_utf8(&out);
1037            assert_eq!(output, std::str::from_utf8(expected));
1038        }
1039
1040        Ok(())
1041    }
1042
1043    proptest! {
1044        #[mz_ore::test]
1045        #[cfg_attr(miri, ignore)]
1046        fn proptest_csv_roundtrips(copy_csv_params: CopyCsvFormatParams)  {
1047            // Given a SqlScalarType and Datum roundtrips it through the CSV COPY format.
1048            let try_roundtrip_datum = |scalar_type: &SqlScalarType, datum| {
1049                let row = Row::pack_slice(&[datum]);
1050                let typ = SqlRelationType::new(vec![
1051                    SqlColumnType {
1052                        scalar_type: scalar_type.clone(),
1053                        nullable: true,
1054                    }
1055                ]);
1056
1057                let mut buf = Vec::new();
1058                let mut csv_params = copy_csv_params.clone();
1059                // TODO: Encoding never writes a header.
1060                csv_params.header = false;
1061                let params = CopyFormatParams::Csv(csv_params);
1062
1063                // Roundtrip the Row through our CSV format.
1064                encode_copy_format(&params, &row, &typ, &mut buf)?;
1065                let column_types = typ
1066                    .column_types
1067                    .iter()
1068                    .map(|x| &x.scalar_type)
1069                    .map(mz_pgrepr::Type::from)
1070                    .collect::<Vec<mz_pgrepr::Type>>();
1071                let result = decode_copy_format(&buf, &column_types, params);
1072
1073                match result {
1074                    Ok(rows) => {
1075                        let out_str = std::str::from_utf8(&buf[..]);
1076
1077                        prop_assert_eq!(
1078                            rows.len(),
1079                            1,
1080                            "unexpected number of rows! {:?}, csv string: {:?}", rows, out_str
1081                        );
1082                        let output = rows.into_element();
1083
1084                        prop_assert_eq!(
1085                            row,
1086                            output,
1087                            "csv string: {:?}, scalar_type: {:?}", out_str, scalar_type
1088                        );
1089                    }
1090                    _ => {
1091                        // ignoring decoding failures
1092                    }
1093                }
1094
1095                Ok(())
1096            };
1097
1098            // Try roundtripping all of our interesting Datums.
1099            for scalar_type in SqlScalarType::enumerate() {
1100                for datum in scalar_type.interesting_datums() {
1101                    // TODO: The decoder cannot differentiate between empty string and null.
1102                    if let Some(value) = mz_pgrepr::Value::from_datum(datum, scalar_type) {
1103                        let mut buf = bytes::BytesMut::new();
1104                        value.encode_text(&mut buf);
1105
1106                        if let Ok(datum_str) = std::str::from_utf8(&buf[..]) {
1107                            if datum_str == copy_csv_params.null {
1108                                continue;
1109                            }
1110                        }
1111                    }
1112
1113                    let updated_datum = match datum {
1114                        // TODO: Fix roundtrip decoding of these types.
1115                        Datum::Timestamp(_) | Datum::TimestampTz(_) | Datum::Null => {
1116                            continue;
1117                        }
1118                        Datum::String(s) => {
1119                            // TODO: The decoder cannot differentiate between empty string and null.
1120                            if s.trim() == copy_csv_params.null || s.trim().is_empty() {
1121                                continue;
1122                            } else {
1123                                Datum::String(s)
1124                            }
1125                        }
1126                        other => other,
1127                    };
1128
1129                    let result = try_roundtrip_datum(scalar_type, updated_datum);
1130                    prop_assert!(result.is_ok(), "failure: {result:?}");
1131                }
1132            }
1133        }
1134    }
1135}