Skip to main content

mz_testdrive/action/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::ascii;
11use std::error::Error;
12use std::fmt::{self, Display, Formatter, Write as _};
13use std::time::SystemTime;
14
15use anyhow::{Context, bail};
16use itertools::Itertools;
17use md5::{Digest, Md5};
18use mz_ore::collections::CollectionExt;
19use mz_ore::retry::Retry;
20use mz_ore::str::StrExt;
21use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8};
22use mz_repr::adt::range::Range;
23use mz_sql_parser::ast::{Raw, Statement};
24use postgres_array::Array;
25use regex::Regex;
26use tokio_postgres::error::DbError;
27use tokio_postgres::row::Row;
28use tokio_postgres::types::{FromSql, Type};
29
30use crate::action::{ControlFlow, Rewrite, State};
31use crate::parser::{FailSqlCommand, SqlCommand, SqlExpectedError, SqlOutput};
32
33pub async fn run_sql(mut cmd: SqlCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
34    use Statement::*;
35
36    state.rewrite_pos_start = cmd.expected_start;
37    state.rewrite_pos_end = cmd.expected_end;
38
39    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
40        .with_context(|| format!("unable to parse SQL: {}", cmd.query))?;
41    if stmts.len() != 1 {
42        bail!("expected one statement, but got {}", stmts.len());
43    }
44    let stmt = stmts.into_element().ast;
45    if let SqlOutput::Full { expected_rows, .. } = &mut cmd.expected_output {
46        // TODO(benesch): one day we'll support SQL queries where order matters.
47        expected_rows.sort();
48    }
49
50    let should_retry = match &stmt {
51        // Do not retry FETCH statements as subsequent executions are likely
52        // to return an empty result. The original result would thus be lost.
53        Fetch(_) => false,
54        // EXPLAIN ... PLAN statements should always provide the expected result
55        // on the first try
56        ExplainPlan(_) => false,
57        // DDL statements should always provide the expected result on the first try
58        CreateConnection(_)
59        | CreateCluster(_)
60        | CreateClusterReplica(_)
61        | CreateDatabase(_)
62        | CreateSchema(_)
63        | CreateSource(_)
64        | CreateWebhookSource(_)
65        | CreateSink(_)
66        | CreateMaterializedView(_)
67        | CreateView(_)
68        | CreateTable(_)
69        | CreateTableFromSource(_)
70        | CreateIndex(_)
71        | CreateType(_)
72        | CreateRole(_)
73        | AlterObjectRename(_)
74        | AlterIndex(_)
75        | AlterSink(_)
76        | Discard(_)
77        | DropObjects(_)
78        | SetVariable(_) => false,
79        _ => true,
80    };
81
82    let query = &cmd.query;
83    print_query(query, Some(&stmt));
84    let expected_output = &cmd.expected_output;
85    state.error_line_count = 0;
86    state.error_string = "".to_string();
87    let (state, res) = match should_retry {
88        true => Retry::default()
89            .initial_backoff(state.initial_backoff)
90            .factor(state.backoff_factor)
91            .max_duration(state.timeout)
92            .max_tries(state.max_tries),
93        false => Retry::default().max_duration(state.timeout).max_tries(1),
94    }
95    .retry_async_with_state(state, |retry_state, state| async move {
96        let should_continue = retry_state.i + 1 < state.max_tries && should_retry;
97        let start = SystemTime::now();
98        match try_run_sql(state, query, expected_output, should_continue).await {
99            Ok(()) => {
100                let now = SystemTime::now();
101                let epoch = SystemTime::UNIX_EPOCH;
102                let ts = now.duration_since(epoch).unwrap().as_secs_f64();
103                let delay = now.duration_since(start).unwrap().as_secs_f64();
104                println!("rows match; continuing at ts {ts}, took {delay}s");
105                (state, Ok(()))
106            }
107            Err(e) => {
108                if let Some(backoff) = retry_state.next_backoff {
109                    if !backoff.is_zero()
110                        && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI"))
111                    {
112                        let error_string = format!("{:?}", e);
113                        if error_string != state.error_string {
114                            // Remove old status lines so as not to spam the output
115                            for _ in 0..state.error_line_count {
116                                print!("\x1B[1A\x1B[2K");
117                            }
118                            state.error_line_count = 0;
119                            state.error_string = error_string;
120                            state.error_line_count = state.error_string.lines().count() + 1;
121                            if state.error_string.ends_with('\n') {
122                                print!("{}", state.error_string);
123                            } else {
124                                println!("{}", state.error_string);
125                                state.error_line_count += 1;
126                            }
127                            println!(
128                                "rows didn't match; sleeping to see if dataflow catches up 🕑 {:.0?}",
129                                retry_state.next_backoff.unwrap_or_default()
130                            );
131                        }
132                    }
133                } else {
134                    for _ in 0..state.error_line_count {
135                        print!("\x1B[1A\x1B[2K");
136                    }
137                    state.error_line_count = 0;
138                }
139                (state, Err(e))
140            }
141        }
142    })
143    .await;
144    if let Err(e) = res {
145        return Err(e);
146    }
147    if state.consistency_checks == super::consistency::Level::Statement {
148        super::consistency::run_consistency_checks(state).await?;
149    }
150
151    Ok(ControlFlow::Continue)
152}
153
154/// Quote and escape `value` using only the escape sequences that
155/// `parser::split_line` recognises (`\\`, `\"`, `\n`, `\t`, `\r`, `\0`).
156///
157/// Avoids `format!("{:?}", …)` because Rust's Debug impl also emits `\u{N}`
158/// for non-printable Unicode, which the parser does not understand and would
159/// silently corrupt on the next read.
160fn escape_for_parser(value: &str) -> String {
161    let mut out = String::with_capacity(value.len() + 2);
162    out.push('"');
163    for c in value.chars() {
164        match c {
165            '\\' => out.push_str("\\\\"),
166            '"' => out.push_str("\\\""),
167            '\n' => out.push_str("\\n"),
168            '\t' => out.push_str("\\t"),
169            '\r' => out.push_str("\\r"),
170            '\0' => out.push_str("\\0"),
171            c => out.push(c),
172        }
173    }
174    out.push('"');
175    out
176}
177
178fn rewrite_result(
179    state: &mut State,
180    columns: Vec<&str>,
181    content: Vec<Vec<String>>,
182) -> Result<(), anyhow::Error> {
183    let mut buf = String::new();
184    writeln!(buf, "{}", columns.join(" "))?;
185    writeln!(buf, "----")?;
186    for row in content {
187        let mut formatted_row = Vec::<String>::new();
188        for value in row {
189            if value.is_empty()
190                || value.contains(|x: char| char::is_ascii_whitespace(&x))
191                || value.contains('"')
192                || value.contains('\\')
193            {
194                formatted_row.push(escape_for_parser(&value));
195            } else {
196                formatted_row.push(value);
197            }
198        }
199        writeln!(buf, "{}", formatted_row.join(" "))?;
200    }
201    state.rewrites.push(Rewrite {
202        content: buf,
203        start: state.rewrite_pos_start,
204        end: state.rewrite_pos_end,
205    });
206
207    Ok(())
208}
209
210async fn try_run_sql(
211    state: &mut State,
212    query: &str,
213    expected_output: &SqlOutput,
214    should_retry: bool,
215) -> Result<(), anyhow::Error> {
216    let stmt = state
217        .materialize
218        .pgclient
219        .prepare(query)
220        .await
221        .context("preparing query failed")?;
222
223    let query_with_timeout = tokio::time::timeout(
224        state.timeout.clone(),
225        state.materialize.pgclient.query(&stmt, &[]),
226    )
227    .await;
228
229    if query_with_timeout.is_err() {
230        bail!("query timed out\n")
231    }
232
233    let rows: Vec<_> = query_with_timeout
234        .unwrap()
235        .context("executing query failed")?
236        .into_iter()
237        .map(|row| decode_row(state, row))
238        .collect::<Result<_, _>>()?;
239
240    let (mut actual, raw_actual): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
241
242    let raw_actual: Option<Vec<_>> = if raw_actual.iter().any(|r| r.is_some()) {
243        // TODO(guswynn): Note we don't sort the raw rows, because
244        // there is no easy way of ensuring they sort the same way as actual.
245        Some(
246            actual
247                .iter()
248                .zip_eq(raw_actual)
249                .map(|(actual, unreplaced)| match unreplaced {
250                    Some(raw_row) => raw_row,
251                    None => actual.clone(),
252                })
253                .collect(),
254        )
255    } else {
256        None
257    };
258
259    actual.sort();
260    let actual_columns: Vec<_> = stmt.columns().iter().map(|c| c.name()).collect();
261
262    match expected_output {
263        SqlOutput::Full {
264            expected_rows,
265            column_names,
266        } => {
267            if let Some(column_names) = column_names {
268                if actual_columns.iter().ne(column_names) {
269                    if state.rewrite_results && !should_retry {
270                        rewrite_result(state, actual_columns, actual)?;
271                        return Ok(());
272                    } else {
273                        bail!(
274                            "column name mismatch\nexpected: {:?}\nactual:   {:?}\n",
275                            column_names,
276                            actual_columns
277                        );
278                    }
279                }
280            }
281            if &actual == expected_rows {
282                Ok(())
283            } else if state.rewrite_results && !should_retry {
284                rewrite_result(state, actual_columns, actual)?;
285                Ok(())
286            } else {
287                let (mut left, mut right) = (0, 0);
288                let mut buf = String::new();
289                while let (Some(e), Some(a)) = (expected_rows.get(left), actual.get(right)) {
290                    match e.cmp(a) {
291                        std::cmp::Ordering::Less => {
292                            writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
293                            left += 1;
294                        }
295                        std::cmp::Ordering::Equal => {
296                            left += 1;
297                            right += 1;
298                        }
299                        std::cmp::Ordering::Greater => {
300                            writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
301                            right += 1;
302                        }
303                    }
304                }
305                while let Some(e) = expected_rows.get(left) {
306                    writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
307                    left += 1;
308                }
309                while let Some(a) = actual.get(right) {
310                    writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
311                    right += 1;
312                }
313                if state.rewrite_results && !should_retry {
314                    rewrite_result(state, actual_columns, actual)?;
315                    Ok(())
316                } else if let Some(raw_actual) = raw_actual {
317                    bail!(
318                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\ngot raw rows:\n{:?}\nPoor diff:\n{}",
319                        expected_rows,
320                        actual,
321                        raw_actual,
322                        buf,
323                    )
324                } else {
325                    bail!(
326                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\nPoor diff:\n{}",
327                        expected_rows,
328                        actual,
329                        buf
330                    )
331                }
332            }
333        }
334        SqlOutput::Hashed { num_values, md5 } => {
335            if &actual.len() != num_values {
336                bail!(
337                    "wrong row count: expected:\n{:?}\ngot:\n{:?}\n",
338                    num_values,
339                    actual.len(),
340                )
341            } else {
342                let mut hasher = Md5::new();
343                for row in &actual {
344                    for entry in row {
345                        hasher.update(entry);
346                    }
347                }
348                let actual = format!("{:x}", hasher.finalize());
349                if &actual != md5 {
350                    bail!("wrong hash value: expected:{:?} got:{:?}", md5, actual)
351                } else {
352                    Ok(())
353                }
354            }
355        }
356    }
357}
358
359#[derive(Clone)]
360enum ErrorMatcher {
361    Contains(String),
362    Exact(String),
363    Regex(Regex),
364    Timeout,
365}
366
367impl ErrorMatcher {
368    fn is_match(&self, err: &String) -> bool {
369        match self {
370            ErrorMatcher::Contains(s) => err.contains(s),
371            ErrorMatcher::Exact(s) => err == s,
372            ErrorMatcher::Regex(r) => r.is_match(err),
373            // Timeouts never match errors directly. If we are matching an error
374            // message, it means the query returned a result (i.e., an error
375            // result), which means the query did not time out as expected.
376            ErrorMatcher::Timeout => false,
377        }
378    }
379}
380
381impl fmt::Display for ErrorMatcher {
382    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
383        match self {
384            ErrorMatcher::Contains(s) => write!(f, "error containing {}", s.quoted()),
385            ErrorMatcher::Exact(s) => write!(f, "exact error {}", s.quoted()),
386            ErrorMatcher::Regex(s) => write!(f, "error matching regex {}", s.as_str().quoted()),
387            ErrorMatcher::Timeout => f.write_str("timeout"),
388        }
389    }
390}
391
392impl ErrorMatcher {
393    fn fmt_with_type(&self, type_: &str) -> String {
394        match self {
395            ErrorMatcher::Contains(s) => format!("{} containing {}", type_, s.quoted()),
396            ErrorMatcher::Exact(s) => format!("exact {} {}", type_, s.quoted()),
397            ErrorMatcher::Regex(s) => format!("{} matching regex {}", type_, s.as_str().quoted()),
398            ErrorMatcher::Timeout => "timeout".to_string(),
399        }
400    }
401}
402
403pub async fn run_fail_sql(
404    cmd: FailSqlCommand,
405    state: &mut State,
406) -> Result<ControlFlow, anyhow::Error> {
407    use Statement::{AlterSink, Commit, CreateConnection, Fetch, Rollback};
408
409    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
410        .map_err(|e| format!("unable to parse SQL: {}: {}", cmd.query, e));
411
412    // Allow for statements that could not be parsed.
413    // This way such statements can be used for negative testing in .td files
414    let stmt = match stmts {
415        Ok(s) => {
416            if s.len() != 1 {
417                bail!("expected one statement, but got {}", s.len());
418            }
419            Some(s.into_element().ast)
420        }
421        Err(_) => None,
422    };
423
424    let expected_error = match cmd.expected_error {
425        SqlExpectedError::Contains(s) => ErrorMatcher::Contains(s),
426        SqlExpectedError::Exact(s) => ErrorMatcher::Exact(s),
427        SqlExpectedError::Regex(s) => ErrorMatcher::Regex(s.parse()?),
428        SqlExpectedError::Timeout => ErrorMatcher::Timeout,
429    };
430    let expected_detail = cmd.expected_detail.map(ErrorMatcher::Contains);
431    let expected_hint = cmd.expected_hint.map(ErrorMatcher::Contains);
432
433    let query = &cmd.query;
434    print_query(query, stmt.as_ref());
435
436    let should_retry = match &stmt {
437        // Do not retry statements that could not be parsed
438        None => false,
439        // Do not retry COMMIT and ROLLBACK. Once the transaction has errored out and has
440        // been aborted, retrying COMMIT or ROLLBACK will actually start succeeding, which
441        // causes testdrive to emit a confusing "query succeded but expected error" message.
442        Some(Commit(_)) | Some(Rollback(_)) => false,
443        // FETCH should not be retried because it consumes data on each response.
444        Some(Fetch(_)) => false,
445        Some(AlterSink(_)) => false,
446        Some(CreateConnection(_)) => false,
447        Some(_) => true,
448    };
449
450    state.error_line_count = 0;
451    state.error_string = "".to_string();
452    let res = match should_retry {
453        true => Retry::default()
454            .initial_backoff(state.initial_backoff)
455            .factor(state.backoff_factor)
456            .max_duration(state.timeout)
457            .max_tries(state.max_tries),
458        false => Retry::default().max_duration(state.timeout).max_tries(1),
459    }
460    .retry_async_with_state_canceling(state, |retry_state, state| {
461        let expected_error = expected_error.clone();
462        let expected_detail = expected_detail.clone();
463        let expected_hint = expected_hint.clone();
464        async move {
465            match try_run_fail_sql(
466                state,
467                query,
468                &expected_error,
469                expected_detail.as_ref(),
470                expected_hint.as_ref(),
471            )
472            .await
473            {
474                Ok(()) => {
475                    println!("query error matches; continuing");
476                    (state, Ok(()))
477                }
478                Err(e) => {
479                    if let Some(backoff) = retry_state.next_backoff {
480                        if !backoff.is_zero() && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI")) {
481                            let error_string = format!("{:?}", e);
482                            if error_string != state.error_string {
483                                // Remove old status lines so as not to spam the output
484                                for _ in 0..state.error_line_count {
485                                    print!("\x1B[1A\x1B[2K");
486                                }
487                                state.error_line_count = 0;
488                                state.error_string = error_string;
489                                state.error_line_count = state.error_string.lines().count() + 1;
490                                if state.error_string.ends_with('\n') {
491                                    print!("{}", state.error_string);
492                                } else {
493                                    println!("{}", state.error_string);
494                                    state.error_line_count += 1;
495                                }
496                                println!("query error didn't match; sleeping to see if dataflow produces error shortly 🕑 {:.0?}", retry_state.next_backoff.unwrap_or_default());
497                            }
498                        }
499                    } else {
500                        for _ in 0..state.error_line_count {
501                            print!("\x1B[1A\x1B[2K");
502                        }
503                        state.error_line_count = 0;
504                    }
505                    (state, Err(e))
506                }
507            }
508        }})
509    .await;
510
511    // If a timeout was expected, check whether the retry operation timed
512    // out, which indicates that the test passed.
513    if let ErrorMatcher::Timeout = expected_error {
514        if let Err(e) = &res {
515            if e.is::<tokio::time::error::Elapsed>() {
516                println!("query timed out as expected");
517                return Ok(ControlFlow::Continue);
518            }
519        }
520    }
521
522    // Otherwise, return the error if any. Note that this is the error
523    // returned by the retry operation (e.g., "expected timeout, but query
524    // succeeded"), *not* an error returned from Materialize itself.
525    res?;
526    Ok(ControlFlow::Continue)
527}
528
529async fn try_run_fail_sql(
530    state: &State,
531    query: &str,
532    expected_error: &ErrorMatcher,
533    expected_detail: Option<&ErrorMatcher>,
534    expected_hint: Option<&ErrorMatcher>,
535) -> Result<(), anyhow::Error> {
536    match state.materialize.pgclient.query(query, &[]).await {
537        Ok(_) => bail!("query succeeded, but expected {}", expected_error),
538        Err(err) => match err.source().and_then(|err| err.downcast_ref::<DbError>()) {
539            Some(err) => {
540                let mut err_string = err.message().to_string();
541                if let Some(regex) = &state.regex {
542                    err_string = regex
543                        .replace_all(&err_string, state.regex_replacement.as_str())
544                        .to_string();
545                }
546                if !expected_error.is_match(&err_string) {
547                    bail!("expected {}, got {}", expected_error, err_string.quoted());
548                }
549
550                let check_additional =
551                    |extra: Option<&str>, matcher: Option<&ErrorMatcher>, type_| {
552                        let extra = extra.map(|s| s.to_string());
553                        match (extra, matcher) {
554                            (Some(extra), Some(expected)) => {
555                                if !expected.is_match(&extra) {
556                                    bail!(
557                                        "expected {}, got {}",
558                                        expected.fmt_with_type(type_),
559                                        extra.quoted()
560                                    );
561                                }
562                            }
563                            (None, Some(expected)) => {
564                                bail!("expected {}, but found none", expected.fmt_with_type(type_));
565                            }
566                            _ => {}
567                        }
568                        Ok(())
569                    };
570
571                check_additional(err.detail(), expected_detail, "DETAIL")?;
572                check_additional(err.hint(), expected_hint, "HINT")?;
573
574                Ok(())
575            }
576            None => Err(err.into()),
577        },
578    }
579}
580
581pub fn print_query(query: &str, stmt: Option<&Statement<Raw>>) {
582    use Statement::*;
583    if let Some(CreateSecret(_)) = stmt {
584        println!(
585            "> CREATE SECRET [query truncated on purpose so as to not reveal the secret in the log]"
586        );
587    } else {
588        println!("> {}", query)
589    }
590}
591
592// Returns the row after regex replacments, and the before, if its different
593pub fn decode_row(
594    state: &State,
595    row: Row,
596) -> Result<(Vec<String>, Option<Vec<String>>), anyhow::Error> {
597    enum ArrayElement<T> {
598        Null,
599        NonNull(T),
600    }
601
602    impl<T> fmt::Display for ArrayElement<T>
603    where
604        T: fmt::Display,
605    {
606        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
607            match self {
608                ArrayElement::Null => f.write_str("NULL"),
609                ArrayElement::NonNull(t) => t.fmt(f),
610            }
611        }
612    }
613
614    impl<'a, T> FromSql<'a> for ArrayElement<T>
615    where
616        T: FromSql<'a>,
617    {
618        fn from_sql(
619            ty: &Type,
620            raw: &'a [u8],
621        ) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
622            T::from_sql(ty, raw).map(ArrayElement::NonNull)
623        }
624
625        fn from_sql_null(_: &Type) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
626            Ok(ArrayElement::Null)
627        }
628
629        fn accepts(ty: &Type) -> bool {
630            T::accepts(ty)
631        }
632    }
633
634    /// This lets us:
635    /// - Continue using the default method of printing array elements while
636    /// preserving SQL-looking output w/ `dec::to_standard_notation_string`.
637    /// - Avoid upstreaming a complicated change to `rust-postgres-array`.
638    struct NumericStandardNotation(Numeric);
639
640    impl fmt::Display for NumericStandardNotation {
641        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
642            write!(f, "{}", self.0.0.0.to_standard_notation_string())
643        }
644    }
645
646    impl<'a> FromSql<'a> for NumericStandardNotation {
647        fn from_sql(
648            ty: &Type,
649            raw: &'a [u8],
650        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
651            Ok(NumericStandardNotation(Numeric::from_sql(ty, raw)?))
652        }
653
654        fn from_sql_null(
655            ty: &Type,
656        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
657            Ok(NumericStandardNotation(Numeric::from_sql_null(ty)?))
658        }
659
660        fn accepts(ty: &Type) -> bool {
661            Numeric::accepts(ty)
662        }
663    }
664
665    let mut out = vec![];
666    let mut raw_out = vec![];
667    for (i, col) in row.columns().iter().enumerate() {
668        let ty = col.type_();
669        let mut value: String = match *ty {
670            Type::ACLITEM => row.get::<_, Option<AclItem>>(i).map(|x| x.0),
671            Type::BOOL => row.get::<_, Option<bool>>(i).map(|x| x.to_string()),
672            Type::BPCHAR | Type::TEXT | Type::VARCHAR => row.get::<_, Option<String>>(i),
673            Type::TEXT_ARRAY => row
674                .get::<_, Option<Array<ArrayElement<String>>>>(i)
675                .map(|a| a.to_string()),
676            Type::BYTEA => row.get::<_, Option<Vec<u8>>>(i).map(|x| {
677                let s = x.into_iter().map(ascii::escape_default).flatten().collect();
678                String::from_utf8(s).unwrap()
679            }),
680            Type::CHAR => row.get::<_, Option<i8>>(i).map(|x| x.to_string()),
681            Type::INT2 => row.get::<_, Option<i16>>(i).map(|x| x.to_string()),
682            Type::INT4 => row.get::<_, Option<i32>>(i).map(|x| x.to_string()),
683            Type::INT8 => row.get::<_, Option<i64>>(i).map(|x| x.to_string()),
684            Type::OID => row.get::<_, Option<u32>>(i).map(|x| x.to_string()),
685            Type::NUMERIC => row
686                .get::<_, Option<NumericStandardNotation>>(i)
687                .map(|x| x.to_string()),
688            Type::FLOAT4 => row.get::<_, Option<f32>>(i).map(|x| x.to_string()),
689            Type::FLOAT8 => row.get::<_, Option<f64>>(i).map(|x| x.to_string()),
690            Type::TIMESTAMP => row
691                .get::<_, Option<chrono::NaiveDateTime>>(i)
692                .map(|x| x.to_string()),
693            Type::TIMESTAMPTZ => row
694                .get::<_, Option<chrono::DateTime<chrono::Utc>>>(i)
695                .map(|x| x.to_string()),
696            Type::DATE => row
697                .get::<_, Option<chrono::NaiveDate>>(i)
698                .map(|x| x.to_string()),
699            Type::TIME => row
700                .get::<_, Option<chrono::NaiveTime>>(i)
701                .map(|x| x.to_string()),
702            Type::INTERVAL => row.get::<_, Option<Interval>>(i).map(|x| x.to_string()),
703            Type::JSONB => row.get::<_, Option<Jsonb>>(i).map(|v| v.0.to_string()),
704            Type::UUID => row.get::<_, Option<uuid::Uuid>>(i).map(|v| v.to_string()),
705            Type::BOOL_ARRAY => row
706                .get::<_, Option<Array<ArrayElement<bool>>>>(i)
707                .map(|a| a.to_string()),
708            Type::INT2_ARRAY | Type::INT2_VECTOR => row
709                .get::<_, Option<Array<ArrayElement<i16>>>>(i)
710                .map(|a| a.to_string()),
711            Type::INT4_ARRAY => row
712                .get::<_, Option<Array<ArrayElement<i32>>>>(i)
713                .map(|a| a.to_string()),
714            Type::INT8_ARRAY => row
715                .get::<_, Option<Array<ArrayElement<i64>>>>(i)
716                .map(|a| a.to_string()),
717            Type::OID_ARRAY => row
718                .get::<_, Option<Array<ArrayElement<u32>>>>(i)
719                .map(|x| x.to_string()),
720            Type::NUMERIC_ARRAY => row
721                .get::<_, Option<Array<ArrayElement<NumericStandardNotation>>>>(i)
722                .map(|x| x.to_string()),
723            Type::FLOAT4_ARRAY => row
724                .get::<_, Option<Array<ArrayElement<f32>>>>(i)
725                .map(|x| x.to_string()),
726            Type::FLOAT8_ARRAY => row
727                .get::<_, Option<Array<ArrayElement<f64>>>>(i)
728                .map(|x| x.to_string()),
729            Type::TIMESTAMP_ARRAY => row
730                .get::<_, Option<Array<ArrayElement<chrono::NaiveDateTime>>>>(i)
731                .map(|x| x.to_string()),
732            Type::TIMESTAMPTZ_ARRAY => row
733                .get::<_, Option<Array<ArrayElement<chrono::DateTime<chrono::Utc>>>>>(i)
734                .map(|x| x.to_string()),
735            Type::DATE_ARRAY => row
736                .get::<_, Option<Array<ArrayElement<chrono::NaiveDate>>>>(i)
737                .map(|x| x.to_string()),
738            Type::TIME_ARRAY => row
739                .get::<_, Option<Array<ArrayElement<chrono::NaiveTime>>>>(i)
740                .map(|x| x.to_string()),
741            Type::INTERVAL_ARRAY => row
742                .get::<_, Option<Array<ArrayElement<Interval>>>>(i)
743                .map(|x| x.to_string()),
744            Type::JSONB_ARRAY => row
745                .get::<_, Option<Array<ArrayElement<Jsonb>>>>(i)
746                .map(|v| v.to_string()),
747            Type::UUID_ARRAY => row
748                .get::<_, Option<Array<ArrayElement<uuid::Uuid>>>>(i)
749                .map(|v| v.to_string()),
750            Type::INT4_RANGE => row.get::<_, Option<Range<i32>>>(i).map(|v| v.to_string()),
751            Type::INT4_RANGE_ARRAY => row
752                .get::<_, Option<Array<ArrayElement<Range<i32>>>>>(i)
753                .map(|v| v.to_string()),
754            Type::INT8_RANGE => row.get::<_, Option<Range<i64>>>(i).map(|v| v.to_string()),
755            Type::INT8_RANGE_ARRAY => row
756                .get::<_, Option<Array<ArrayElement<Range<i64>>>>>(i)
757                .map(|v| v.to_string()),
758            Type::NUM_RANGE => row
759                .get::<_, Option<Range<NumericStandardNotation>>>(i)
760                .map(|v| v.to_string()),
761            Type::NUM_RANGE_ARRAY => row
762                .get::<_, Option<Array<ArrayElement<Range<NumericStandardNotation>>>>>(i)
763                .map(|v| v.to_string()),
764            Type::DATE_RANGE => row
765                .get::<_, Option<Range<chrono::NaiveDate>>>(i)
766                .map(|v| v.to_string()),
767            Type::DATE_RANGE_ARRAY => row
768                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDate>>>>>(i)
769                .map(|v| v.to_string()),
770            Type::TS_RANGE => row
771                .get::<_, Option<Range<chrono::NaiveDateTime>>>(i)
772                .map(|v| v.to_string()),
773            Type::TS_RANGE_ARRAY => row
774                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDateTime>>>>>(i)
775                .map(|v| v.to_string()),
776            Type::TSTZ_RANGE => row
777                .get::<_, Option<Range<chrono::DateTime<chrono::Utc>>>>(i)
778                .map(|v| v.to_string()),
779            Type::TSTZ_RANGE_ARRAY => row
780                .get::<_, Option<Array<ArrayElement<Range<chrono::DateTime<chrono::Utc>>>>>>(i)
781                .map(|v| v.to_string()),
782            _ => match ty.oid() {
783                mz_pgrepr::oid::TYPE_UINT2_OID => {
784                    row.get::<_, Option<UInt2>>(i).map(|x| x.0.to_string())
785                }
786                mz_pgrepr::oid::TYPE_UINT4_OID => {
787                    row.get::<_, Option<UInt4>>(i).map(|x| x.0.to_string())
788                }
789                mz_pgrepr::oid::TYPE_UINT8_OID => {
790                    row.get::<_, Option<UInt8>>(i).map(|x| x.0.to_string())
791                }
792                mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID => {
793                    row.get::<_, Option<MzTimestamp>>(i).map(|x| x.0)
794                }
795                _ => bail!("unsupported SQL type in testdrive: {:?}", ty),
796            },
797        }
798        .unwrap_or_else(|| "<null>".into());
799
800        raw_out.push(value.clone());
801        if let Some(regex) = &state.regex {
802            value = regex
803                .replace_all(&value, state.regex_replacement.as_str())
804                .to_string();
805        }
806
807        out.push(value);
808    }
809    let raw_out = if out != raw_out { Some(raw_out) } else { None };
810    Ok((out, raw_out))
811}
812
813struct MzTimestamp(String);
814
815impl<'a> FromSql<'a> for MzTimestamp {
816    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
817        Ok(MzTimestamp(std::str::from_utf8(raw)?.to_string()))
818    }
819
820    fn accepts(ty: &Type) -> bool {
821        ty.oid() == mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID
822    }
823}
824
825#[allow(dead_code)]
826struct MzAclItem(String);
827
828impl<'a> FromSql<'a> for MzAclItem {
829    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
830        Ok(MzAclItem(std::str::from_utf8(raw)?.to_string()))
831    }
832
833    fn accepts(ty: &Type) -> bool {
834        ty.oid() == mz_pgrepr::oid::TYPE_MZ_ACL_ITEM_OID
835    }
836}
837
838struct AclItem(String);
839
840impl<'a> FromSql<'a> for AclItem {
841    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
842        Ok(AclItem(std::str::from_utf8(raw)?.to_string()))
843    }
844
845    fn accepts(ty: &Type) -> bool {
846        ty.oid() == 1033
847    }
848}
849
850struct TestdriveRow<'a>(&'a Vec<String>);
851
852impl Display for TestdriveRow<'_> {
853    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
854        let mut cols = Vec::<String>::new();
855
856        for col_str in &self.0[0..self.0.len()] {
857            if col_str.contains(' ') || col_str.contains('"') || col_str.is_empty() {
858                cols.push(escape_for_parser(col_str));
859            } else {
860                cols.push(col_str.to_string());
861            }
862        }
863
864        write!(f, "{}", cols.join(" "))
865    }
866}