Skip to main content

mz_testdrive/action/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::ascii;
11use std::error::Error;
12use std::fmt::{self, Display, Formatter, Write as _};
13use std::time::SystemTime;
14
15use anyhow::{Context, bail};
16use itertools::Itertools;
17use md5::{Digest, Md5};
18use mz_ore::collections::CollectionExt;
19use mz_ore::retry::Retry;
20use mz_ore::str::StrExt;
21use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8};
22use mz_postgres_util::query_prepared;
23use mz_repr::adt::range::Range;
24use mz_sql_parser::ast::{Raw, Statement};
25use postgres_array::Array;
26use regex::Regex;
27use tokio_postgres::error::DbError;
28use tokio_postgres::row::Row;
29use tokio_postgres::types::{FromSql, Type};
30
31use crate::action::{ControlFlow, Rewrite, State};
32use crate::parser::{FailSqlCommand, SqlCommand, SqlExpectedError, SqlOutput};
33
34pub async fn run_sql(mut cmd: SqlCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
35    use Statement::*;
36
37    state.rewrite_pos_start = cmd.expected_start;
38    state.rewrite_pos_end = cmd.expected_end;
39
40    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
41        .with_context(|| format!("unable to parse SQL: {}", cmd.query))?;
42    if stmts.len() != 1 {
43        bail!("expected one statement, but got {}", stmts.len());
44    }
45    let stmt = stmts.into_element().ast;
46    if let SqlOutput::Full { expected_rows, .. } = &mut cmd.expected_output {
47        // TODO(benesch): one day we'll support SQL queries where order matters.
48        expected_rows.sort();
49    }
50
51    let should_retry = match &stmt {
52        // Do not retry FETCH statements as subsequent executions are likely
53        // to return an empty result. The original result would thus be lost.
54        Fetch(_) => false,
55        // EXPLAIN ... PLAN statements should always provide the expected result
56        // on the first try
57        ExplainPlan(_) => false,
58        // DDL statements should always provide the expected result on the first try
59        CreateConnection(_)
60        | CreateCluster(_)
61        | CreateClusterReplica(_)
62        | CreateDatabase(_)
63        | CreateSchema(_)
64        | CreateSource(_)
65        | CreateWebhookSource(_)
66        | CreateSink(_)
67        | CreateMaterializedView(_)
68        | CreateView(_)
69        | CreateTable(_)
70        | CreateTableFromSource(_)
71        | CreateIndex(_)
72        | CreateType(_)
73        | CreateRole(_)
74        | AlterObjectRename(_)
75        | AlterIndex(_)
76        | AlterSink(_)
77        | Discard(_)
78        | DropObjects(_)
79        | SetVariable(_) => false,
80        _ => true,
81    };
82
83    let query = &cmd.query;
84    print_query(query, Some(&stmt));
85    let expected_output = &cmd.expected_output;
86    state.error_line_count = 0;
87    state.error_string = "".to_string();
88    let (state, res) = match should_retry {
89        true => Retry::default()
90            .initial_backoff(state.initial_backoff)
91            .factor(state.backoff_factor)
92            .max_duration(state.timeout)
93            .max_tries(state.max_tries),
94        false => Retry::default().max_duration(state.timeout).max_tries(1),
95    }
96    .retry_async_with_state(state, |retry_state, state| async move {
97        let should_continue = retry_state.i + 1 < state.max_tries && should_retry;
98        let start = SystemTime::now();
99        match try_run_sql(state, query, expected_output, should_continue).await {
100            Ok(()) => {
101                let now = SystemTime::now();
102                let epoch = SystemTime::UNIX_EPOCH;
103                let ts = now.duration_since(epoch).unwrap().as_secs_f64();
104                let delay = now.duration_since(start).unwrap().as_secs_f64();
105                println!("rows match; continuing at ts {ts}, took {delay}s");
106                (state, Ok(()))
107            }
108            Err(e) => {
109                if let Some(backoff) = retry_state.next_backoff {
110                    if !backoff.is_zero()
111                        && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI"))
112                    {
113                        let error_string = format!("{:?}", e);
114                        if error_string != state.error_string {
115                            // Remove old status lines so as not to spam the output
116                            for _ in 0..state.error_line_count {
117                                print!("\x1B[1A\x1B[2K");
118                            }
119                            state.error_line_count = 0;
120                            state.error_string = error_string;
121                            state.error_line_count = state.error_string.lines().count() + 1;
122                            if state.error_string.ends_with('\n') {
123                                print!("{}", state.error_string);
124                            } else {
125                                println!("{}", state.error_string);
126                                state.error_line_count += 1;
127                            }
128                            println!(
129                                "rows didn't match; sleeping to see if dataflow catches up 🕑 {:.0?}",
130                                retry_state.next_backoff.unwrap_or_default()
131                            );
132                        }
133                    }
134                } else {
135                    for _ in 0..state.error_line_count {
136                        print!("\x1B[1A\x1B[2K");
137                    }
138                    state.error_line_count = 0;
139                }
140                (state, Err(e))
141            }
142        }
143    })
144    .await;
145    if let Err(e) = res {
146        return Err(e);
147    }
148    if state.consistency_checks == super::consistency::Level::Statement {
149        super::consistency::run_consistency_checks(state).await?;
150    }
151
152    Ok(ControlFlow::Continue)
153}
154
155/// Quote and escape `value` using only the escape sequences that
156/// `parser::split_line` recognises (`\\`, `\"`, `\n`, `\t`, `\r`, `\0`).
157///
158/// Avoids `format!("{:?}", …)` because Rust's Debug impl also emits `\u{N}`
159/// for non-printable Unicode, which the parser does not understand and would
160/// silently corrupt on the next read.
161fn escape_for_parser(value: &str) -> String {
162    let mut out = String::with_capacity(value.len() + 2);
163    out.push('"');
164    for c in value.chars() {
165        match c {
166            '\\' => out.push_str("\\\\"),
167            '"' => out.push_str("\\\""),
168            '\n' => out.push_str("\\n"),
169            '\t' => out.push_str("\\t"),
170            '\r' => out.push_str("\\r"),
171            '\0' => out.push_str("\\0"),
172            c => out.push(c),
173        }
174    }
175    out.push('"');
176    out
177}
178
179fn rewrite_result(
180    state: &mut State,
181    columns: Vec<&str>,
182    content: Vec<Vec<String>>,
183) -> Result<(), anyhow::Error> {
184    let mut buf = String::new();
185    writeln!(buf, "{}", columns.join(" "))?;
186    writeln!(buf, "----")?;
187    for row in content {
188        let mut formatted_row = Vec::<String>::new();
189        for value in row {
190            if value.is_empty()
191                || value.contains(|x: char| char::is_ascii_whitespace(&x))
192                || value.contains('"')
193                || value.contains('\\')
194            {
195                formatted_row.push(escape_for_parser(&value));
196            } else {
197                formatted_row.push(value);
198            }
199        }
200        writeln!(buf, "{}", formatted_row.join(" "))?;
201    }
202    state.rewrites.push(Rewrite {
203        content: buf,
204        start: state.rewrite_pos_start,
205        end: state.rewrite_pos_end,
206    });
207
208    Ok(())
209}
210
211async fn try_run_sql(
212    state: &mut State,
213    query: &str,
214    expected_output: &SqlOutput,
215    should_retry: bool,
216) -> Result<(), anyhow::Error> {
217    let stmt = state
218        .materialize
219        .pgclient
220        .prepare(query)
221        .await
222        .context("preparing query failed")?;
223
224    let query_with_timeout = tokio::time::timeout(
225        state.timeout.clone(),
226        query_prepared(&state.materialize.pgclient, &stmt, &[]),
227    )
228    .await;
229
230    if query_with_timeout.is_err() {
231        bail!("query timed out\n")
232    }
233
234    let rows: Vec<_> = query_with_timeout
235        .unwrap()
236        .context("executing query failed")?
237        .into_iter()
238        .map(|row| decode_row(state, row))
239        .collect::<Result<_, _>>()?;
240
241    let (mut actual, raw_actual): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
242
243    let raw_actual: Option<Vec<_>> = if raw_actual.iter().any(|r| r.is_some()) {
244        // TODO(guswynn): Note we don't sort the raw rows, because
245        // there is no easy way of ensuring they sort the same way as actual.
246        Some(
247            actual
248                .iter()
249                .zip_eq(raw_actual)
250                .map(|(actual, unreplaced)| match unreplaced {
251                    Some(raw_row) => raw_row,
252                    None => actual.clone(),
253                })
254                .collect(),
255        )
256    } else {
257        None
258    };
259
260    actual.sort();
261    let actual_columns: Vec<_> = stmt.columns().iter().map(|c| c.name()).collect();
262
263    match expected_output {
264        SqlOutput::Full {
265            expected_rows,
266            column_names,
267        } => {
268            if let Some(column_names) = column_names {
269                if actual_columns.iter().ne(column_names) {
270                    if state.rewrite_results && !should_retry {
271                        rewrite_result(state, actual_columns, actual)?;
272                        return Ok(());
273                    } else {
274                        bail!(
275                            "column name mismatch\nexpected: {:?}\nactual:   {:?}\n",
276                            column_names,
277                            actual_columns
278                        );
279                    }
280                }
281            }
282            if &actual == expected_rows {
283                Ok(())
284            } else if state.rewrite_results && !should_retry {
285                rewrite_result(state, actual_columns, actual)?;
286                Ok(())
287            } else {
288                let (mut left, mut right) = (0, 0);
289                let mut buf = String::new();
290                while let (Some(e), Some(a)) = (expected_rows.get(left), actual.get(right)) {
291                    match e.cmp(a) {
292                        std::cmp::Ordering::Less => {
293                            writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
294                            left += 1;
295                        }
296                        std::cmp::Ordering::Equal => {
297                            left += 1;
298                            right += 1;
299                        }
300                        std::cmp::Ordering::Greater => {
301                            writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
302                            right += 1;
303                        }
304                    }
305                }
306                while let Some(e) = expected_rows.get(left) {
307                    writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
308                    left += 1;
309                }
310                while let Some(a) = actual.get(right) {
311                    writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
312                    right += 1;
313                }
314                if state.rewrite_results && !should_retry {
315                    rewrite_result(state, actual_columns, actual)?;
316                    Ok(())
317                } else if let Some(raw_actual) = raw_actual {
318                    bail!(
319                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\ngot raw rows:\n{:?}\nPoor diff:\n{}",
320                        expected_rows,
321                        actual,
322                        raw_actual,
323                        buf,
324                    )
325                } else {
326                    bail!(
327                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\nPoor diff:\n{}",
328                        expected_rows,
329                        actual,
330                        buf
331                    )
332                }
333            }
334        }
335        SqlOutput::Hashed { num_values, md5 } => {
336            if &actual.len() != num_values {
337                bail!(
338                    "wrong row count: expected:\n{:?}\ngot:\n{:?}\n",
339                    num_values,
340                    actual.len(),
341                )
342            } else {
343                let mut hasher = Md5::new();
344                for row in &actual {
345                    for entry in row {
346                        hasher.update(entry);
347                    }
348                }
349                let actual = format!("{:x}", hasher.finalize());
350                if &actual != md5 {
351                    bail!("wrong hash value: expected:{:?} got:{:?}", md5, actual)
352                } else {
353                    Ok(())
354                }
355            }
356        }
357    }
358}
359
360#[derive(Clone)]
361enum ErrorMatcher {
362    Contains(String),
363    Exact(String),
364    Regex(Regex),
365    Timeout,
366}
367
368impl ErrorMatcher {
369    fn is_match(&self, err: &String) -> bool {
370        match self {
371            ErrorMatcher::Contains(s) => err.contains(s),
372            ErrorMatcher::Exact(s) => err == s,
373            ErrorMatcher::Regex(r) => r.is_match(err),
374            // Timeouts never match errors directly. If we are matching an error
375            // message, it means the query returned a result (i.e., an error
376            // result), which means the query did not time out as expected.
377            ErrorMatcher::Timeout => false,
378        }
379    }
380}
381
382impl fmt::Display for ErrorMatcher {
383    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
384        match self {
385            ErrorMatcher::Contains(s) => write!(f, "error containing {}", s.quoted()),
386            ErrorMatcher::Exact(s) => write!(f, "exact error {}", s.quoted()),
387            ErrorMatcher::Regex(s) => write!(f, "error matching regex {}", s.as_str().quoted()),
388            ErrorMatcher::Timeout => f.write_str("timeout"),
389        }
390    }
391}
392
393impl ErrorMatcher {
394    fn fmt_with_type(&self, type_: &str) -> String {
395        match self {
396            ErrorMatcher::Contains(s) => format!("{} containing {}", type_, s.quoted()),
397            ErrorMatcher::Exact(s) => format!("exact {} {}", type_, s.quoted()),
398            ErrorMatcher::Regex(s) => format!("{} matching regex {}", type_, s.as_str().quoted()),
399            ErrorMatcher::Timeout => "timeout".to_string(),
400        }
401    }
402}
403
404pub async fn run_fail_sql(
405    cmd: FailSqlCommand,
406    state: &mut State,
407) -> Result<ControlFlow, anyhow::Error> {
408    use Statement::{AlterSink, Commit, CreateConnection, Fetch, Rollback};
409
410    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
411        .map_err(|e| format!("unable to parse SQL: {}: {}", cmd.query, e));
412
413    // Allow for statements that could not be parsed.
414    // This way such statements can be used for negative testing in .td files
415    let stmt = match stmts {
416        Ok(s) => {
417            if s.len() != 1 {
418                bail!("expected one statement, but got {}", s.len());
419            }
420            Some(s.into_element().ast)
421        }
422        Err(_) => None,
423    };
424
425    let expected_error = match cmd.expected_error {
426        SqlExpectedError::Contains(s) => ErrorMatcher::Contains(s),
427        SqlExpectedError::Exact(s) => ErrorMatcher::Exact(s),
428        SqlExpectedError::Regex(s) => ErrorMatcher::Regex(s.parse()?),
429        SqlExpectedError::Timeout => ErrorMatcher::Timeout,
430    };
431    let expected_detail = cmd.expected_detail.map(ErrorMatcher::Contains);
432    let expected_hint = cmd.expected_hint.map(ErrorMatcher::Contains);
433
434    let query = &cmd.query;
435    print_query(query, stmt.as_ref());
436
437    let should_retry = match &stmt {
438        // Do not retry statements that could not be parsed
439        None => false,
440        // Do not retry COMMIT and ROLLBACK. Once the transaction has errored out and has
441        // been aborted, retrying COMMIT or ROLLBACK will actually start succeeding, which
442        // causes testdrive to emit a confusing "query succeded but expected error" message.
443        Some(Commit(_)) | Some(Rollback(_)) => false,
444        // FETCH should not be retried because it consumes data on each response.
445        Some(Fetch(_)) => false,
446        Some(AlterSink(_)) => false,
447        Some(CreateConnection(_)) => false,
448        Some(_) => true,
449    };
450
451    state.error_line_count = 0;
452    state.error_string = "".to_string();
453    let res = match should_retry {
454        true => Retry::default()
455            .initial_backoff(state.initial_backoff)
456            .factor(state.backoff_factor)
457            .max_duration(state.timeout)
458            .max_tries(state.max_tries),
459        false => Retry::default().max_duration(state.timeout).max_tries(1),
460    }
461    .retry_async_with_state_canceling(state, |retry_state, state| {
462        let expected_error = expected_error.clone();
463        let expected_detail = expected_detail.clone();
464        let expected_hint = expected_hint.clone();
465        async move {
466            match try_run_fail_sql(
467                state,
468                query,
469                &expected_error,
470                expected_detail.as_ref(),
471                expected_hint.as_ref(),
472            )
473            .await
474            {
475                Ok(()) => {
476                    println!("query error matches; continuing");
477                    (state, Ok(()))
478                }
479                Err(e) => {
480                    if let Some(backoff) = retry_state.next_backoff {
481                        if !backoff.is_zero() && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI")) {
482                            let error_string = format!("{:?}", e);
483                            if error_string != state.error_string {
484                                // Remove old status lines so as not to spam the output
485                                for _ in 0..state.error_line_count {
486                                    print!("\x1B[1A\x1B[2K");
487                                }
488                                state.error_line_count = 0;
489                                state.error_string = error_string;
490                                state.error_line_count = state.error_string.lines().count() + 1;
491                                if state.error_string.ends_with('\n') {
492                                    print!("{}", state.error_string);
493                                } else {
494                                    println!("{}", state.error_string);
495                                    state.error_line_count += 1;
496                                }
497                                println!("query error didn't match; sleeping to see if dataflow produces error shortly 🕑 {:.0?}", retry_state.next_backoff.unwrap_or_default());
498                            }
499                        }
500                    } else {
501                        for _ in 0..state.error_line_count {
502                            print!("\x1B[1A\x1B[2K");
503                        }
504                        state.error_line_count = 0;
505                    }
506                    (state, Err(e))
507                }
508            }
509        }})
510    .await;
511
512    // If a timeout was expected, check whether the retry operation timed
513    // out, which indicates that the test passed.
514    if let ErrorMatcher::Timeout = expected_error {
515        if let Err(e) = &res {
516            if e.is::<tokio::time::error::Elapsed>() {
517                println!("query timed out as expected");
518                return Ok(ControlFlow::Continue);
519            }
520        }
521    }
522
523    // Otherwise, return the error if any. Note that this is the error
524    // returned by the retry operation (e.g., "expected timeout, but query
525    // succeeded"), *not* an error returned from Materialize itself.
526    res?;
527    Ok(ControlFlow::Continue)
528}
529
530async fn try_run_fail_sql(
531    state: &State,
532    query: &str,
533    expected_error: &ErrorMatcher,
534    expected_detail: Option<&ErrorMatcher>,
535    expected_hint: Option<&ErrorMatcher>,
536) -> Result<(), anyhow::Error> {
537    // `query` is raw SQL from testdrive input and may include statements that
538    // cannot be represented as composable `Sql`.
539    #[allow(clippy::disallowed_methods)]
540    match state.materialize.pgclient.query(query, &[]).await {
541        Ok(_) => bail!("query succeeded, but expected {}", expected_error),
542        Err(err) => match err.source().and_then(|err| err.downcast_ref::<DbError>()) {
543            Some(err) => {
544                let mut err_string = err.message().to_string();
545                if let Some(regex) = &state.regex {
546                    err_string = regex
547                        .replace_all(&err_string, state.regex_replacement.as_str())
548                        .to_string();
549                }
550                if !expected_error.is_match(&err_string) {
551                    bail!("expected {}, got {}", expected_error, err_string.quoted());
552                }
553
554                let check_additional =
555                    |extra: Option<&str>, matcher: Option<&ErrorMatcher>, type_| {
556                        let extra = extra.map(|s| s.to_string());
557                        match (extra, matcher) {
558                            (Some(extra), Some(expected)) => {
559                                if !expected.is_match(&extra) {
560                                    bail!(
561                                        "expected {}, got {}",
562                                        expected.fmt_with_type(type_),
563                                        extra.quoted()
564                                    );
565                                }
566                            }
567                            (None, Some(expected)) => {
568                                bail!("expected {}, but found none", expected.fmt_with_type(type_));
569                            }
570                            _ => {}
571                        }
572                        Ok(())
573                    };
574
575                check_additional(err.detail(), expected_detail, "DETAIL")?;
576                check_additional(err.hint(), expected_hint, "HINT")?;
577
578                Ok(())
579            }
580            None => Err(err.into()),
581        },
582    }
583}
584
585pub fn print_query(query: &str, stmt: Option<&Statement<Raw>>) {
586    use Statement::*;
587    if let Some(CreateSecret(_)) = stmt {
588        println!(
589            "> CREATE SECRET [query truncated on purpose so as to not reveal the secret in the log]"
590        );
591    } else {
592        println!("> {}", query)
593    }
594}
595
596// Returns the row after regex replacments, and the before, if its different
597pub fn decode_row(
598    state: &State,
599    row: Row,
600) -> Result<(Vec<String>, Option<Vec<String>>), anyhow::Error> {
601    enum ArrayElement<T> {
602        Null,
603        NonNull(T),
604    }
605
606    impl<T> fmt::Display for ArrayElement<T>
607    where
608        T: fmt::Display,
609    {
610        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
611            match self {
612                ArrayElement::Null => f.write_str("NULL"),
613                ArrayElement::NonNull(t) => t.fmt(f),
614            }
615        }
616    }
617
618    impl<'a, T> FromSql<'a> for ArrayElement<T>
619    where
620        T: FromSql<'a>,
621    {
622        fn from_sql(
623            ty: &Type,
624            raw: &'a [u8],
625        ) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
626            T::from_sql(ty, raw).map(ArrayElement::NonNull)
627        }
628
629        fn from_sql_null(_: &Type) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
630            Ok(ArrayElement::Null)
631        }
632
633        fn accepts(ty: &Type) -> bool {
634            T::accepts(ty)
635        }
636    }
637
638    /// This lets us:
639    /// - Continue using the default method of printing array elements while
640    /// preserving SQL-looking output w/ `dec::to_standard_notation_string`.
641    /// - Avoid upstreaming a complicated change to `rust-postgres-array`.
642    struct NumericStandardNotation(Numeric);
643
644    impl fmt::Display for NumericStandardNotation {
645        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
646            write!(f, "{}", self.0.0.0.to_standard_notation_string())
647        }
648    }
649
650    impl<'a> FromSql<'a> for NumericStandardNotation {
651        fn from_sql(
652            ty: &Type,
653            raw: &'a [u8],
654        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
655            Ok(NumericStandardNotation(Numeric::from_sql(ty, raw)?))
656        }
657
658        fn from_sql_null(
659            ty: &Type,
660        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
661            Ok(NumericStandardNotation(Numeric::from_sql_null(ty)?))
662        }
663
664        fn accepts(ty: &Type) -> bool {
665            Numeric::accepts(ty)
666        }
667    }
668
669    let mut out = vec![];
670    let mut raw_out = vec![];
671    for (i, col) in row.columns().iter().enumerate() {
672        let ty = col.type_();
673        let mut value: String = match *ty {
674            Type::ACLITEM => row.get::<_, Option<AclItem>>(i).map(|x| x.0),
675            Type::BOOL => row.get::<_, Option<bool>>(i).map(|x| x.to_string()),
676            Type::BPCHAR | Type::TEXT | Type::VARCHAR => row.get::<_, Option<String>>(i),
677            Type::TEXT_ARRAY => row
678                .get::<_, Option<Array<ArrayElement<String>>>>(i)
679                .map(|a| a.to_string()),
680            Type::BYTEA => row.get::<_, Option<Vec<u8>>>(i).map(|x| {
681                let s = x.into_iter().map(ascii::escape_default).flatten().collect();
682                String::from_utf8(s).unwrap()
683            }),
684            Type::CHAR => row.get::<_, Option<i8>>(i).map(|x| x.to_string()),
685            Type::INT2 => row.get::<_, Option<i16>>(i).map(|x| x.to_string()),
686            Type::INT4 => row.get::<_, Option<i32>>(i).map(|x| x.to_string()),
687            Type::INT8 => row.get::<_, Option<i64>>(i).map(|x| x.to_string()),
688            Type::OID => row.get::<_, Option<u32>>(i).map(|x| x.to_string()),
689            Type::NUMERIC => row
690                .get::<_, Option<NumericStandardNotation>>(i)
691                .map(|x| x.to_string()),
692            Type::FLOAT4 => row.get::<_, Option<f32>>(i).map(|x| x.to_string()),
693            Type::FLOAT8 => row.get::<_, Option<f64>>(i).map(|x| x.to_string()),
694            Type::TIMESTAMP => row
695                .get::<_, Option<chrono::NaiveDateTime>>(i)
696                .map(|x| x.to_string()),
697            Type::TIMESTAMPTZ => row
698                .get::<_, Option<chrono::DateTime<chrono::Utc>>>(i)
699                .map(|x| x.to_string()),
700            Type::DATE => row
701                .get::<_, Option<chrono::NaiveDate>>(i)
702                .map(|x| x.to_string()),
703            Type::TIME => row
704                .get::<_, Option<chrono::NaiveTime>>(i)
705                .map(|x| x.to_string()),
706            Type::INTERVAL => row.get::<_, Option<Interval>>(i).map(|x| x.to_string()),
707            Type::JSONB => row.get::<_, Option<Jsonb>>(i).map(|v| v.0.to_string()),
708            Type::UUID => row.get::<_, Option<uuid::Uuid>>(i).map(|v| v.to_string()),
709            Type::BOOL_ARRAY => row
710                .get::<_, Option<Array<ArrayElement<bool>>>>(i)
711                .map(|a| a.to_string()),
712            Type::INT2_ARRAY | Type::INT2_VECTOR => row
713                .get::<_, Option<Array<ArrayElement<i16>>>>(i)
714                .map(|a| a.to_string()),
715            Type::INT4_ARRAY => row
716                .get::<_, Option<Array<ArrayElement<i32>>>>(i)
717                .map(|a| a.to_string()),
718            Type::INT8_ARRAY => row
719                .get::<_, Option<Array<ArrayElement<i64>>>>(i)
720                .map(|a| a.to_string()),
721            Type::OID_ARRAY => row
722                .get::<_, Option<Array<ArrayElement<u32>>>>(i)
723                .map(|x| x.to_string()),
724            Type::NUMERIC_ARRAY => row
725                .get::<_, Option<Array<ArrayElement<NumericStandardNotation>>>>(i)
726                .map(|x| x.to_string()),
727            Type::FLOAT4_ARRAY => row
728                .get::<_, Option<Array<ArrayElement<f32>>>>(i)
729                .map(|x| x.to_string()),
730            Type::FLOAT8_ARRAY => row
731                .get::<_, Option<Array<ArrayElement<f64>>>>(i)
732                .map(|x| x.to_string()),
733            Type::TIMESTAMP_ARRAY => row
734                .get::<_, Option<Array<ArrayElement<chrono::NaiveDateTime>>>>(i)
735                .map(|x| x.to_string()),
736            Type::TIMESTAMPTZ_ARRAY => row
737                .get::<_, Option<Array<ArrayElement<chrono::DateTime<chrono::Utc>>>>>(i)
738                .map(|x| x.to_string()),
739            Type::DATE_ARRAY => row
740                .get::<_, Option<Array<ArrayElement<chrono::NaiveDate>>>>(i)
741                .map(|x| x.to_string()),
742            Type::TIME_ARRAY => row
743                .get::<_, Option<Array<ArrayElement<chrono::NaiveTime>>>>(i)
744                .map(|x| x.to_string()),
745            Type::INTERVAL_ARRAY => row
746                .get::<_, Option<Array<ArrayElement<Interval>>>>(i)
747                .map(|x| x.to_string()),
748            Type::JSONB_ARRAY => row
749                .get::<_, Option<Array<ArrayElement<Jsonb>>>>(i)
750                .map(|v| v.to_string()),
751            Type::UUID_ARRAY => row
752                .get::<_, Option<Array<ArrayElement<uuid::Uuid>>>>(i)
753                .map(|v| v.to_string()),
754            Type::INT4_RANGE => row.get::<_, Option<Range<i32>>>(i).map(|v| v.to_string()),
755            Type::INT4_RANGE_ARRAY => row
756                .get::<_, Option<Array<ArrayElement<Range<i32>>>>>(i)
757                .map(|v| v.to_string()),
758            Type::INT8_RANGE => row.get::<_, Option<Range<i64>>>(i).map(|v| v.to_string()),
759            Type::INT8_RANGE_ARRAY => row
760                .get::<_, Option<Array<ArrayElement<Range<i64>>>>>(i)
761                .map(|v| v.to_string()),
762            Type::NUM_RANGE => row
763                .get::<_, Option<Range<NumericStandardNotation>>>(i)
764                .map(|v| v.to_string()),
765            Type::NUM_RANGE_ARRAY => row
766                .get::<_, Option<Array<ArrayElement<Range<NumericStandardNotation>>>>>(i)
767                .map(|v| v.to_string()),
768            Type::DATE_RANGE => row
769                .get::<_, Option<Range<chrono::NaiveDate>>>(i)
770                .map(|v| v.to_string()),
771            Type::DATE_RANGE_ARRAY => row
772                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDate>>>>>(i)
773                .map(|v| v.to_string()),
774            Type::TS_RANGE => row
775                .get::<_, Option<Range<chrono::NaiveDateTime>>>(i)
776                .map(|v| v.to_string()),
777            Type::TS_RANGE_ARRAY => row
778                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDateTime>>>>>(i)
779                .map(|v| v.to_string()),
780            Type::TSTZ_RANGE => row
781                .get::<_, Option<Range<chrono::DateTime<chrono::Utc>>>>(i)
782                .map(|v| v.to_string()),
783            Type::TSTZ_RANGE_ARRAY => row
784                .get::<_, Option<Array<ArrayElement<Range<chrono::DateTime<chrono::Utc>>>>>>(i)
785                .map(|v| v.to_string()),
786            _ => match ty.oid() {
787                mz_pgrepr::oid::TYPE_UINT2_OID => {
788                    row.get::<_, Option<UInt2>>(i).map(|x| x.0.to_string())
789                }
790                mz_pgrepr::oid::TYPE_UINT4_OID => {
791                    row.get::<_, Option<UInt4>>(i).map(|x| x.0.to_string())
792                }
793                mz_pgrepr::oid::TYPE_UINT8_OID => {
794                    row.get::<_, Option<UInt8>>(i).map(|x| x.0.to_string())
795                }
796                mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID => {
797                    row.get::<_, Option<MzTimestamp>>(i).map(|x| x.0)
798                }
799                _ => bail!("unsupported SQL type in testdrive: {:?}", ty),
800            },
801        }
802        .unwrap_or_else(|| "<null>".into());
803
804        raw_out.push(value.clone());
805        if let Some(regex) = &state.regex {
806            value = regex
807                .replace_all(&value, state.regex_replacement.as_str())
808                .to_string();
809        }
810
811        out.push(value);
812    }
813    let raw_out = if out != raw_out { Some(raw_out) } else { None };
814    Ok((out, raw_out))
815}
816
817struct MzTimestamp(String);
818
819impl<'a> FromSql<'a> for MzTimestamp {
820    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
821        Ok(MzTimestamp(std::str::from_utf8(raw)?.to_string()))
822    }
823
824    fn accepts(ty: &Type) -> bool {
825        ty.oid() == mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID
826    }
827}
828
829#[allow(dead_code)]
830struct MzAclItem(String);
831
832impl<'a> FromSql<'a> for MzAclItem {
833    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
834        Ok(MzAclItem(std::str::from_utf8(raw)?.to_string()))
835    }
836
837    fn accepts(ty: &Type) -> bool {
838        ty.oid() == mz_pgrepr::oid::TYPE_MZ_ACL_ITEM_OID
839    }
840}
841
842struct AclItem(String);
843
844impl<'a> FromSql<'a> for AclItem {
845    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
846        Ok(AclItem(std::str::from_utf8(raw)?.to_string()))
847    }
848
849    fn accepts(ty: &Type) -> bool {
850        ty.oid() == 1033
851    }
852}
853
854struct TestdriveRow<'a>(&'a Vec<String>);
855
856impl Display for TestdriveRow<'_> {
857    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
858        let mut cols = Vec::<String>::new();
859
860        for col_str in &self.0[0..self.0.len()] {
861            if col_str.contains(' ') || col_str.contains('"') || col_str.is_empty() {
862                cols.push(escape_for_parser(col_str));
863            } else {
864                cols.push(col_str.to_string());
865            }
866        }
867
868        write!(f, "{}", cols.join(" "))
869    }
870}