1use std::ascii;
11use std::error::Error;
12use std::fmt::{self, Display, Formatter, Write as _};
13use std::time::SystemTime;
14
15use anyhow::{Context, bail};
16use itertools::Itertools;
17use md5::{Digest, Md5};
18use mz_ore::collections::CollectionExt;
19use mz_ore::retry::Retry;
20use mz_ore::str::StrExt;
21use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8};
22use mz_repr::adt::range::Range;
23use mz_sql_parser::ast::{Raw, Statement};
24use postgres_array::Array;
25use regex::Regex;
26use tokio_postgres::error::DbError;
27use tokio_postgres::row::Row;
28use tokio_postgres::types::{FromSql, Type};
29
30use crate::action::{ControlFlow, Rewrite, State};
31use crate::parser::{FailSqlCommand, SqlCommand, SqlExpectedError, SqlOutput};
32
33pub async fn run_sql(mut cmd: SqlCommand, state: &mut State) -> Result<ControlFlow, anyhow::Error> {
34    use Statement::*;
35
36    state.rewrite_pos_start = cmd.expected_start;
37    state.rewrite_pos_end = cmd.expected_end;
38
39    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
40        .with_context(|| format!("unable to parse SQL: {}", cmd.query))?;
41    if stmts.len() != 1 {
42        bail!("expected one statement, but got {}", stmts.len());
43    }
44    let stmt = stmts.into_element().ast;
45    if let SqlOutput::Full { expected_rows, .. } = &mut cmd.expected_output {
46        expected_rows.sort();
48    }
49
50    let should_retry = match &stmt {
51        Fetch(_) => false,
54        ExplainPlan(_) => false,
57        CreateConnection(_)
59        | CreateCluster(_)
60        | CreateClusterReplica(_)
61        | CreateDatabase(_)
62        | CreateSchema(_)
63        | CreateSource(_)
64        | CreateWebhookSource(_)
65        | CreateSink(_)
66        | CreateMaterializedView(_)
67        | CreateView(_)
68        | CreateTable(_)
69        | CreateTableFromSource(_)
70        | CreateIndex(_)
71        | CreateType(_)
72        | CreateRole(_)
73        | AlterObjectRename(_)
74        | AlterIndex(_)
75        | AlterSink(_)
76        | Discard(_)
77        | DropObjects(_)
78        | SetVariable(_) => false,
79        _ => true,
80    };
81
82    let query = &cmd.query;
83    print_query(query, Some(&stmt));
84    let expected_output = &cmd.expected_output;
85    state.error_line_count = 0;
86    state.error_string = "".to_string();
87    let (state, res) = match should_retry {
88        true => Retry::default()
89            .initial_backoff(state.initial_backoff)
90            .factor(state.backoff_factor)
91            .max_duration(state.timeout)
92            .max_tries(state.max_tries),
93        false => Retry::default().max_duration(state.timeout).max_tries(1),
94    }
95    .retry_async_with_state(state, |retry_state, state| async move {
96        let should_continue = retry_state.i + 1 < state.max_tries && should_retry;
97        let start = SystemTime::now();
98        match try_run_sql(state, query, expected_output, should_continue).await {
99            Ok(()) => {
100                let now = SystemTime::now();
101                let epoch = SystemTime::UNIX_EPOCH;
102                let ts = now.duration_since(epoch).unwrap().as_secs_f64();
103                let delay = now.duration_since(start).unwrap().as_secs_f64();
104                println!("rows match; continuing at ts {ts}, took {delay}s");
105                (state, Ok(()))
106            }
107            Err(e) => {
108                if let Some(backoff) = retry_state.next_backoff {
109                    if !backoff.is_zero()
110                        && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI"))
111                    {
112                        let error_string = format!("{:?}", e);
113                        if error_string != state.error_string {
114                            for _ in 0..state.error_line_count {
116                                print!("\x1B[1A\x1B[2K");
117                            }
118                            state.error_line_count = 0;
119                            state.error_string = error_string;
120                            state.error_line_count = state.error_string.lines().count() + 1;
121                            if state.error_string.ends_with('\n') {
122                                print!("{}", state.error_string);
123                            } else {
124                                println!("{}", state.error_string);
125                                state.error_line_count += 1;
126                            }
127                            println!(
128                                "rows didn't match; sleeping to see if dataflow catches up 🕑 {:.0?}",
129                                retry_state.next_backoff.unwrap_or_default()
130                            );
131                        }
132                    }
133                } else {
134                    for _ in 0..state.error_line_count {
135                        print!("\x1B[1A\x1B[2K");
136                    }
137                    state.error_line_count = 0;
138                }
139                (state, Err(e))
140            }
141        }
142    })
143    .await;
144    if let Err(e) = res {
145        return Err(e);
146    }
147    if state.consistency_checks == super::consistency::Level::Statement {
148        super::consistency::run_consistency_checks(state).await?;
149    }
150
151    Ok(ControlFlow::Continue)
152}
153
154fn rewrite_result(
155    state: &mut State,
156    columns: Vec<&str>,
157    content: Vec<Vec<String>>,
158) -> Result<(), anyhow::Error> {
159    let mut buf = String::new();
160    writeln!(buf, "{}", columns.join(" "))?;
161    writeln!(buf, "----")?;
162    for row in content {
163        let mut formatted_row = Vec::<String>::new();
164        for value in row {
165            if value.is_empty() || value.contains(|x: char| char::is_ascii_whitespace(&x)) {
166                formatted_row.push("\"".to_owned() + &value + "\"");
167            } else {
168                formatted_row.push(value);
169            }
170        }
171        writeln!(buf, "{}", formatted_row.join(" "))?;
172    }
173    state.rewrites.push(Rewrite {
174        content: buf,
175        start: state.rewrite_pos_start,
176        end: state.rewrite_pos_end,
177    });
178
179    Ok(())
180}
181
182async fn try_run_sql(
183    state: &mut State,
184    query: &str,
185    expected_output: &SqlOutput,
186    should_retry: bool,
187) -> Result<(), anyhow::Error> {
188    let stmt = state
189        .materialize
190        .pgclient
191        .prepare(query)
192        .await
193        .context("preparing query failed")?;
194
195    let query_with_timeout = tokio::time::timeout(
196        state.timeout.clone(),
197        state.materialize.pgclient.query(&stmt, &[]),
198    )
199    .await;
200
201    if query_with_timeout.is_err() {
202        bail!("query timed out\n")
203    }
204
205    let rows: Vec<_> = query_with_timeout
206        .unwrap()
207        .context("executing query failed")?
208        .into_iter()
209        .map(|row| decode_row(state, row))
210        .collect::<Result<_, _>>()?;
211
212    let (mut actual, raw_actual): (Vec<_>, Vec<_>) = rows.into_iter().unzip();
213
214    let raw_actual: Option<Vec<_>> = if raw_actual.iter().any(|r| r.is_some()) {
215        Some(
218            actual
219                .iter()
220                .zip_eq(raw_actual.into_iter())
221                .map(|(actual, unreplaced)| match unreplaced {
222                    Some(raw_row) => raw_row,
223                    None => actual.clone(),
224                })
225                .collect(),
226        )
227    } else {
228        None
229    };
230
231    actual.sort();
232    let actual_columns: Vec<_> = stmt.columns().iter().map(|c| c.name()).collect();
233
234    match expected_output {
235        SqlOutput::Full {
236            expected_rows,
237            column_names,
238        } => {
239            if let Some(column_names) = column_names {
240                if actual_columns.iter().ne(column_names) {
241                    if state.rewrite_results && !should_retry {
242                        rewrite_result(state, actual_columns, actual)?;
243                        return Ok(());
244                    } else {
245                        bail!(
246                            "column name mismatch\nexpected: {:?}\nactual:   {:?}\n",
247                            column_names,
248                            actual_columns
249                        );
250                    }
251                }
252            }
253            if &actual == expected_rows {
254                Ok(())
255            } else if state.rewrite_results && !should_retry {
256                rewrite_result(state, actual_columns, actual)?;
257                Ok(())
258            } else {
259                let (mut left, mut right) = (0, 0);
260                let mut buf = String::new();
261                while let (Some(e), Some(a)) = (expected_rows.get(left), actual.get(right)) {
262                    match e.cmp(a) {
263                        std::cmp::Ordering::Less => {
264                            writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
265                            left += 1;
266                        }
267                        std::cmp::Ordering::Equal => {
268                            left += 1;
269                            right += 1;
270                        }
271                        std::cmp::Ordering::Greater => {
272                            writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
273                            right += 1;
274                        }
275                    }
276                }
277                while let Some(e) = expected_rows.get(left) {
278                    writeln!(buf, "- {}", TestdriveRow(e)).unwrap();
279                    left += 1;
280                }
281                while let Some(a) = actual.get(right) {
282                    writeln!(buf, "+ {}", TestdriveRow(a)).unwrap();
283                    right += 1;
284                }
285                if state.rewrite_results && !should_retry {
286                    rewrite_result(state, actual_columns, actual)?;
287                    Ok(())
288                } else if let Some(raw_actual) = raw_actual {
289                    bail!(
290                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\ngot raw rows:\n{:?}\nPoor diff:\n{}",
291                        expected_rows,
292                        actual,
293                        raw_actual,
294                        buf,
295                    )
296                } else {
297                    bail!(
298                        "non-matching rows: expected:\n{:?}\ngot:\n{:?}\nPoor diff:\n{}",
299                        expected_rows,
300                        actual,
301                        buf
302                    )
303                }
304            }
305        }
306        SqlOutput::Hashed { num_values, md5 } => {
307            if &actual.len() != num_values {
308                bail!(
309                    "wrong row count: expected:\n{:?}\ngot:\n{:?}\n",
310                    num_values,
311                    actual.len(),
312                )
313            } else {
314                let mut hasher = Md5::new();
315                for row in &actual {
316                    for entry in row {
317                        hasher.update(entry);
318                    }
319                }
320                let actual = format!("{:x}", hasher.finalize());
321                if &actual != md5 {
322                    bail!("wrong hash value: expected:{:?} got:{:?}", md5, actual)
323                } else {
324                    Ok(())
325                }
326            }
327        }
328    }
329}
330
331#[derive(Clone)]
332enum ErrorMatcher {
333    Contains(String),
334    Exact(String),
335    Regex(Regex),
336    Timeout,
337}
338
339impl ErrorMatcher {
340    fn is_match(&self, err: &String) -> bool {
341        match self {
342            ErrorMatcher::Contains(s) => err.contains(s),
343            ErrorMatcher::Exact(s) => err == s,
344            ErrorMatcher::Regex(r) => r.is_match(err),
345            ErrorMatcher::Timeout => false,
349        }
350    }
351}
352
353impl fmt::Display for ErrorMatcher {
354    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
355        match self {
356            ErrorMatcher::Contains(s) => write!(f, "error containing {}", s.quoted()),
357            ErrorMatcher::Exact(s) => write!(f, "exact error {}", s.quoted()),
358            ErrorMatcher::Regex(s) => write!(f, "error matching regex {}", s.as_str().quoted()),
359            ErrorMatcher::Timeout => f.write_str("timeout"),
360        }
361    }
362}
363
364impl ErrorMatcher {
365    fn fmt_with_type(&self, type_: &str) -> String {
366        match self {
367            ErrorMatcher::Contains(s) => format!("{} containing {}", type_, s.quoted()),
368            ErrorMatcher::Exact(s) => format!("exact {} {}", type_, s.quoted()),
369            ErrorMatcher::Regex(s) => format!("{} matching regex {}", type_, s.as_str().quoted()),
370            ErrorMatcher::Timeout => "timeout".to_string(),
371        }
372    }
373}
374
375pub async fn run_fail_sql(
376    cmd: FailSqlCommand,
377    state: &mut State,
378) -> Result<ControlFlow, anyhow::Error> {
379    use Statement::{AlterSink, Commit, CreateConnection, Fetch, Rollback};
380
381    let stmts = mz_sql_parser::parser::parse_statements(&cmd.query)
382        .map_err(|e| format!("unable to parse SQL: {}: {}", cmd.query, e));
383
384    let stmt = match stmts {
387        Ok(s) => {
388            if s.len() != 1 {
389                bail!("expected one statement, but got {}", s.len());
390            }
391            Some(s.into_element().ast)
392        }
393        Err(_) => None,
394    };
395
396    let expected_error = match cmd.expected_error {
397        SqlExpectedError::Contains(s) => ErrorMatcher::Contains(s),
398        SqlExpectedError::Exact(s) => ErrorMatcher::Exact(s),
399        SqlExpectedError::Regex(s) => ErrorMatcher::Regex(s.parse()?),
400        SqlExpectedError::Timeout => ErrorMatcher::Timeout,
401    };
402    let expected_detail = cmd.expected_detail.map(ErrorMatcher::Contains);
403    let expected_hint = cmd.expected_hint.map(ErrorMatcher::Contains);
404
405    let query = &cmd.query;
406    print_query(query, stmt.as_ref());
407
408    let should_retry = match &stmt {
409        None => false,
411        Some(Commit(_)) | Some(Rollback(_)) => false,
415        Some(Fetch(_)) => false,
417        Some(AlterSink(_)) => false,
418        Some(CreateConnection(_)) => false,
419        Some(_) => true,
420    };
421
422    state.error_line_count = 0;
423    state.error_string = "".to_string();
424    let res = match should_retry {
425        true => Retry::default()
426            .initial_backoff(state.initial_backoff)
427            .factor(state.backoff_factor)
428            .max_duration(state.timeout)
429            .max_tries(state.max_tries),
430        false => Retry::default().max_duration(state.timeout).max_tries(1),
431    }
432    .retry_async_with_state_canceling(state, |retry_state, state| {
433        let expected_error = expected_error.clone();
434        let expected_detail = expected_detail.clone();
435        let expected_hint = expected_hint.clone();
436        async move {
437            match try_run_fail_sql(
438                state,
439                query,
440                &expected_error,
441                expected_detail.as_ref(),
442                expected_hint.as_ref(),
443            )
444            .await
445            {
446                Ok(()) => {
447                    println!("query error matches; continuing");
448                    (state, Ok(()))
449                }
450                Err(e) => {
451                    if let Some(backoff) = retry_state.next_backoff {
452                        if !backoff.is_zero() && (retry_state.i == 0 || !mz_ore::env::is_var_truthy("CI")) {
453                            let error_string = format!("{:?}", e);
454                            if error_string != state.error_string {
455                                for _ in 0..state.error_line_count {
457                                    print!("\x1B[1A\x1B[2K");
458                                }
459                                state.error_line_count = 0;
460                                state.error_string = error_string;
461                                state.error_line_count = state.error_string.lines().count() + 1;
462                                if state.error_string.ends_with('\n') {
463                                    print!("{}", state.error_string);
464                                } else {
465                                    println!("{}", state.error_string);
466                                    state.error_line_count += 1;
467                                }
468                                println!("query error didn't match; sleeping to see if dataflow produces error shortly 🕑 {:.0?}", retry_state.next_backoff.unwrap_or_default());
469                            }
470                        }
471                    } else {
472                        for _ in 0..state.error_line_count {
473                            print!("\x1B[1A\x1B[2K");
474                        }
475                        state.error_line_count = 0;
476                    }
477                    (state, Err(e))
478                }
479            }
480        }})
481    .await;
482
483    if let ErrorMatcher::Timeout = expected_error {
486        if let Err(e) = &res {
487            if e.is::<tokio::time::error::Elapsed>() {
488                println!("query timed out as expected");
489                return Ok(ControlFlow::Continue);
490            }
491        }
492    }
493
494    res?;
498    Ok(ControlFlow::Continue)
499}
500
501async fn try_run_fail_sql(
502    state: &State,
503    query: &str,
504    expected_error: &ErrorMatcher,
505    expected_detail: Option<&ErrorMatcher>,
506    expected_hint: Option<&ErrorMatcher>,
507) -> Result<(), anyhow::Error> {
508    match state.materialize.pgclient.query(query, &[]).await {
509        Ok(_) => bail!("query succeeded, but expected {}", expected_error),
510        Err(err) => match err.source().and_then(|err| err.downcast_ref::<DbError>()) {
511            Some(err) => {
512                let mut err_string = err.message().to_string();
513                if let Some(regex) = &state.regex {
514                    err_string = regex
515                        .replace_all(&err_string, state.regex_replacement.as_str())
516                        .to_string();
517                }
518                if !expected_error.is_match(&err_string) {
519                    bail!("expected {}, got {}", expected_error, err_string.quoted());
520                }
521
522                let check_additional =
523                    |extra: Option<&str>, matcher: Option<&ErrorMatcher>, type_| {
524                        let extra = extra.map(|s| s.to_string());
525                        match (extra, matcher) {
526                            (Some(extra), Some(expected)) => {
527                                if !expected.is_match(&extra) {
528                                    bail!(
529                                        "expected {}, got {}",
530                                        expected.fmt_with_type(type_),
531                                        extra.quoted()
532                                    );
533                                }
534                            }
535                            (None, Some(expected)) => {
536                                bail!("expected {}, but found none", expected.fmt_with_type(type_));
537                            }
538                            _ => {}
539                        }
540                        Ok(())
541                    };
542
543                check_additional(err.detail(), expected_detail, "DETAIL")?;
544                check_additional(err.hint(), expected_hint, "HINT")?;
545
546                Ok(())
547            }
548            None => Err(err.into()),
549        },
550    }
551}
552
553pub fn print_query(query: &str, stmt: Option<&Statement<Raw>>) {
554    use Statement::*;
555    if let Some(CreateSecret(_)) = stmt {
556        println!(
557            "> CREATE SECRET [query truncated on purpose so as to not reveal the secret in the log]"
558        );
559    } else {
560        println!("> {}", query)
561    }
562}
563
564pub fn decode_row(
566    state: &State,
567    row: Row,
568) -> Result<(Vec<String>, Option<Vec<String>>), anyhow::Error> {
569    enum ArrayElement<T> {
570        Null,
571        NonNull(T),
572    }
573
574    impl<T> fmt::Display for ArrayElement<T>
575    where
576        T: fmt::Display,
577    {
578        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
579            match self {
580                ArrayElement::Null => f.write_str("NULL"),
581                ArrayElement::NonNull(t) => t.fmt(f),
582            }
583        }
584    }
585
586    impl<'a, T> FromSql<'a> for ArrayElement<T>
587    where
588        T: FromSql<'a>,
589    {
590        fn from_sql(
591            ty: &Type,
592            raw: &'a [u8],
593        ) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
594            T::from_sql(ty, raw).map(ArrayElement::NonNull)
595        }
596
597        fn from_sql_null(_: &Type) -> Result<ArrayElement<T>, Box<dyn Error + Sync + Send>> {
598            Ok(ArrayElement::Null)
599        }
600
601        fn accepts(ty: &Type) -> bool {
602            T::accepts(ty)
603        }
604    }
605
606    struct NumericStandardNotation(Numeric);
611
612    impl fmt::Display for NumericStandardNotation {
613        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
614            write!(f, "{}", self.0.0.0.to_standard_notation_string())
615        }
616    }
617
618    impl<'a> FromSql<'a> for NumericStandardNotation {
619        fn from_sql(
620            ty: &Type,
621            raw: &'a [u8],
622        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
623            Ok(NumericStandardNotation(Numeric::from_sql(ty, raw)?))
624        }
625
626        fn from_sql_null(
627            ty: &Type,
628        ) -> Result<NumericStandardNotation, Box<dyn Error + Sync + Send>> {
629            Ok(NumericStandardNotation(Numeric::from_sql_null(ty)?))
630        }
631
632        fn accepts(ty: &Type) -> bool {
633            Numeric::accepts(ty)
634        }
635    }
636
637    let mut out = vec![];
638    let mut raw_out = vec![];
639    for (i, col) in row.columns().iter().enumerate() {
640        let ty = col.type_();
641        let mut value: String = match *ty {
642            Type::ACLITEM => row.get::<_, Option<AclItem>>(i).map(|x| x.0),
643            Type::BOOL => row.get::<_, Option<bool>>(i).map(|x| x.to_string()),
644            Type::BPCHAR | Type::TEXT | Type::VARCHAR => row.get::<_, Option<String>>(i),
645            Type::TEXT_ARRAY => row
646                .get::<_, Option<Array<ArrayElement<String>>>>(i)
647                .map(|a| a.to_string()),
648            Type::BYTEA => row.get::<_, Option<Vec<u8>>>(i).map(|x| {
649                let s = x.into_iter().map(ascii::escape_default).flatten().collect();
650                String::from_utf8(s).unwrap()
651            }),
652            Type::CHAR => row.get::<_, Option<i8>>(i).map(|x| x.to_string()),
653            Type::INT2 => row.get::<_, Option<i16>>(i).map(|x| x.to_string()),
654            Type::INT4 => row.get::<_, Option<i32>>(i).map(|x| x.to_string()),
655            Type::INT8 => row.get::<_, Option<i64>>(i).map(|x| x.to_string()),
656            Type::OID => row.get::<_, Option<u32>>(i).map(|x| x.to_string()),
657            Type::NUMERIC => row
658                .get::<_, Option<NumericStandardNotation>>(i)
659                .map(|x| x.to_string()),
660            Type::FLOAT4 => row.get::<_, Option<f32>>(i).map(|x| x.to_string()),
661            Type::FLOAT8 => row.get::<_, Option<f64>>(i).map(|x| x.to_string()),
662            Type::TIMESTAMP => row
663                .get::<_, Option<chrono::NaiveDateTime>>(i)
664                .map(|x| x.to_string()),
665            Type::TIMESTAMPTZ => row
666                .get::<_, Option<chrono::DateTime<chrono::Utc>>>(i)
667                .map(|x| x.to_string()),
668            Type::DATE => row
669                .get::<_, Option<chrono::NaiveDate>>(i)
670                .map(|x| x.to_string()),
671            Type::TIME => row
672                .get::<_, Option<chrono::NaiveTime>>(i)
673                .map(|x| x.to_string()),
674            Type::INTERVAL => row.get::<_, Option<Interval>>(i).map(|x| x.to_string()),
675            Type::JSONB => row.get::<_, Option<Jsonb>>(i).map(|v| v.0.to_string()),
676            Type::UUID => row.get::<_, Option<uuid::Uuid>>(i).map(|v| v.to_string()),
677            Type::BOOL_ARRAY => row
678                .get::<_, Option<Array<ArrayElement<bool>>>>(i)
679                .map(|a| a.to_string()),
680            Type::INT2_ARRAY => row
681                .get::<_, Option<Array<ArrayElement<i16>>>>(i)
682                .map(|a| a.to_string()),
683            Type::INT4_ARRAY => row
684                .get::<_, Option<Array<ArrayElement<i32>>>>(i)
685                .map(|a| a.to_string()),
686            Type::INT8_ARRAY => row
687                .get::<_, Option<Array<ArrayElement<i64>>>>(i)
688                .map(|a| a.to_string()),
689            Type::OID_ARRAY => row
690                .get::<_, Option<Array<ArrayElement<u32>>>>(i)
691                .map(|x| x.to_string()),
692            Type::NUMERIC_ARRAY => row
693                .get::<_, Option<Array<ArrayElement<NumericStandardNotation>>>>(i)
694                .map(|x| x.to_string()),
695            Type::FLOAT4_ARRAY => row
696                .get::<_, Option<Array<ArrayElement<f32>>>>(i)
697                .map(|x| x.to_string()),
698            Type::FLOAT8_ARRAY => row
699                .get::<_, Option<Array<ArrayElement<f64>>>>(i)
700                .map(|x| x.to_string()),
701            Type::TIMESTAMP_ARRAY => row
702                .get::<_, Option<Array<ArrayElement<chrono::NaiveDateTime>>>>(i)
703                .map(|x| x.to_string()),
704            Type::TIMESTAMPTZ_ARRAY => row
705                .get::<_, Option<Array<ArrayElement<chrono::DateTime<chrono::Utc>>>>>(i)
706                .map(|x| x.to_string()),
707            Type::DATE_ARRAY => row
708                .get::<_, Option<Array<ArrayElement<chrono::NaiveDate>>>>(i)
709                .map(|x| x.to_string()),
710            Type::TIME_ARRAY => row
711                .get::<_, Option<Array<ArrayElement<chrono::NaiveTime>>>>(i)
712                .map(|x| x.to_string()),
713            Type::INTERVAL_ARRAY => row
714                .get::<_, Option<Array<ArrayElement<Interval>>>>(i)
715                .map(|x| x.to_string()),
716            Type::JSONB_ARRAY => row
717                .get::<_, Option<Array<ArrayElement<Jsonb>>>>(i)
718                .map(|v| v.to_string()),
719            Type::UUID_ARRAY => row
720                .get::<_, Option<Array<ArrayElement<uuid::Uuid>>>>(i)
721                .map(|v| v.to_string()),
722            Type::INT4_RANGE => row.get::<_, Option<Range<i32>>>(i).map(|v| v.to_string()),
723            Type::INT4_RANGE_ARRAY => row
724                .get::<_, Option<Array<ArrayElement<Range<i32>>>>>(i)
725                .map(|v| v.to_string()),
726            Type::INT8_RANGE => row.get::<_, Option<Range<i64>>>(i).map(|v| v.to_string()),
727            Type::INT8_RANGE_ARRAY => row
728                .get::<_, Option<Array<ArrayElement<Range<i64>>>>>(i)
729                .map(|v| v.to_string()),
730            Type::NUM_RANGE => row
731                .get::<_, Option<Range<NumericStandardNotation>>>(i)
732                .map(|v| v.to_string()),
733            Type::NUM_RANGE_ARRAY => row
734                .get::<_, Option<Array<ArrayElement<Range<NumericStandardNotation>>>>>(i)
735                .map(|v| v.to_string()),
736            Type::DATE_RANGE => row
737                .get::<_, Option<Range<chrono::NaiveDate>>>(i)
738                .map(|v| v.to_string()),
739            Type::DATE_RANGE_ARRAY => row
740                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDate>>>>>(i)
741                .map(|v| v.to_string()),
742            Type::TS_RANGE => row
743                .get::<_, Option<Range<chrono::NaiveDateTime>>>(i)
744                .map(|v| v.to_string()),
745            Type::TS_RANGE_ARRAY => row
746                .get::<_, Option<Array<ArrayElement<Range<chrono::NaiveDateTime>>>>>(i)
747                .map(|v| v.to_string()),
748            Type::TSTZ_RANGE => row
749                .get::<_, Option<Range<chrono::DateTime<chrono::Utc>>>>(i)
750                .map(|v| v.to_string()),
751            Type::TSTZ_RANGE_ARRAY => row
752                .get::<_, Option<Array<ArrayElement<Range<chrono::DateTime<chrono::Utc>>>>>>(i)
753                .map(|v| v.to_string()),
754            _ => match ty.oid() {
755                mz_pgrepr::oid::TYPE_UINT2_OID => {
756                    row.get::<_, Option<UInt2>>(i).map(|x| x.0.to_string())
757                }
758                mz_pgrepr::oid::TYPE_UINT4_OID => {
759                    row.get::<_, Option<UInt4>>(i).map(|x| x.0.to_string())
760                }
761                mz_pgrepr::oid::TYPE_UINT8_OID => {
762                    row.get::<_, Option<UInt8>>(i).map(|x| x.0.to_string())
763                }
764                mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID => {
765                    row.get::<_, Option<MzTimestamp>>(i).map(|x| x.0)
766                }
767                _ => bail!("unsupported SQL type in testdrive: {:?}", ty),
768            },
769        }
770        .unwrap_or_else(|| "<null>".into());
771
772        raw_out.push(value.clone());
773        if let Some(regex) = &state.regex {
774            value = regex
775                .replace_all(&value, state.regex_replacement.as_str())
776                .to_string();
777        }
778
779        out.push(value);
780    }
781    let raw_out = if out != raw_out { Some(raw_out) } else { None };
782    Ok((out, raw_out))
783}
784
785struct MzTimestamp(String);
786
787impl<'a> FromSql<'a> for MzTimestamp {
788    fn from_sql(_: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
789        Ok(MzTimestamp(std::str::from_utf8(raw)?.to_string()))
790    }
791
792    fn accepts(ty: &Type) -> bool {
793        ty.oid() == mz_pgrepr::oid::TYPE_MZ_TIMESTAMP_OID
794    }
795}
796
797#[allow(dead_code)]
798struct MzAclItem(String);
799
800impl<'a> FromSql<'a> for MzAclItem {
801    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
802        Ok(MzAclItem(std::str::from_utf8(raw)?.to_string()))
803    }
804
805    fn accepts(ty: &Type) -> bool {
806        ty.oid() == mz_pgrepr::oid::TYPE_MZ_ACL_ITEM_OID
807    }
808}
809
810struct AclItem(String);
811
812impl<'a> FromSql<'a> for AclItem {
813    fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
814        Ok(AclItem(std::str::from_utf8(raw)?.to_string()))
815    }
816
817    fn accepts(ty: &Type) -> bool {
818        ty.oid() == 1033
819    }
820}
821
822struct TestdriveRow<'a>(&'a Vec<String>);
823
824impl Display for TestdriveRow<'_> {
825    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
826        let mut cols = Vec::<String>::new();
827
828        for col_str in &self.0[0..self.0.len()] {
829            if col_str.contains(' ') || col_str.contains('"') || col_str.is_empty() {
830                cols.push(format!("{:?}", col_str));
831            } else {
832                cols.push(col_str.to_string());
833            }
834        }
835
836        write!(f, "{}", cols.join(" "))
837    }
838}