mz_testdrive/action/
sql.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::ascii;
11use std::error::Error;
12use std::fmt::{self, Display, Formatter, Write as _};
13use std::io::{self, Write};
14use std::time::SystemTime;
15
16use anyhow::{Context, bail};
17use md5::{Digest, Md5};
18use mz_ore::collections::CollectionExt;
19use mz_ore::retry::Retry;
20use mz_ore::str::StrExt;
21use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8};
22use mz_repr::adt::range::Range;
23use mz_sql_parser::ast::{Raw, Statement};
24use postgres_array::Array;
25use regex::Regex;
26use tokio_postgres::error::DbError;
27use tokio_postgres::row::Row;
28use tokio_postgres::types::{FromSql, Type};
29
30use crate::action::{ControlFlow, Rewrite, State};
31use crate::parser::{FailSqlCommand, SqlCommand, SqlExpectedError, SqlOutput};
32
33pub async fn run_sql(mut cmd: SqlCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
34    use Statement::*;
35
36    state.rewrite_pos_start = cmd.expected_start;
37    state.rewrite_pos_end = cmd.expected_end;
38
39    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
40        .with_context(|| format!("unable to parse SQL: {}", cmd.query))?;
41    if stmts.len() != 1 {
42        bail!("expected one statement, but got {}", stmts.len());
43    }
44    let stmt = stmts.into_element().ast;
45    if let SqlOutput::Full { expected_rows, .. } = &mut cmd.expected_output {
46        // TODO(benesch): one day we'll support SQL queries where order matters.
47        expected_rows.sort();
48    }
49
50    let should_retry = match &stmt {
51        // Do not retry FETCH statements as subsequent executions are likely
52        // to return an empty result. The original result would thus be lost.
53        Fetch(_) => false,
54        // EXPLAIN ... PLAN statements should always provide the expected result
55        // on the first try
56        ExplainPlan(_) => false,
57        // DDL statements should always provide the expected result on the first try
58        CreateConnection(_)
59        | CreateCluster(_)
60        | CreateClusterReplica(_)
61        | CreateDatabase(_)
62        | CreateSchema(_)
63        | CreateSource(_)
64        | CreateSink(_)
65        | CreateMaterializedView(_)
66        | CreateView(_)
67        | CreateTable(_)
68        | CreateTableFromSource(_)
69        | CreateIndex(_)
70        | CreateType(_)
71        | CreateRole(_)
72        | AlterObjectRename(_)
73        | AlterIndex(_)
74        | AlterSink(_)
75        | Discard(_)
76        | DropObjects(_)
77        | SetVariable(_) => false,
78        _ => true,
79    };
80
81    let query = &cmd.query;
82    print_query(query, Some(&stmt));
83    let expected_output = &cmd.expected_output;
84    let (state, res) = match should_retry {
85        true => Retry::default()
86            .initial_backoff(state.initial_backoff)
87            .factor(state.backoff_factor)
88            .max_duration(state.timeout)
89            .max_tries(state.max_tries),
90        false => Retry::default().max_duration(state.timeout).max_tries(1),
91    }
92    .retry_async_with_state(state, |retry_state, state| async move {
93        let should_continue = retry_state.i + 1 < state.max_tries && should_retry;
94        let start = SystemTime::now();
95        match try_run_sql(state, query, expected_output, should_continue).await {
96            Ok(()) => {
97                let now = SystemTime::now();
98                let epoch = SystemTime::UNIX_EPOCH;
99                let ts = now.duration_since(epoch).unwrap().as_secs_f64();
100                let delay = now.duration_since(start).unwrap().as_secs_f64();
101                if retry_state.i != 0 {
102                    println!();
103                }
104                println!("rows match; continuing at ts {ts}, took {delay}s");
105                (state, Ok(()))
106            }
107            Err(e) => {
108                if retry_state.i == 0 && should_retry {
109                    print!("rows didn't match; sleeping to see if dataflow catches up");
110                }
111                if let Some(backoff) = retry_state.next_backoff {
112                    if !backoff.is_zero() {
113                        print!(" {:.0?}", backoff);
114                        io::stdout().flush().unwrap();
115                    }
116                }
117                (state, Err(e))
118            }
119        }
120    })
121    .await;
122    if let Err(e) = res {
123        println!();
124        return Err(e);
125    }
126    if state.consistency_checks == super::consistency::Level::Statement {
127        super::consistency::run_consistency_checks(state).await?;
128    }
129
130    Ok(ControlFlow::Continue)
131}
132
133fn rewrite_result(
134    state: &mut State,
135    columns: Vec<&str>,
136    content: Vec<Vec<String>>,
137) -> Result<(), anyhow::Error> {
138    let mut buf = String::new();
139    writeln!(buf, "{}", columns.join(" "))?;
140    writeln!(buf, "----")?;
141    for row in content {
142        let mut formatted_row = Vec::<String>::new();
143        for value in row {
144            if value.is_empty() || value.contains(|x: char| char::is_ascii_whitespace(&x)) {
145                formatted_row.push("\"".to_owned() + &value + "\"");
146            } else {
147                formatted_row.push(value);
148            }
149        }
150        writeln!(buf, "{}", formatted_row.join(" "))?;
151    }
152    state.rewrites.push(Rewrite {
153        content: buf,
154        start: state.rewrite_pos_start,
155        end: state.rewrite_pos_end,
156    });
157
158    Ok(())
159}
160
161async fn try_run_sql(
162    state: &mut State,
163    query: &str,
164    expected_output: &SqlOutput,
165    should_retry: bool,
166) -> Result<(), anyhow::Error> {
167    let stmt = state
168        .materialize
169        .pgclient
170        .prepare(query)
171        .await
172        .context("preparing query failed")?;
173
174    let query_with_timeout = tokio::time::timeout(
175        state.timeout.clone(),
176        state.materialize.pgclient.query(&stmt, &[]),
177    )
178    .await;
179
180    if query_with_timeout.is_err() {
181        bail!("query timed out")
182    }
183
184    let rows: Vec<_> = query_with_timeout
185        .unwrap()
186        .context("executing query failed")?
187        .into_iter()
188        .map(|row| decode_row(state, row))
189        .collect::<Result<_, _>>()?;
190
191    let (mut actual, raw_actual): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
192
193    let raw_actual: Option<Vec<_>> = if raw_actual.iter().any(|r| r.is_some()) {
194        // TODO(guswynn): Note we don't sort the raw rows, because
195        // there is no easy way of ensuring they sort the same way as actual.
196        Some(
197            actual
198                .iter()
199                .zip(raw_actual.into_iter())
200                .map(|(actual, unreplaced)| match unreplaced {
201                    Some(raw_row) => raw_row,
202                    None => actual.clone(),
203                })
204                .collect(),
205        )
206    } else {
207        None
208    };
209
210    actual.sort();
211    let actual_columns: Vec<_> = stmt.columns().iter().map(|c| c.name()).collect();
212
213    match expected_output {
214        SqlOutput::Full {
215            expected_rows,
216            column_names,
217        } => {
218            if let Some(column_names) = column_names {
219                if actual_columns.iter().ne(column_names) {
220                    if state.rewrite_results && !should_retry {
221                        rewrite_result(state, actual_columns, actual)?;
222                        return Ok(());
223                    } else {
224                        bail!(
225                            "column name mismatch\nexpected: {:?}\nactual:   {:?}",
226                            column_names,
227                            actual_columns
228                        );
229                    }
230                }
231            }
232            if &actual == expected_rows {
233                Ok(())
234            } else if state.rewrite_results && !should_retry {
235                rewrite_result(state, actual_columns, actual)?;
236                Ok(())
237            } else {
238                let (mut left, mut right) = (0, 0);
239                let mut buf = String::new();
240                while let (Some(e), Some(a)) = (expected_rows.get(left), actual.get(right)) {
241                    match e.cmp(a) {
242                        std::cmp::Ordering::Less => {
243                            writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
244                            left += 1;
245                        }
246                        std::cmp::Ordering::Equal => {
247                            left += 1;
248                            right += 1;
249                        }
250                        std::cmp::Ordering::Greater => {
251                            writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
252                            right += 1;
253                        }
254                    }
255                }
256                while let Some(e) = expected_rows.get(left) {
257                    writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
258                    left += 1;
259                }
260                while let Some(a) = actual.get(right) {
261                    writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
262                    right += 1;
263                }
264                if state.rewrite_results && !should_retry {
265                    rewrite_result(state, actual_columns, actual)?;
266                    Ok(())
267                } else if let Some(raw_actual) = raw_actual {
268                    bail!(
269                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\ngot raw rows:\n{:?}\nPoor diff:\n{}",
270                        expected_rows,
271                        actual,
272                        raw_actual,
273                        buf,
274                    )
275                } else {
276                    bail!(
277                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\nPoor diff:\n{}",
278                        expected_rows,
279                        actual,
280                        buf
281                    )
282                }
283            }
284        }
285        SqlOutput::Hashed { num_values, md5 } => {
286            if &actual.len() != num_values {
287                bail!(
288                    "wrong row count: expected:\n{:?}\ngot:\n{:?}\n",
289                    num_values,
290                    actual.len(),
291                )
292            } else {
293                let mut hasher = Md5::new();
294                for row in &actual {
295                    for entry in row {
296                        hasher.update(entry);
297                    }
298                }
299                let actual = format!("{:x}", hasher.finalize());
300                if &actual != md5 {
301                    bail!("wrong hash value: expected:{:?} got:{:?}", md5, actual)
302                } else {
303                    Ok(())
304                }
305            }
306        }
307    }
308}
309
310enum ErrorMatcher {
311    Contains(String),
312    Exact(String),
313    Regex(Regex),
314    Timeout,
315}
316
317impl ErrorMatcher {
318    fn is_match(&self, err: &String) -> bool {
319        match self {
320            ErrorMatcher::Contains(s) => err.contains(s),
321            ErrorMatcher::Exact(s) => err == s,
322            ErrorMatcher::Regex(r) => r.is_match(err),
323            // Timeouts never match errors directly. If we are matching an error
324            // message, it means the query returned a result (i.e., an error
325            // result), which means the query did not time out as expected.
326            ErrorMatcher::Timeout => false,
327        }
328    }
329}
330
331impl fmt::Display for ErrorMatcher {
332    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
333        match self {
334            ErrorMatcher::Contains(s) => write!(f, "error containing {}", s.quoted()),
335            ErrorMatcher::Exact(s) => write!(f, "exact error {}", s.quoted()),
336            ErrorMatcher::Regex(s) => write!(f, "error matching regex {}", s.as_str().quoted()),
337            ErrorMatcher::Timeout => f.write_str("timeout"),
338        }
339    }
340}
341
342impl ErrorMatcher {
343    fn fmt_with_type(&self, type_: &str) -> String {
344        match self {
345            ErrorMatcher::Contains(s) => format!("{} containing {}", type_, s.quoted()),
346            ErrorMatcher::Exact(s) => format!("exact {} {}", type_, s.quoted()),
347            ErrorMatcher::Regex(s) => format!("{} matching regex {}", type_, s.as_str().quoted()),
348            ErrorMatcher::Timeout => "timeout".to_string(),
349        }
350    }
351}
352
353pub async fn run_fail_sql(
354    cmd: FailSqlCommand,
355    state: &State,
356) -> Result<ControlFlow, anyhow::Error> {
357    use Statement::{AlterSink, Commit, CreateConnection, Fetch, Rollback};
358
359    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
360        .map_err(|e| format!("unable to parse SQL: {}: {}", cmd.query, e));
361
362    // Allow for statements that could not be parsed.
363    // This way such statements can be used for negative testing in .td files
364    let stmt = match stmts {
365        Ok(s) => {
366            if s.len() != 1 {
367                bail!("expected one statement, but got {}", s.len());
368            }
369            Some(s.into_element().ast)
370        }
371        Err(_) => None,
372    };
373
374    let expected_error = match cmd.expected_error {
375        SqlExpectedError::Contains(s) => ErrorMatcher::Contains(s),
376        SqlExpectedError::Exact(s) => ErrorMatcher::Exact(s),
377        SqlExpectedError::Regex(s) => ErrorMatcher::Regex(s.parse()?),
378        SqlExpectedError::Timeout => ErrorMatcher::Timeout,
379    };
380    let expected_detail = cmd.expected_detail.map(ErrorMatcher::Contains);
381    let expected_hint = cmd.expected_hint.map(ErrorMatcher::Contains);
382
383    let query = &cmd.query;
384    print_query(query, stmt.as_ref());
385
386    let should_retry = match &stmt {
387        // Do not retry statements that could not be parsed
388        None => false,
389        // Do not retry COMMIT and ROLLBACK. Once the transaction has errored out and has
390        // been aborted, retrying COMMIT or ROLLBACK will actually start succeeding, which
391        // causes testdrive to emit a confusing "query succeded but expected error" message.
392        Some(Commit(_)) | Some(Rollback(_)) => false,
393        // FETCH should not be retried because it consumes data on each response.
394        Some(Fetch(_)) => false,
395        Some(AlterSink(_)) => false,
396        Some(CreateConnection(_)) => false,
397        Some(_) => true,
398    };
399
400    let state = &state;
401    let res = match should_retry {
402        true => Retry::default()
403            .initial_backoff(state.initial_backoff)
404            .factor(state.backoff_factor)
405            .max_duration(state.timeout)
406            .max_tries(state.max_tries),
407        false => Retry::default().max_duration(state.timeout).max_tries(1),
408    }
409    .retry_async_canceling(|retry_state| {
410        let expected_error = &expected_error;
411        let expected_detail = &expected_detail;
412        let expected_hint = &expected_hint;
413        async move {
414            match try_run_fail_sql(
415                state,
416                query,
417                expected_error,
418                expected_detail.as_ref(),
419                expected_hint.as_ref(),
420            )
421            .await
422            {
423                Ok(()) => {
424                    if retry_state.i != 0 {
425                        println!();
426                    }
427                    println!("query error matches; continuing");
428                    Ok(())
429                }
430                Err(e) => {
431                    if retry_state.i == 0 && should_retry {
432                        print!(
433                            "query error didn't match; \
434                                sleeping to see if dataflow produces error shortly"
435                        );
436                    }
437                    if let Some(backoff) = retry_state.next_backoff {
438                        print!(" {:.0?}", backoff);
439                        io::stdout().flush().unwrap();
440                    } else {
441                        println!();
442                    }
443                    Err(e)
444                }
445            }
446        }
447    })
448    .await;
449
450    // If a timeout was expected, check whether the retry operation timed
451    // out, which indicates that the test passed.
452    if let ErrorMatcher::Timeout = expected_error {
453        if let Err(e) = &res {
454            if e.is::<tokio::time::error::Elapsed>() {
455                println!("query timed out as expected");
456                return Ok(ControlFlow::Continue);
457            }
458        }
459    }
460
461    // Otherwise, return the error if any. Note that this is the error
462    // returned by the retry operation (e.g., "expected timeout, but query
463    // succeeded"), *not* an error returned from Materialize itself.
464    res?;
465    Ok(ControlFlow::Continue)
466}
467
468async fn try_run_fail_sql(
469    state: &State,
470    query: &str,
471    expected_error: &ErrorMatcher,
472    expected_detail: Option<&ErrorMatcher>,
473    expected_hint: Option<&ErrorMatcher>,
474) -> Result<(), anyhow::Error> {
475    match state.materialize.pgclient.query(query, &[]).await {
476        Ok(_) => bail!("query succeeded, but expected {}", expected_error),
477        Err(err) => match err.source().and_then(|err| err.downcast_ref::<DbError>()) {
478            Some(err) => {
479                let mut err_string = err.message().to_string();
480                if let Some(regex) = &state.regex {
481                    err_string = regex
482                        .replace_all(&err_string, state.regex_replacement.as_str())
483                        .to_string();
484                }
485                if !expected_error.is_match(&err_string) {
486                    bail!("expected {}, got {}", expected_error, err_string.quoted());
487                }
488
489                let check_additional =
490                    |extra: Option<&str>, matcher: Option<&ErrorMatcher>, type_| {
491                        let extra = extra.map(|s| s.to_string());
492                        match (extra, matcher) {
493                            (Some(extra), Some(expected)) => {
494                                if !expected.is_match(&extra) {
495                                    bail!(
496                                        "expected {}, got {}",
497                                        expected.fmt_with_type(type_),
498                                        extra.quoted()
499                                    );
500                                }
501                            }
502                            (None, Some(expected)) => {
503                                bail!("expected {}, but found none", expected.fmt_with_type(type_));
504                            }
505                            _ => {}
506                        }
507                        Ok(())
508                    };
509
510                check_additional(err.detail(), expected_detail, "DETAIL")?;
511                check_additional(err.hint(), expected_hint, "HINT")?;
512
513                Ok(())
514            }
515            None => Err(err.into()),
516        },
517    }
518}
519
520pub fn print_query(query: &str, stmt: Option<&Statement<Raw>>) {
521    use Statement::*;
522    if let Some(CreateSecret(_)) = stmt {
523        println!(
524            "> CREATE SECRET [query truncated on purpose so as to not reveal the secret in the log]"
525        );
526    } else {
527        println!("> {}", query)
528    }
529}
530
531// Returns the row after regex replacments, and the before, if its different
532pub fn decode_row(
533    state: &State,
534    row: Row,
535) -> Result<(Vec<String>, Option<Vec<String>>), anyhow::Error> {
536    enum ArrayElement<T> {
537        Null,
538        NonNull(T),
539    }
540
541    impl<T> fmt::Display for ArrayElement<T>
542    where
543        T: fmt::Display,
544    {
545        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
546            match self {
547                ArrayElement::Null => f.write_str("NULL"),
548                ArrayElement::NonNull(t) => t.fmt(f),
549            }
550        }
551    }
552
553    impl<'a, T> FromSql<'a> for ArrayElement<T>
554    where
555        T: FromSql<'a>,
556    {
557        fn from_sql(
558            ty: &Type,
559            raw: &'a [u8],
560        ) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
561            T::from_sql(ty, raw).map(ArrayElement::NonNull)
562        }
563
564        fn from_sql_null(_: &Type) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
565            Ok(ArrayElement::Null)
566        }
567
568        fn accepts(ty: &Type) -> bool {
569            T::accepts(ty)
570        }
571    }
572
573    /// This lets us:
574    /// - Continue using the default method of printing array elements while
575    /// preserving SQL-looking output w/ `dec::to_standard_notation_string`.
576    /// - Avoid upstreaming a complicated change to `rust-postgres-array`.
577    struct NumericStandardNotation(Numeric);
578
579    impl fmt::Display for NumericStandardNotation {
580        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
581            write!(f, "{}", self.0.0.0.to_standard_notation_string())
582        }
583    }
584
585    impl<'a> FromSql<'a> for NumericStandardNotation {
586        fn from_sql(
587            ty: &Type,
588            raw: &'a [u8],
589        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
590            Ok(NumericStandardNotation(Numeric::from_sql(ty, raw)?))
591        }
592
593        fn from_sql_null(
594            ty: &Type,
595        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
596            Ok(NumericStandardNotation(Numeric::from_sql_null(ty)?))
597        }
598
599        fn accepts(ty: &Type) -> bool {
600            Numeric::accepts(ty)
601        }
602    }
603
604    let mut out = vec![];
605    let mut raw_out = vec![];
606    for (i, col) in row.columns().iter().enumerate() {
607        let ty = col.type_();
608        let mut value: String = match *ty {
609            Type::ACLITEM => row.get::<_, Option<AclItem>>(i).map(|x| x.0),
610            Type::BOOL => row.get::<_, Option<bool>>(i).map(|x| x.to_string()),
611            Type::BPCHAR | Type::TEXT | Type::VARCHAR => row.get::<_, Option<String>>(i),
612            Type::TEXT_ARRAY => row
613                .get::<_, Option<Array<ArrayElement<String>>>>(i)
614                .map(|a| a.to_string()),
615            Type::BYTEA => row.get::<_, Option<Vec<u8>>>(i).map(|x| {
616                let s = x.into_iter().map(ascii::escape_default).flatten().collect();
617                String::from_utf8(s).unwrap()
618            }),
619            Type::CHAR => row.get::<_, Option<i8>>(i).map(|x| x.to_string()),
620            Type::INT2 => row.get::<_, Option<i16>>(i).map(|x| x.to_string()),
621            Type::INT4 => row.get::<_, Option<i32>>(i).map(|x| x.to_string()),
622            Type::INT8 => row.get::<_, Option<i64>>(i).map(|x| x.to_string()),
623            Type::OID => row.get::<_, Option<u32>>(i).map(|x| x.to_string()),
624            Type::NUMERIC => row
625                .get::<_, Option<NumericStandardNotation>>(i)
626                .map(|x| x.to_string()),
627            Type::FLOAT4 => row.get::<_, Option<f32>>(i).map(|x| x.to_string()),
628            Type::FLOAT8 => row.get::<_, Option<f64>>(i).map(|x| x.to_string()),
629            Type::TIMESTAMP => row
630                .get::<_, Option<chrono::NaiveDateTime>>(i)
631                .map(|x| x.to_string()),
632            Type::TIMESTAMPTZ => row
633                .get::<_, Option<chrono::DateTime<chrono::Utc>>>(i)
634                .map(|x| x.to_string()),
635            Type::DATE => row
636                .get::<_, Option<chrono::NaiveDate>>(i)
637                .map(|x| x.to_string()),
638            Type::TIME => row
639                .get::<_, Option<chrono::NaiveTime>>(i)
640                .map(|x| x.to_string()),
641            Type::INTERVAL => row.get::<_, Option<Interval>>(i).map(|x| x.to_string()),
642            Type::JSONB => row.get::<_, Option<Jsonb>>(i).map(|v| v.0.to_string()),
643            Type::UUID => row.get::<_, Option<uuid::Uuid>>(i).map(|v| v.to_string()),
644            Type::BOOL_ARRAY => row
645                .get::<_, Option<Array<ArrayElement<bool>>>>(i)
646                .map(|a| a.to_string()),
647            Type::INT2_ARRAY => row
648                .get::<_, Option<Array<ArrayElement<i16>>>>(i)
649                .map(|a| a.to_string()),
650            Type::INT4_ARRAY => row
651                .get::<_, Option<Array<ArrayElement<i32>>>>(i)
652                .map(|a| a.to_string()),
653            Type::INT8_ARRAY => row
654                .get::<_, Option<Array<ArrayElement<i64>>>>(i)
655                .map(|a| a.to_string()),
656            Type::OID_ARRAY => row
657                .get::<_, Option<Array<ArrayElement<u32>>>>(i)
658                .map(|x| x.to_string()),
659            Type::NUMERIC_ARRAY => row
660                .get::<_, Option<Array<ArrayElement<NumericStandardNotation>>>>(i)
661                .map(|x| x.to_string()),
662            Type::FLOAT4_ARRAY => row
663                .get::<_, Option<Array<ArrayElement<f32>>>>(i)
664                .map(|x| x.to_string()),
665            Type::FLOAT8_ARRAY => row
666                .get::<_, Option<Array<ArrayElement<f64>>>>(i)
667                .map(|x| x.to_string()),
668            Type::TIMESTAMP_ARRAY => row
669                .get::<_, Option<Array<ArrayElement<chrono::NaiveDateTime>>>>(i)
670                .map(|x| x.to_string()),
671            Type::TIMESTAMPTZ_ARRAY => row
672                .get::<_, Option<Array<ArrayElement<chrono::DateTime<chrono::Utc>>>>>(i)
673                .map(|x| x.to_string()),
674            Type::DATE_ARRAY => row
675                .get::<_, Option<Array<ArrayElement<chrono::NaiveDate>>>>(i)
676                .map(|x| x.to_string()),
677            Type::TIME_ARRAY => row
678                .get::<_, Option<Array<ArrayElement<chrono::NaiveTime>>>>(i)
679                .map(|x| x.to_string()),
680            Type::INTERVAL_ARRAY => row
681                .get::<_, Option<Array<ArrayElement<Interval>>>>(i)
682                .map(|x| x.to_string()),
683            Type::JSONB_ARRAY => row
684                .get::<_, Option<Array<ArrayElement<Jsonb>>>>(i)
685                .map(|v| v.to_string()),
686            Type::UUID_ARRAY => row
687                .get::<_, Option<Array<ArrayElement<uuid::Uuid>>>>(i)
688                .map(|v| v.to_string()),
689            Type::INT4_RANGE => row.get::<_, Option<Range<i32>>>(i).map(|v| v.to_string()),
690            Type::INT4_RANGE_ARRAY => row
691                .get::<_, Option<Array<ArrayElement<Range<i32>>>>>(i)
692                .map(|v| v.to_string()),
693            Type::INT8_RANGE => row.get::<_, Option<Range<i64>>>(i).map(|v| v.to_string()),
694            Type::INT8_RANGE_ARRAY => row
695                .get::<_, Option<Array<ArrayElement<Range<i64>>>>>(i)
696                .map(|v| v.to_string()),
697            Type::NUM_RANGE => row
698                .get::<_, Option<Range<NumericStandardNotation>>>(i)
699                .map(|v| v.to_string()),
700            Type::NUM_RANGE_ARRAY => row
701                .get::<_, Option<Array<ArrayElement<Range<NumericStandardNotation>>>>>(i)
702                .map(|v| v.to_string()),
703            Type::DATE_RANGE => row
704                .get::<_, Option<Range<chrono::NaiveDate>>>(i)
705                .map(|v| v.to_string()),
706            Type::DATE_RANGE_ARRAY => row
707                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDate>>>>>(i)
708                .map(|v| v.to_string()),
709            Type::TS_RANGE => row
710                .get::<_, Option<Range<chrono::NaiveDateTime>>>(i)
711                .map(|v| v.to_string()),
712            Type::TS_RANGE_ARRAY => row
713                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDateTime>>>>>(i)
714                .map(|v| v.to_string()),
715            Type::TSTZ_RANGE => row
716                .get::<_, Option<Range<chrono::DateTime<chrono::Utc>>>>(i)
717                .map(|v| v.to_string()),
718            Type::TSTZ_RANGE_ARRAY => row
719                .get::<_, Option<Array<ArrayElement<Range<chrono::DateTime<chrono::Utc>>>>>>(i)
720                .map(|v| v.to_string()),
721            _ => match ty.oid() {
722                mz_pgrepr::oid::TYPE_UINT2_OID => {
723                    row.get::<_, Option<UInt2>>(i).map(|x| x.0.to_string())
724                }
725                mz_pgrepr::oid::TYPE_UINT4_OID => {
726                    row.get::<_, Option<UInt4>>(i).map(|x| x.0.to_string())
727                }
728                mz_pgrepr::oid::TYPE_UINT8_OID => {
729                    row.get::<_, Option<UInt8>>(i).map(|x| x.0.to_string())
730                }
731                mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID => {
732                    row.get::<_, Option<MzTimestamp>>(i).map(|x| x.0)
733                }
734                _ => bail!("unsupported SQL type in testdrive: {:?}", ty),
735            },
736        }
737        .unwrap_or_else(|| "<null>".into());
738
739        raw_out.push(value.clone());
740        if let Some(regex) = &state.regex {
741            value = regex
742                .replace_all(&value, state.regex_replacement.as_str())
743                .to_string();
744        }
745
746        out.push(value);
747    }
748    let raw_out = if out != raw_out { Some(raw_out) } else { None };
749    Ok((out, raw_out))
750}
751
752struct MzTimestamp(String);
753
754impl<'a> FromSql<'a> for MzTimestamp {
755    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
756        Ok(MzTimestamp(std::str::from_utf8(raw)?.to_string()))
757    }
758
759    fn accepts(ty: &Type) -> bool {
760        ty.oid() == mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID
761    }
762}
763
764struct MzAclItem(#[allow(dead_code)] String);
765
766impl<'a> FromSql<'a> for MzAclItem {
767    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
768        Ok(MzAclItem(std::str::from_utf8(raw)?.to_string()))
769    }
770
771    fn accepts(ty: &Type) -> bool {
772        ty.oid() == mz_pgrepr::oid::TYPE_MZ_ACL_ITEM_OID
773    }
774}
775
776struct AclItem(String);
777
778impl<'a> FromSql<'a> for AclItem {
779    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
780        Ok(AclItem(std::str::from_utf8(raw)?.to_string()))
781    }
782
783    fn accepts(ty: &Type) -> bool {
784        ty.oid() == 1033
785    }
786}
787
788struct TestdriveRow<'a>(&'a Vec<String>);
789
790impl Display for TestdriveRow<'_> {
791    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
792        let mut cols = Vec::<String>::new();
793
794        for col_str in &self.0[0..self.0.len()] {
795            if col_str.contains(' ') || col_str.contains('"') || col_str.is_empty() {
796                cols.push(format!("{:?}", col_str));
797            } else {
798                cols.push(col_str.to_string());
799            }
800        }
801
802        write!(f, "{}", cols.join(" "))
803    }
804}