Skip to main content

mz_sqllogictest/
runner.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The Materialize-specific runner for sqllogictest.
11//!
12//! slt tests expect a serialized execution of sql statements and queries.
13//! To get the same results in materialize we track current_timestamp and increment it whenever we execute a statement.
14//!
15//! The high-level workflow is:
16//!   for each record in the test file:
17//!     if record is a sql statement:
18//!       run sql in postgres, observe changes and copy them to materialize using LocalInput::Updates(..)
19//!       advance current_timestamp
20//!       promise to never send updates for times < current_timestamp using LocalInput::Watermark(..)
21//!       compare to expected results
22//!       if wrong, bail out and stop processing this file
23//!     if record is a sql query:
24//!       peek query at current_timestamp
25//!       compare to expected results
26//!       if wrong, record the error
27
28use std::collections::BTreeMap;
29use std::error::Error;
30use std::fs::{File, OpenOptions};
31use std::io::{Read, Seek, SeekFrom, Write};
32use std::net::{IpAddr, Ipv4Addr, SocketAddr};
33use std::path::Path;
34use std::sync::Arc;
35use std::sync::LazyLock;
36use std::time::Duration;
37use std::{env, fmt, ops, str, thread};
38
39use anyhow::{anyhow, bail};
40use bytes::BytesMut;
41use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
42use fallible_iterator::FallibleIterator;
43use futures::sink::SinkExt;
44use itertools::Itertools;
45use maplit::btreemap;
46use md5::{Digest, Md5};
47use mz_adapter_types::bootstrap_builtin_cluster_config::{
48    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
49    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
50    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
51};
52use mz_catalog::config::ClusterReplicaSizeMap;
53use mz_controller::{ControllerConfig, ReplicaHttpLocator};
54use mz_environmentd::CatalogConfig;
55use mz_license_keys::ValidatedLicenseKey;
56use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
57use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
58use mz_ore::cast::{CastFrom, ReinterpretCast};
59use mz_ore::channel::trigger;
60use mz_ore::error::ErrorExt;
61use mz_ore::metrics::MetricsRegistry;
62use mz_ore::now::SYSTEM_TIME;
63use mz_ore::retry::Retry;
64use mz_ore::task;
65use mz_ore::thread::{JoinHandleExt, JoinOnDropHandle};
66use mz_ore::tracing::TracingHandle;
67use mz_ore::url::SensitiveUrl;
68use mz_persist_client::PersistLocation;
69use mz_persist_client::cache::PersistClientCache;
70use mz_persist_client::cfg::PersistConfig;
71use mz_persist_client::rpc::{
72    MetricsSameProcessPubSubSender, PersistGrpcPubSubServer, PubSubClientConnection, PubSubSender,
73};
74use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8, Value, oid};
75use mz_repr::ColumnName;
76use mz_repr::adt::date::Date;
77use mz_repr::adt::mz_acl_item::{AclItem, MzAclItem};
78use mz_repr::adt::numeric;
79use mz_secrets::SecretsController;
80use mz_server_core::listeners::{
81    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpListenerConfig, HttpRoutesEnabled,
82    ListenersConfig, SqlListenerConfig,
83};
84use mz_sql::ast::{Expr, Raw, Statement};
85use mz_sql::catalog::EnvironmentId;
86use mz_sql_parser::ast::display::AstDisplay;
87use mz_sql_parser::ast::{
88    CreateIndexStatement, CreateViewStatement, CteBlock, Distinct, DropObjectsStatement, Ident,
89    IfExistsBehavior, ObjectType, OrderByExpr, Query, RawItemName, Select, SelectItem,
90    SelectStatement, SetExpr, Statement as AstStatement, TableFactor, TableWithJoins,
91    UnresolvedItemName, UnresolvedObjectName, ViewDefinition,
92};
93use mz_sql_parser::parser;
94use mz_storage_types::connections::ConnectionContext;
95use postgres_protocol::types;
96use regex::Regex;
97use tempfile::TempDir;
98use tokio::net::TcpListener;
99use tokio::runtime::Runtime;
100use tokio::sync::oneshot;
101use tokio_postgres::types::{FromSql, Kind as PgKind, Type as PgType};
102use tokio_postgres::{NoTls, Row, SimpleQueryMessage};
103use tokio_stream::wrappers::TcpListenerStream;
104use tower_http::cors::AllowOrigin;
105use tracing::{error, info};
106use uuid::Uuid;
107use uuid::fmt::Simple;
108
109use crate::ast::{Location, Mode, Output, QueryOutput, Record, Sort, Type};
110use crate::util;
111
112#[derive(Debug)]
113pub enum Outcome<'a> {
114    Unsupported {
115        error: anyhow::Error,
116        location: Location,
117    },
118    ParseFailure {
119        error: anyhow::Error,
120        location: Location,
121    },
122    PlanFailure {
123        error: anyhow::Error,
124        expected_error: Option<String>,
125        location: Location,
126    },
127    UnexpectedPlanSuccess {
128        expected_error: &'a str,
129        location: Location,
130    },
131    WrongNumberOfRowsInserted {
132        expected_count: u64,
133        actual_count: u64,
134        location: Location,
135    },
136    WrongColumnCount {
137        expected_count: usize,
138        actual_count: usize,
139        location: Location,
140    },
141    WrongColumnNames {
142        expected_column_names: &'a Vec<ColumnName>,
143        actual_column_names: Vec<ColumnName>,
144        actual_output: Output,
145        location: Location,
146    },
147    OutputFailure {
148        expected_output: &'a Output,
149        actual_raw_output: Vec<Row>,
150        actual_output: Output,
151        location: Location,
152    },
153    InconsistentViewOutcome {
154        query_outcome: Box<Outcome<'a>>,
155        view_outcome: Box<Outcome<'a>>,
156        location: Location,
157    },
158    Bail {
159        cause: Box<Outcome<'a>>,
160        location: Location,
161    },
162    Warning {
163        cause: Box<Outcome<'a>>,
164        location: Location,
165    },
166    Success,
167}
168
169const NUM_OUTCOMES: usize = 12;
170const WARNING_OUTCOME: usize = NUM_OUTCOMES - 2;
171const SUCCESS_OUTCOME: usize = NUM_OUTCOMES - 1;
172
173impl<'a> Outcome<'a> {
174    fn code(&self) -> usize {
175        match self {
176            Outcome::Unsupported { .. } => 0,
177            Outcome::ParseFailure { .. } => 1,
178            Outcome::PlanFailure { .. } => 2,
179            Outcome::UnexpectedPlanSuccess { .. } => 3,
180            Outcome::WrongNumberOfRowsInserted { .. } => 4,
181            Outcome::WrongColumnCount { .. } => 5,
182            Outcome::WrongColumnNames { .. } => 6,
183            Outcome::OutputFailure { .. } => 7,
184            Outcome::InconsistentViewOutcome { .. } => 8,
185            Outcome::Bail { .. } => 9,
186            Outcome::Warning { .. } => 10,
187            Outcome::Success => 11,
188        }
189    }
190
191    fn success(&self) -> bool {
192        matches!(self, Outcome::Success)
193    }
194
195    fn failure(&self) -> bool {
196        !matches!(self, Outcome::Success) && !matches!(self, Outcome::Warning { .. })
197    }
198
199    /// Returns an error message that will match self. Appropriate for
200    /// rewriting error messages (i.e. not inserting error messages where we
201    /// currently expect success).
202    fn err_msg(&self) -> Option<String> {
203        match self {
204            Outcome::Unsupported { error, .. }
205            | Outcome::ParseFailure { error, .. }
206            | Outcome::PlanFailure { error, .. } => {
207                // Take only the first line, which should be sufficient for
208                // meaningfully matching the error.
209                let err_str = error.to_string_with_causes();
210                let err_str = err_str.split('\n').next().unwrap();
211                // Strip the "db error: ERROR: " prefix added by the postgres
212                // client library, as it's noisy and not useful for matching.
213                let err_str = err_str.strip_prefix("db error: ERROR: ").unwrap_or(err_str);
214                // This value gets fed back into regex to check that it matches
215                // `self`, so escape its meta characters. We need to undo the
216                // escaping of #. `regex::escape` escapes this because it
217                // expects that we use the `x` flag when building a regex, but
218                // this is not the case, so \# would end up being an invalid
219                // escape sequence, which would choke the parsing of the slt
220                // file the next time around.
221                Some(regex::escape(err_str).replace(r"\#", "#"))
222            }
223            _ => None,
224        }
225    }
226}
227
228impl fmt::Display for Outcome<'_> {
229    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
230        use Outcome::*;
231        const INDENT: &str = "\n        ";
232        match self {
233            Unsupported { error, location } => write!(
234                f,
235                "Unsupported:{}:\n{}",
236                location,
237                error.display_with_causes()
238            ),
239            ParseFailure { error, location } => {
240                write!(
241                    f,
242                    "ParseFailure:{}:\n{}",
243                    location,
244                    error.display_with_causes()
245                )
246            }
247            PlanFailure {
248                error,
249                expected_error,
250                location,
251            } => {
252                if let Some(expected_error) = expected_error {
253                    write!(
254                        f,
255                        "PlanFailure:{}:\nerror does not match expected pattern:\n  expected: /{}/\n  actual:    {}",
256                        location,
257                        expected_error,
258                        error.display_with_causes()
259                    )
260                } else {
261                    write!(f, "PlanFailure:{}:\n{:#}", location, error)
262                }
263            }
264            UnexpectedPlanSuccess {
265                expected_error,
266                location,
267            } => write!(
268                f,
269                "UnexpectedPlanSuccess:{} expected error: {}",
270                location, expected_error
271            ),
272            WrongNumberOfRowsInserted {
273                expected_count,
274                actual_count,
275                location,
276            } => write!(
277                f,
278                "WrongNumberOfRowsInserted:{}{}expected: {}{}actually: {}",
279                location, INDENT, expected_count, INDENT, actual_count
280            ),
281            WrongColumnCount {
282                expected_count,
283                actual_count,
284                location,
285            } => write!(
286                f,
287                "WrongColumnCount:{}{}expected: {}{}actually: {}",
288                location, INDENT, expected_count, INDENT, actual_count
289            ),
290            WrongColumnNames {
291                expected_column_names,
292                actual_column_names,
293                actual_output: _,
294                location,
295            } => write!(
296                f,
297                "Wrong Column Names:{}:{}expected column names: {}{}inferred column names: {}",
298                location,
299                INDENT,
300                expected_column_names
301                    .iter()
302                    .map(|n| n.to_string())
303                    .collect::<Vec<_>>()
304                    .join(" "),
305                INDENT,
306                actual_column_names
307                    .iter()
308                    .map(|n| n.to_string())
309                    .collect::<Vec<_>>()
310                    .join(" ")
311            ),
312            OutputFailure {
313                expected_output,
314                actual_raw_output,
315                actual_output,
316                location,
317            } => write!(
318                f,
319                "OutputFailure:{}{}expected: {:?}{}actually: {:?}{}actual raw: {:?}",
320                location, INDENT, expected_output, INDENT, actual_output, INDENT, actual_raw_output
321            ),
322            InconsistentViewOutcome {
323                query_outcome,
324                view_outcome,
325                location,
326            } => write!(
327                f,
328                "InconsistentViewOutcome:{}{}expected from query: {}{}actually from indexed view: {}",
329                location, INDENT, query_outcome, INDENT, view_outcome
330            ),
331            Bail { cause, location } => write!(f, "Bail:{} {}", location, cause),
332            Warning { cause, location } => write!(f, "Warning:{} {}", location, cause),
333            Success => f.write_str("Success"),
334        }
335    }
336}
337
338#[derive(Default, Debug)]
339pub struct Outcomes {
340    stats: [usize; NUM_OUTCOMES],
341    details: Vec<String>,
342}
343
344impl ops::AddAssign<Outcomes> for Outcomes {
345    fn add_assign(&mut self, rhs: Outcomes) {
346        for (lhs, rhs) in self.stats.iter_mut().zip_eq(rhs.stats.iter()) {
347            *lhs += rhs
348        }
349    }
350}
351impl Outcomes {
352    pub fn any_failed(&self) -> bool {
353        self.stats[SUCCESS_OUTCOME] + self.stats[WARNING_OUTCOME] < self.stats.iter().sum::<usize>()
354    }
355
356    pub fn as_json(&self) -> serde_json::Value {
357        serde_json::json!({
358            "unsupported": self.stats[0],
359            "parse_failure": self.stats[1],
360            "plan_failure": self.stats[2],
361            "unexpected_plan_success": self.stats[3],
362            "wrong_number_of_rows_affected": self.stats[4],
363            "wrong_column_count": self.stats[5],
364            "wrong_column_names": self.stats[6],
365            "output_failure": self.stats[7],
366            "inconsistent_view_outcome": self.stats[8],
367            "bail": self.stats[9],
368            "warning": self.stats[10],
369            "success": self.stats[11],
370        })
371    }
372
373    pub fn display(&self, no_fail: bool, failure_details: bool) -> OutcomesDisplay<'_> {
374        OutcomesDisplay {
375            inner: self,
376            no_fail,
377            failure_details,
378        }
379    }
380}
381
382pub struct OutcomesDisplay<'a> {
383    inner: &'a Outcomes,
384    no_fail: bool,
385    failure_details: bool,
386}
387
388impl<'a> fmt::Display for OutcomesDisplay<'a> {
389    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
390        let total: usize = self.inner.stats.iter().sum();
391        if self.failure_details
392            && (self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] != total
393                || self.no_fail)
394        {
395            for outcome in &self.inner.details {
396                writeln!(f, "{}", outcome)?;
397            }
398            Ok(())
399        } else {
400            write!(
401                f,
402                "{}:",
403                if self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] == total {
404                    "PASS"
405                } else if self.no_fail {
406                    "FAIL-IGNORE"
407                } else {
408                    "FAIL"
409                }
410            )?;
411            static NAMES: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
412                vec![
413                    "unsupported",
414                    "parse-failure",
415                    "plan-failure",
416                    "unexpected-plan-success",
417                    "wrong-number-of-rows-inserted",
418                    "wrong-column-count",
419                    "wrong-column-names",
420                    "output-failure",
421                    "inconsistent-view-outcome",
422                    "bail",
423                    "warning",
424                    "success",
425                    "total",
426                ]
427            });
428            for (i, n) in self.inner.stats.iter().enumerate() {
429                if *n > 0 {
430                    write!(f, " {}={}", NAMES[i], n)?;
431                }
432            }
433            write!(f, " total={}", total)
434        }
435    }
436}
437
438struct QueryInfo {
439    is_select: bool,
440    num_attributes: Option<usize>,
441}
442
443enum PrepareQueryOutcome<'a> {
444    QueryPrepared(QueryInfo),
445    Outcome(Outcome<'a>),
446}
447
448pub struct Runner<'a> {
449    config: &'a RunConfig<'a>,
450    inner: Option<RunnerInner<'a>>,
451}
452
453pub struct RunnerInner<'a> {
454    server_addr: SocketAddr,
455    internal_server_addr: SocketAddr,
456    password_server_addr: SocketAddr,
457    internal_http_server_addr: SocketAddr,
458    // Drop order matters for these fields.
459    client: tokio_postgres::Client,
460    system_client: tokio_postgres::Client,
461    clients: BTreeMap<String, tokio_postgres::Client>,
462    auto_index_tables: bool,
463    auto_index_selects: bool,
464    auto_transactions: bool,
465    enable_table_keys: bool,
466    verbose: bool,
467    stdout: &'a dyn WriteFmt,
468    _shutdown_trigger: trigger::Trigger,
469    _server_thread: JoinOnDropHandle<()>,
470    _temp_dir: TempDir,
471}
472
473#[derive(Debug)]
474pub struct Slt(Value);
475
476impl<'a> FromSql<'a> for Slt {
477    fn from_sql(
478        ty: &PgType,
479        mut raw: &'a [u8],
480    ) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
481        Ok(match *ty {
482            PgType::ACLITEM => Self(Value::AclItem(AclItem::decode_binary(
483                types::bytea_from_sql(raw),
484            )?)),
485            PgType::BOOL => Self(Value::Bool(types::bool_from_sql(raw)?)),
486            PgType::BYTEA => Self(Value::Bytea(types::bytea_from_sql(raw).to_vec())),
487            PgType::CHAR => Self(Value::Char(u8::from_be_bytes(
488                types::char_from_sql(raw)?.to_be_bytes(),
489            ))),
490            PgType::FLOAT4 => Self(Value::Float4(types::float4_from_sql(raw)?)),
491            PgType::FLOAT8 => Self(Value::Float8(types::float8_from_sql(raw)?)),
492            PgType::DATE => Self(Value::Date(Date::from_pg_epoch(types::int4_from_sql(
493                raw,
494            )?)?)),
495            PgType::INT2 => Self(Value::Int2(types::int2_from_sql(raw)?)),
496            PgType::INT4 => Self(Value::Int4(types::int4_from_sql(raw)?)),
497            PgType::INT8 => Self(Value::Int8(types::int8_from_sql(raw)?)),
498            PgType::INTERVAL => Self(Value::Interval(Interval::from_sql(ty, raw)?)),
499            PgType::JSONB => Self(Value::Jsonb(Jsonb::from_sql(ty, raw)?)),
500            PgType::NAME => Self(Value::Name(types::text_from_sql(raw)?.to_string())),
501            PgType::NUMERIC => Self(Value::Numeric(Numeric::from_sql(ty, raw)?)),
502            PgType::OID => Self(Value::Oid(types::oid_from_sql(raw)?)),
503            PgType::REGCLASS => Self(Value::Oid(types::oid_from_sql(raw)?)),
504            PgType::REGPROC => Self(Value::Oid(types::oid_from_sql(raw)?)),
505            PgType::REGTYPE => Self(Value::Oid(types::oid_from_sql(raw)?)),
506            PgType::TEXT | PgType::BPCHAR | PgType::VARCHAR => {
507                Self(Value::Text(types::text_from_sql(raw)?.to_string()))
508            }
509            PgType::TIME => Self(Value::Time(NaiveTime::from_sql(ty, raw)?)),
510            PgType::TIMESTAMP => Self(Value::Timestamp(
511                NaiveDateTime::from_sql(ty, raw)?.try_into()?,
512            )),
513            PgType::TIMESTAMPTZ => Self(Value::TimestampTz(
514                DateTime::<Utc>::from_sql(ty, raw)?.try_into()?,
515            )),
516            PgType::UUID => Self(Value::Uuid(Uuid::from_sql(ty, raw)?)),
517            PgType::RECORD => {
518                let num_fields = read_be_i32(&mut raw)?;
519                let mut tuple = vec![];
520                for _ in 0..num_fields {
521                    let oid = u32::reinterpret_cast(read_be_i32(&mut raw)?);
522                    let typ = match PgType::from_oid(oid) {
523                        Some(typ) => typ,
524                        None => return Err("unknown oid".into()),
525                    };
526                    let v = read_value::<Option<Slt>>(&typ, &mut raw)?;
527                    tuple.push(v.map(|v| v.0));
528                }
529                Self(Value::Record(tuple))
530            }
531            PgType::INT4_RANGE
532            | PgType::INT8_RANGE
533            | PgType::DATE_RANGE
534            | PgType::NUM_RANGE
535            | PgType::TS_RANGE
536            | PgType::TSTZ_RANGE => {
537                use mz_repr::adt::range::Range;
538                let range: Range<Slt> = Range::from_sql(ty, raw)?;
539                Self(Value::Range(range.into_bounds(|b| Box::new(b.0))))
540            }
541
542            _ => match ty.kind() {
543                PgKind::Array(arr_type) => {
544                    let arr = types::array_from_sql(raw)?;
545                    let elements: Vec<Option<Value>> = arr
546                        .values()
547                        .map(|v| match v {
548                            Some(v) => Ok(Some(Slt::from_sql(arr_type, v)?)),
549                            None => Ok(None),
550                        })
551                        .collect::<Vec<Option<Slt>>>()?
552                        .into_iter()
553                        // Map a Vec<Option<Slt>> to Vec<Option<Value>>.
554                        .map(|v| v.map(|v| v.0))
555                        .collect();
556
557                    Self(Value::Array {
558                        dims: arr
559                            .dimensions()
560                            .map(|d| {
561                                Ok(mz_repr::adt::array::ArrayDimension {
562                                    lower_bound: isize::cast_from(d.lower_bound),
563                                    length: usize::try_from(d.len)
564                                        .expect("cannot have negative length"),
565                                })
566                            })
567                            .collect()?,
568                        elements,
569                    })
570                }
571                _ => match ty.oid() {
572                    oid::TYPE_UINT2_OID => Self(Value::UInt2(UInt2::from_sql(ty, raw)?)),
573                    oid::TYPE_UINT4_OID => Self(Value::UInt4(UInt4::from_sql(ty, raw)?)),
574                    oid::TYPE_UINT8_OID => Self(Value::UInt8(UInt8::from_sql(ty, raw)?)),
575                    oid::TYPE_MZ_TIMESTAMP_OID => {
576                        let s = types::text_from_sql(raw)?;
577                        let t: mz_repr::Timestamp = s.parse()?;
578                        Self(Value::MzTimestamp(t))
579                    }
580                    oid::TYPE_MZ_ACL_ITEM_OID => Self(Value::MzAclItem(MzAclItem::decode_binary(
581                        types::bytea_from_sql(raw),
582                    )?)),
583                    _ => unreachable!(),
584                },
585            },
586        })
587    }
588    fn accepts(ty: &PgType) -> bool {
589        match ty.kind() {
590            PgKind::Array(_) | PgKind::Composite(_) => return true,
591            _ => {}
592        }
593        match ty.oid() {
594            oid::TYPE_UINT2_OID
595            | oid::TYPE_UINT4_OID
596            | oid::TYPE_UINT8_OID
597            | oid::TYPE_MZ_TIMESTAMP_OID
598            | oid::TYPE_MZ_ACL_ITEM_OID => return true,
599            _ => {}
600        }
601        matches!(
602            *ty,
603            PgType::ACLITEM
604                | PgType::BOOL
605                | PgType::BYTEA
606                | PgType::CHAR
607                | PgType::DATE
608                | PgType::FLOAT4
609                | PgType::FLOAT8
610                | PgType::INT2
611                | PgType::INT4
612                | PgType::INT8
613                | PgType::INTERVAL
614                | PgType::JSONB
615                | PgType::NAME
616                | PgType::NUMERIC
617                | PgType::OID
618                | PgType::REGCLASS
619                | PgType::REGPROC
620                | PgType::REGTYPE
621                | PgType::RECORD
622                | PgType::TEXT
623                | PgType::BPCHAR
624                | PgType::VARCHAR
625                | PgType::TIME
626                | PgType::TIMESTAMP
627                | PgType::TIMESTAMPTZ
628                | PgType::UUID
629                | PgType::INT4_RANGE
630                | PgType::INT4_RANGE_ARRAY
631                | PgType::INT8_RANGE
632                | PgType::INT8_RANGE_ARRAY
633                | PgType::DATE_RANGE
634                | PgType::DATE_RANGE_ARRAY
635                | PgType::NUM_RANGE
636                | PgType::NUM_RANGE_ARRAY
637                | PgType::TS_RANGE
638                | PgType::TS_RANGE_ARRAY
639                | PgType::TSTZ_RANGE
640                | PgType::TSTZ_RANGE_ARRAY
641        )
642    }
643}
644
645// From postgres-types/src/private.rs.
646fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
647    if buf.len() < 4 {
648        return Err("invalid buffer size".into());
649    }
650    let mut bytes = [0; 4];
651    bytes.copy_from_slice(&buf[..4]);
652    *buf = &buf[4..];
653    Ok(i32::from_be_bytes(bytes))
654}
655
656// From postgres-types/src/private.rs.
657fn read_value<'a, T>(type_: &PgType, buf: &mut &'a [u8]) -> Result<T, Box<dyn Error + Sync + Send>>
658where
659    T: FromSql<'a>,
660{
661    let value = match usize::try_from(read_be_i32(buf)?) {
662        Err(_) => None,
663        Ok(len) => {
664            if len > buf.len() {
665                return Err("invalid buffer size".into());
666            }
667            let (head, tail) = buf.split_at(len);
668            *buf = tail;
669            Some(head)
670        }
671    };
672    T::from_sql_nullable(type_, value)
673}
674
675fn format_datum(d: Slt, typ: &Type, mode: Mode, col: usize) -> String {
676    match (typ, d.0) {
677        (Type::Bool, Value::Bool(b)) => b.to_string(),
678
679        (Type::Integer, Value::Int2(i)) => i.to_string(),
680        (Type::Integer, Value::Int4(i)) => i.to_string(),
681        (Type::Integer, Value::Int8(i)) => i.to_string(),
682        (Type::Integer, Value::UInt2(u)) => u.0.to_string(),
683        (Type::Integer, Value::UInt4(u)) => u.0.to_string(),
684        (Type::Integer, Value::UInt8(u)) => u.0.to_string(),
685        (Type::Integer, Value::Oid(i)) => i.to_string(),
686        // TODO(benesch): rewrite to avoid `as`.
687        #[allow(clippy::as_conversions)]
688        (Type::Integer, Value::Float4(f)) => format!("{}", f as i64),
689        // TODO(benesch): rewrite to avoid `as`.
690        #[allow(clippy::as_conversions)]
691        (Type::Integer, Value::Float8(f)) => format!("{}", f as i64),
692        // This is so wrong, but sqlite needs it.
693        (Type::Integer, Value::Text(_)) => "0".to_string(),
694        (Type::Integer, Value::Bool(b)) => i8::from(b).to_string(),
695        (Type::Integer, Value::Numeric(d)) => {
696            let mut d = d.0.0.clone();
697            let mut cx = numeric::cx_datum();
698            // Truncate the decimal to match sqlite.
699            if mode == Mode::Standard {
700                cx.set_rounding(dec::Rounding::Down);
701            }
702            cx.round(&mut d);
703            numeric::munge_numeric(&mut d).unwrap();
704            d.to_standard_notation_string()
705        }
706
707        (Type::Real, Value::Int2(i)) => format!("{:.3}", i),
708        (Type::Real, Value::Int4(i)) => format!("{:.3}", i),
709        (Type::Real, Value::Int8(i)) => format!("{:.3}", i),
710        (Type::Real, Value::Float4(f)) => match mode {
711            Mode::Standard => format!("{:.3}", f),
712            Mode::Cockroach => format!("{}", f),
713        },
714        (Type::Real, Value::Float8(f)) => match mode {
715            Mode::Standard => format!("{:.3}", f),
716            Mode::Cockroach => format!("{}", f),
717        },
718        (Type::Real, Value::Numeric(d)) => match mode {
719            Mode::Standard => {
720                let mut d = d.0.0.clone();
721                if d.exponent() < -3 {
722                    numeric::rescale(&mut d, 3).unwrap();
723                }
724                numeric::munge_numeric(&mut d).unwrap();
725                d.to_standard_notation_string()
726            }
727            Mode::Cockroach => d.0.0.to_standard_notation_string(),
728        },
729
730        (Type::Text, Value::Text(s)) => {
731            if s.is_empty() {
732                "(empty)".to_string()
733            } else {
734                s
735            }
736        }
737        (Type::Text, Value::Bool(b)) => b.to_string(),
738        (Type::Text, Value::Float4(f)) => format!("{:.3}", f),
739        (Type::Text, Value::Float8(f)) => format!("{:.3}", f),
740        // Bytes are printed as text iff they are valid UTF-8. This
741        // seems guaranteed to confuse everyone, but it is required for
742        // compliance with the CockroachDB sqllogictest runner. [0]
743        //
744        // [0]: https://github.com/cockroachdb/cockroach/blob/970782487/pkg/sql/logictest/logic.go#L2038-L2043
745        (Type::Text, Value::Bytea(b)) => match str::from_utf8(&b) {
746            Ok(s) => s.to_string(),
747            Err(_) => format!("{:?}", b),
748        },
749        (Type::Text, Value::Numeric(d)) => d.0.0.to_standard_notation_string(),
750        // Everything else gets normal text encoding. This correctly handles things
751        // like arrays, tuples, and strings that need to be quoted.
752        (Type::Text, d) => {
753            let mut buf = BytesMut::new();
754            d.encode_text(&mut buf);
755            String::from_utf8_lossy(&buf).into_owned()
756        }
757
758        (Type::Oid, Value::Oid(o)) => o.to_string(),
759
760        (_, d) => panic!(
761            "Don't know how to format {:?} as {:?} in column {}",
762            d, typ, col,
763        ),
764    }
765}
766
767fn format_row(row: &Row, types: &[Type], mode: Mode) -> Vec<String> {
768    let mut formatted: Vec<String> = vec![];
769    for i in 0..row.len() {
770        let t: Option<Slt> = row.get::<usize, Option<Slt>>(i);
771        let t: Option<String> = t.map(|d| format_datum(d, &types[i], mode, i));
772        formatted.push(match t {
773            Some(t) => t,
774            None => "NULL".into(),
775        });
776    }
777
778    formatted
779}
780
781impl<'a> Runner<'a> {
782    pub async fn start(config: &'a RunConfig<'a>) -> Result<Runner<'a>, anyhow::Error> {
783        let mut runner = Self {
784            config,
785            inner: None,
786        };
787        runner.reset().await?;
788        Ok(runner)
789    }
790
791    pub async fn reset(&mut self) -> Result<(), anyhow::Error> {
792        // Explicitly drop the old runner here to ensure that we wait for threads to terminate
793        // before starting a new runner
794        drop(self.inner.take());
795        self.inner = Some(RunnerInner::start(self.config).await?);
796
797        Ok(())
798    }
799
800    async fn run_record<'r>(
801        &mut self,
802        record: &'r Record<'r>,
803        in_transaction: &mut bool,
804    ) -> Result<Outcome<'r>, anyhow::Error> {
805        if let Record::ResetServer = record {
806            self.reset().await?;
807            Ok(Outcome::Success)
808        } else {
809            self.inner
810                .as_mut()
811                .expect("RunnerInner missing")
812                .run_record(record, in_transaction)
813                .await
814        }
815    }
816
817    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
818        self.inner
819            .as_ref()
820            .expect("RunnerInner missing")
821            .check_catalog()
822            .await
823    }
824
825    async fn reset_database(&mut self) -> Result<(), anyhow::Error> {
826        let inner = self.inner.as_mut().expect("RunnerInner missing");
827
828        inner.client.batch_execute("ROLLBACK;").await?;
829
830        inner
831            .system_client
832            .batch_execute(
833                "ROLLBACK;
834                 SET cluster = mz_catalog_server;
835                 RESET cluster_replica;",
836            )
837            .await?;
838
839        inner
840            .system_client
841            .batch_execute("ALTER SYSTEM RESET ALL")
842            .await?;
843
844        // Drop all databases, then recreate the `materialize` database.
845        for row in inner
846            .system_client
847            .query("SELECT name FROM mz_databases", &[])
848            .await?
849        {
850            let name: &str = row.get("name");
851            inner
852                .system_client
853                .batch_execute(&format!("DROP DATABASE \"{name}\""))
854                .await?;
855        }
856        inner
857            .system_client
858            .batch_execute("CREATE DATABASE materialize")
859            .await?;
860
861        // Ensure quickstart cluster exists with one replica of size `self.config.replica_size`.
862        // We don't destroy the existing quickstart cluster replica if it exists, as turning
863        // on a cluster replica is exceptionally slow.
864        let mut needs_default_cluster = true;
865        for row in inner
866            .system_client
867            .query("SELECT name FROM mz_clusters WHERE id LIKE 'u%'", &[])
868            .await?
869        {
870            match row.get("name") {
871                "quickstart" => needs_default_cluster = false,
872                name => {
873                    inner
874                        .system_client
875                        .batch_execute(&format!("DROP CLUSTER {name}"))
876                        .await?
877                }
878            }
879        }
880        if needs_default_cluster {
881            inner
882                .system_client
883                .batch_execute("CREATE CLUSTER quickstart REPLICAS ()")
884                .await?;
885        }
886        let mut needs_default_replica = false;
887        let rows = inner
888            .system_client
889            .query(
890                "SELECT name, size FROM mz_cluster_replicas
891                 WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')
892                 ORDER BY name",
893                &[],
894            )
895            .await?;
896        if rows.len() != self.config.replicas {
897            needs_default_replica = true;
898        } else {
899            for (i, row) in rows.iter().enumerate() {
900                let name: &str = row.get("name");
901                let size: &str = row.get("size");
902                if name != format!("r{i}") || size != self.config.replica_size {
903                    needs_default_replica = true;
904                    break;
905                }
906            }
907        }
908
909        if needs_default_replica {
910            inner
911                .system_client
912                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = false)")
913                .await?;
914            for row in inner
915                .system_client
916                .query(
917                    "SELECT name FROM mz_cluster_replicas
918                     WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')",
919                    &[],
920                )
921                .await?
922            {
923                let name: &str = row.get("name");
924                inner
925                    .system_client
926                    .batch_execute(&format!("DROP CLUSTER REPLICA quickstart.{}", name))
927                    .await?;
928            }
929            for i in 1..=self.config.replicas {
930                inner
931                    .system_client
932                    .batch_execute(&format!(
933                        "CREATE CLUSTER REPLICA quickstart.r{i} SIZE '{}'",
934                        self.config.replica_size
935                    ))
936                    .await?;
937            }
938            inner
939                .system_client
940                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = true)")
941                .await?;
942        }
943
944        // Grant initial privileges.
945        inner
946            .system_client
947            .batch_execute("GRANT USAGE ON DATABASE materialize TO PUBLIC")
948            .await?;
949        inner
950            .system_client
951            .batch_execute("GRANT CREATE ON DATABASE materialize TO materialize")
952            .await?;
953        inner
954            .system_client
955            .batch_execute("GRANT CREATE ON SCHEMA materialize.public TO materialize")
956            .await?;
957        inner
958            .system_client
959            .batch_execute("GRANT USAGE ON CLUSTER quickstart TO PUBLIC")
960            .await?;
961        inner
962            .system_client
963            .batch_execute("GRANT CREATE ON CLUSTER quickstart TO materialize")
964            .await?;
965
966        // Some sqllogictests require more than the default amount of tables, so we increase the
967        // limit for all tests.
968        inner
969            .system_client
970            .simple_query("ALTER SYSTEM SET max_tables = 100")
971            .await?;
972
973        if inner.enable_table_keys {
974            inner
975                .system_client
976                .simple_query("ALTER SYSTEM SET unsafe_enable_table_keys = true")
977                .await?;
978        }
979
980        inner.ensure_fixed_features().await?;
981
982        inner.client = connect(inner.server_addr, None, None).await.unwrap();
983        inner.system_client = connect(inner.internal_server_addr, Some("mz_system"), None)
984            .await
985            .unwrap();
986        inner.clients = BTreeMap::new();
987
988        Ok(())
989    }
990}
991
992impl<'a> RunnerInner<'a> {
993    pub async fn start(config: &RunConfig<'a>) -> Result<RunnerInner<'a>, anyhow::Error> {
994        let temp_dir = tempfile::tempdir()?;
995        let scratch_dir = tempfile::tempdir()?;
996        let environment_id = EnvironmentId::for_tests();
997        let (consensus_uri, timestamp_oracle_url): (SensitiveUrl, SensitiveUrl) = {
998            let postgres_url = &config.postgres_url;
999            let prefix = &config.prefix;
1000            info!(%postgres_url, "starting server");
1001            let (client, conn) = Retry::default()
1002                .max_tries(5)
1003                .retry_async(|_| async {
1004                    match tokio_postgres::connect(postgres_url, NoTls).await {
1005                        Ok(c) => Ok(c),
1006                        Err(e) => {
1007                            error!(%e, "failed to connect to postgres");
1008                            Err(e)
1009                        }
1010                    }
1011                })
1012                .await?;
1013            task::spawn(|| "sqllogictest_connect", async move {
1014                if let Err(e) = conn.await {
1015                    panic!("connection error: {}", e);
1016                }
1017            });
1018            client
1019                .batch_execute(&format!(
1020                    "DROP SCHEMA IF EXISTS {prefix}_tsoracle CASCADE;
1021                     CREATE SCHEMA IF NOT EXISTS {prefix}_consensus;
1022                     CREATE SCHEMA {prefix}_tsoracle;"
1023                ))
1024                .await?;
1025            (
1026                format!("{postgres_url}?options=--search_path={prefix}_consensus")
1027                    .parse()
1028                    .expect("invalid consensus URI"),
1029                format!("{postgres_url}?options=--search_path={prefix}_tsoracle")
1030                    .parse()
1031                    .expect("invalid timestamp oracle URI"),
1032            )
1033        };
1034
1035        let secrets_dir = temp_dir.path().join("secrets");
1036        let orchestrator = Arc::new(
1037            ProcessOrchestrator::new(ProcessOrchestratorConfig {
1038                image_dir: env::current_exe()?.parent().unwrap().to_path_buf(),
1039                suppress_output: false,
1040                environment_id: environment_id.to_string(),
1041                secrets_dir: secrets_dir.clone(),
1042                command_wrapper: config
1043                    .orchestrator_process_wrapper
1044                    .as_ref()
1045                    .map_or(Ok(vec![]), |s| shell_words::split(s))?,
1046                propagate_crashes: true,
1047                tcp_proxy: None,
1048                scratch_directory: scratch_dir.path().to_path_buf(),
1049            })
1050            .await?,
1051        );
1052        let now = SYSTEM_TIME.clone();
1053        let metrics_registry = MetricsRegistry::new();
1054
1055        let persist_config = PersistConfig::new(
1056            &mz_environmentd::BUILD_INFO,
1057            now.clone(),
1058            mz_dyncfgs::all_dyncfgs(),
1059        );
1060        let persist_pubsub_server =
1061            PersistGrpcPubSubServer::new(&persist_config, &metrics_registry);
1062        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
1063        let persist_pubsub_tcp_listener =
1064            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1065                .await
1066                .expect("pubsub addr binding");
1067        let persist_pubsub_server_port = persist_pubsub_tcp_listener
1068            .local_addr()
1069            .expect("pubsub addr has local addr")
1070            .port();
1071        info!("listening for persist pubsub connections on localhost:{persist_pubsub_server_port}");
1072        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
1073            persist_pubsub_server
1074                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
1075                .await
1076                .expect("success")
1077        });
1078        let persist_clients =
1079            PersistClientCache::new(persist_config, &metrics_registry, |cfg, metrics| {
1080                let sender: Arc<dyn PubSubSender> = Arc::new(MetricsSameProcessPubSubSender::new(
1081                    cfg,
1082                    persist_pubsub_client.sender,
1083                    metrics,
1084                ));
1085                PubSubClientConnection::new(sender, persist_pubsub_client.receiver)
1086            });
1087        let persist_clients = Arc::new(persist_clients);
1088
1089        let secrets_controller = Arc::clone(&orchestrator);
1090        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
1091        let orchestrator = Arc::new(TracingOrchestrator::new(
1092            orchestrator,
1093            config.tracing.clone(),
1094        ));
1095        let listeners_config = ListenersConfig {
1096            sql: btreemap! {
1097                "external".to_owned() => SqlListenerConfig {
1098                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1099                    authenticator_kind: AuthenticatorKind::None,
1100                    allowed_roles: AllowedRoles::Normal,
1101                    enable_tls: false,
1102                },
1103                "internal".to_owned() => SqlListenerConfig {
1104                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1105                    authenticator_kind: AuthenticatorKind::None,
1106                    allowed_roles: AllowedRoles::Internal,
1107                    enable_tls: false,
1108                },
1109                "password".to_owned() => SqlListenerConfig {
1110                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1111                    authenticator_kind: AuthenticatorKind::Password,
1112                    allowed_roles: AllowedRoles::Normal,
1113                    enable_tls: false,
1114                },
1115            },
1116            http: btreemap![
1117                "external".to_owned() => HttpListenerConfig {
1118                    base: BaseListenerConfig {
1119                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1120                        authenticator_kind: AuthenticatorKind::None,
1121                        allowed_roles: AllowedRoles::Normal,
1122                        enable_tls: false
1123                    },
1124                    routes: HttpRoutesEnabled {
1125                        base: true,
1126                        webhook: true,
1127                        internal: false,
1128                        metrics: false,
1129                        profiling: false,
1130                        mcp_agents: false,
1131                        mcp_observatory: false,
1132                        console_config: true,
1133                    },
1134                },
1135                "internal".to_owned() => HttpListenerConfig {
1136                    base: BaseListenerConfig {
1137                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1138                        authenticator_kind: AuthenticatorKind::None,
1139                        allowed_roles: AllowedRoles::NormalAndInternal,
1140                        enable_tls: false
1141                    },
1142                    routes: HttpRoutesEnabled {
1143                        base: true,
1144                        webhook: true,
1145                        internal: true,
1146                        metrics: true,
1147                        profiling: true,
1148                        mcp_agents: false,
1149                        mcp_observatory: false,
1150                        console_config: false,
1151                    },
1152                },
1153            ],
1154        };
1155        let listeners = mz_environmentd::Listeners::bind(listeners_config).await?;
1156        let host_name = format!(
1157            "localhost:{}",
1158            listeners.http["external"].handle.local_addr.port()
1159        );
1160        let catalog_config = CatalogConfig {
1161            persist_clients: Arc::clone(&persist_clients),
1162            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
1163        };
1164        let server_config = mz_environmentd::Config {
1165            catalog_config,
1166            timestamp_oracle_url: Some(timestamp_oracle_url),
1167            controller: ControllerConfig {
1168                build_info: &mz_environmentd::BUILD_INFO,
1169                orchestrator,
1170                clusterd_image: "clusterd".into(),
1171                init_container_image: None,
1172                deploy_generation: 0,
1173                persist_location: PersistLocation {
1174                    blob_uri: format!(
1175                        "file://{}/persist/blob",
1176                        config.persist_dir.path().display()
1177                    )
1178                    .parse()
1179                    .expect("invalid blob URI"),
1180                    consensus_uri,
1181                },
1182                persist_clients,
1183                now: SYSTEM_TIME.clone(),
1184                metrics_registry: metrics_registry.clone(),
1185                persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
1186                secrets_args: mz_service::secrets::SecretsReaderCliArgs {
1187                    secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
1188                    secrets_reader_local_file_dir: Some(secrets_dir),
1189                    secrets_reader_kubernetes_context: None,
1190                    secrets_reader_aws_prefix: None,
1191                    secrets_reader_name_prefix: None,
1192                },
1193                connection_context,
1194                replica_http_locator: Arc::new(ReplicaHttpLocator::default()),
1195            },
1196            secrets_controller,
1197            cloud_resource_controller: None,
1198            tls: None,
1199            frontegg: None,
1200            cors_allowed_origin: AllowOrigin::list([]),
1201            unsafe_mode: true,
1202            all_features: false,
1203            metrics_registry,
1204            now,
1205            environment_id,
1206            cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
1207            bootstrap_default_cluster_replica_size: config.replica_size.clone(),
1208            bootstrap_default_cluster_replication_factor: config
1209                .replicas
1210                .try_into()
1211                .expect("replicas must fit"),
1212            bootstrap_builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
1213                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1214                size: config.replica_size.clone(),
1215            },
1216            bootstrap_builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
1217                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1218                size: config.replica_size.clone(),
1219            },
1220            bootstrap_builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
1221                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1222                size: config.replica_size.clone(),
1223            },
1224            bootstrap_builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
1225                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1226                size: config.replica_size.clone(),
1227            },
1228            bootstrap_builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
1229                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1230                size: config.replica_size.clone(),
1231            },
1232            system_parameter_defaults: {
1233                let mut params = BTreeMap::new();
1234                params.insert(
1235                    "log_filter".to_string(),
1236                    config.tracing.startup_log_filter.to_string(),
1237                );
1238                params.extend(config.system_parameter_defaults.clone());
1239                params
1240            },
1241            availability_zones: Default::default(),
1242            tracing_handle: config.tracing_handle.clone(),
1243            storage_usage_collection_interval: Duration::from_secs(3600),
1244            storage_usage_retention_period: None,
1245            segment_api_key: None,
1246            segment_client_side: false,
1247            // SLT doesn't like eternally running tasks since it waits for them to finish inbetween SLT files
1248            test_only_dummy_segment_client: false,
1249            egress_addresses: vec![],
1250            aws_account_id: None,
1251            aws_privatelink_availability_zones: None,
1252            launchdarkly_sdk_key: None,
1253            launchdarkly_key_map: Default::default(),
1254            config_sync_file_path: None,
1255            config_sync_timeout: Duration::from_secs(30),
1256            config_sync_loop_interval: None,
1257            bootstrap_role: Some("materialize".into()),
1258            http_host_name: Some(host_name),
1259            internal_console_redirect_url: None,
1260            tls_reload_certs: mz_server_core::cert_reload_never_reload(),
1261            helm_chart_version: None,
1262            license_key: ValidatedLicenseKey::for_tests(),
1263            external_login_password_mz_system: None,
1264            force_builtin_schema_migration: None,
1265        };
1266        // We need to run the server on its own Tokio runtime, which in turn
1267        // requires its own thread, so that we can wait for any tasks spawned
1268        // by the server to be shutdown at the end of each file. If we were to
1269        // share a Tokio runtime, tasks from the last file's server would still
1270        // be live at the start of the next file's server.
1271        let (server_addr_tx, server_addr_rx): (oneshot::Sender<Result<_, anyhow::Error>>, _) =
1272            oneshot::channel();
1273        let (internal_server_addr_tx, internal_server_addr_rx) = oneshot::channel();
1274        let (password_server_addr_tx, password_server_addr_rx) = oneshot::channel();
1275        let (internal_http_server_addr_tx, internal_http_server_addr_rx) = oneshot::channel();
1276        let (shutdown_trigger, shutdown_trigger_rx) = trigger::channel();
1277        let server_thread = thread::spawn(|| {
1278            let runtime = match Runtime::new() {
1279                Ok(runtime) => runtime,
1280                Err(e) => {
1281                    server_addr_tx
1282                        .send(Err(e.into()))
1283                        .expect("receiver should not drop first");
1284                    return;
1285                }
1286            };
1287            let server = match runtime.block_on(listeners.serve(server_config)) {
1288                Ok(runtime) => runtime,
1289                Err(e) => {
1290                    server_addr_tx
1291                        .send(Err(e.into()))
1292                        .expect("receiver should not drop first");
1293                    return;
1294                }
1295            };
1296            server_addr_tx
1297                .send(Ok(server.sql_listener_handles["external"].local_addr))
1298                .expect("receiver should not drop first");
1299            internal_server_addr_tx
1300                .send(server.sql_listener_handles["internal"].local_addr)
1301                .expect("receiver should not drop first");
1302            password_server_addr_tx
1303                .send(server.sql_listener_handles["password"].local_addr)
1304                .expect("receiver should not drop first");
1305            internal_http_server_addr_tx
1306                .send(server.http_listener_handles["internal"].local_addr)
1307                .expect("receiver should not drop first");
1308            runtime.block_on(shutdown_trigger_rx);
1309        });
1310        let server_addr = server_addr_rx.await??;
1311        let internal_server_addr = internal_server_addr_rx.await?;
1312        let password_server_addr = password_server_addr_rx.await?;
1313        let internal_http_server_addr = internal_http_server_addr_rx.await?;
1314
1315        let system_client = connect(internal_server_addr, Some("mz_system"), None)
1316            .await
1317            .unwrap();
1318        let client = connect(server_addr, None, None).await.unwrap();
1319
1320        let inner = RunnerInner {
1321            server_addr,
1322            internal_server_addr,
1323            password_server_addr,
1324            internal_http_server_addr,
1325            _shutdown_trigger: shutdown_trigger,
1326            _server_thread: server_thread.join_on_drop(),
1327            _temp_dir: temp_dir,
1328            client,
1329            system_client,
1330            clients: BTreeMap::new(),
1331            auto_index_tables: config.auto_index_tables,
1332            auto_index_selects: config.auto_index_selects,
1333            auto_transactions: config.auto_transactions,
1334            enable_table_keys: config.enable_table_keys,
1335            verbose: config.verbose,
1336            stdout: config.stdout,
1337        };
1338        inner.ensure_fixed_features().await?;
1339
1340        Ok(inner)
1341    }
1342
1343    /// Set features that should be enabled regardless of whether reset-server was
1344    /// called. These features may be set conditionally depending on the run configuration.
1345    async fn ensure_fixed_features(&self) -> Result<(), anyhow::Error> {
1346        // We turn on enable_reduce_mfp_fusion, as we wish
1347        // to get as much coverage of these features as we can.
1348        // TODO(vmarcos): Remove this code when we retire this feature flag.
1349        self.system_client
1350            .execute("ALTER SYSTEM SET enable_reduce_mfp_fusion = on", &[])
1351            .await?;
1352
1353        // Dangerous functions are useful for tests so we enable it for all tests.
1354        self.system_client
1355            .execute("ALTER SYSTEM SET unsafe_enable_unsafe_functions = on", &[])
1356            .await?;
1357        Ok(())
1358    }
1359
1360    async fn run_record<'r>(
1361        &mut self,
1362        record: &'r Record<'r>,
1363        in_transaction: &mut bool,
1364    ) -> Result<Outcome<'r>, anyhow::Error> {
1365        match &record {
1366            Record::Statement {
1367                expected_error,
1368                rows_affected,
1369                sql,
1370                location,
1371            } => {
1372                if self.auto_transactions && *in_transaction {
1373                    self.client.execute("COMMIT", &[]).await?;
1374                    *in_transaction = false;
1375                }
1376                match self
1377                    .run_statement(*expected_error, *rows_affected, sql, location.clone())
1378                    .await?
1379                {
1380                    Outcome::Success => {
1381                        if self.auto_index_tables {
1382                            let additional = mutate(sql);
1383                            for stmt in additional {
1384                                self.client.execute(&stmt, &[]).await?;
1385                            }
1386                        }
1387                        Ok(Outcome::Success)
1388                    }
1389                    other => {
1390                        if expected_error.is_some() {
1391                            Ok(other)
1392                        } else {
1393                            // If we failed to execute a statement that was supposed to succeed,
1394                            // running the rest of the tests in this file will probably cause
1395                            // false positives, so just give up on the file entirely.
1396                            Ok(Outcome::Bail {
1397                                cause: Box::new(other),
1398                                location: location.clone(),
1399                            })
1400                        }
1401                    }
1402                }
1403            }
1404            Record::Query {
1405                sql,
1406                output,
1407                location,
1408            } => {
1409                self.run_query(sql, output, location.clone(), in_transaction)
1410                    .await
1411            }
1412            Record::Simple {
1413                conn,
1414                user,
1415                password,
1416                sql,
1417                sort,
1418                output,
1419                location,
1420                ..
1421            } => {
1422                self.run_simple(
1423                    *conn,
1424                    *user,
1425                    *password,
1426                    sql,
1427                    sort.clone(),
1428                    output,
1429                    location.clone(),
1430                )
1431                .await
1432            }
1433            Record::Copy {
1434                table_name,
1435                tsv_path,
1436            } => {
1437                let tsv = tokio::fs::read(tsv_path).await?;
1438                let copy = self
1439                    .client
1440                    .copy_in(&*format!("COPY {} FROM STDIN", table_name))
1441                    .await?;
1442                tokio::pin!(copy);
1443                copy.send(bytes::Bytes::from(tsv)).await?;
1444                copy.finish().await?;
1445                Ok(Outcome::Success)
1446            }
1447            _ => Ok(Outcome::Success),
1448        }
1449    }
1450
1451    async fn run_statement<'r>(
1452        &self,
1453        expected_error: Option<&'r str>,
1454        expected_rows_affected: Option<u64>,
1455        sql: &'r str,
1456        location: Location,
1457    ) -> Result<Outcome<'r>, anyhow::Error> {
1458        static UNSUPPORTED_INDEX_STATEMENT_REGEX: LazyLock<Regex> =
1459            LazyLock::new(|| Regex::new("^(CREATE UNIQUE INDEX|REINDEX)").unwrap());
1460        if UNSUPPORTED_INDEX_STATEMENT_REGEX.is_match(sql) {
1461            // sure, we totally made you an index
1462            return Ok(Outcome::Success);
1463        }
1464
1465        match self.client.execute(sql, &[]).await {
1466            Ok(actual) => {
1467                if let Some(expected_error) = expected_error {
1468                    return Ok(Outcome::UnexpectedPlanSuccess {
1469                        expected_error,
1470                        location,
1471                    });
1472                }
1473                match expected_rows_affected {
1474                    None => Ok(Outcome::Success),
1475                    Some(expected) => {
1476                        if expected != actual {
1477                            Ok(Outcome::WrongNumberOfRowsInserted {
1478                                expected_count: expected,
1479                                actual_count: actual,
1480                                location,
1481                            })
1482                        } else {
1483                            Ok(Outcome::Success)
1484                        }
1485                    }
1486                }
1487            }
1488            Err(error) => {
1489                if let Some(expected_error) = expected_error {
1490                    if Regex::new(expected_error)?.is_match(&error.to_string_with_causes()) {
1491                        return Ok(Outcome::Success);
1492                    }
1493                    return Ok(Outcome::PlanFailure {
1494                        error: anyhow!(error),
1495                        expected_error: Some(expected_error.to_string()),
1496                        location,
1497                    });
1498                }
1499                Ok(Outcome::PlanFailure {
1500                    error: anyhow!(error),
1501                    expected_error: None,
1502                    location,
1503                })
1504            }
1505        }
1506    }
1507
1508    async fn prepare_query<'r>(
1509        &self,
1510        sql: &str,
1511        output: &'r Result<QueryOutput<'_>, &'r str>,
1512        location: Location,
1513        in_transaction: &mut bool,
1514    ) -> Result<PrepareQueryOutcome<'r>, anyhow::Error> {
1515        // get statement
1516        let statements = match mz_sql::parse::parse(sql) {
1517            Ok(statements) => statements,
1518            Err(e) => match output {
1519                Ok(_) => {
1520                    return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1521                        error: e.into(),
1522                        location,
1523                    }));
1524                }
1525                Err(expected_error) => {
1526                    if Regex::new(expected_error)?.is_match(&e.to_string_with_causes()) {
1527                        return Ok(PrepareQueryOutcome::Outcome(Outcome::Success));
1528                    } else {
1529                        return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1530                            error: e.into(),
1531                            location,
1532                        }));
1533                    }
1534                }
1535            },
1536        };
1537        let statement = match &*statements {
1538            [] => bail!("Got zero statements?"),
1539            [statement] => &statement.ast,
1540            _ => bail!("Got multiple statements: {:?}", statements),
1541        };
1542        let (is_select, num_attributes) = match statement {
1543            Statement::Select(stmt) => (true, derive_num_attributes(&stmt.query.body)),
1544            _ => (false, None),
1545        };
1546
1547        match output {
1548            Ok(_) => {
1549                if self.auto_transactions && !*in_transaction {
1550                    // No ISOLATION LEVEL SERIALIZABLE because of database-issues#5323
1551                    self.client.execute("BEGIN", &[]).await?;
1552                    *in_transaction = true;
1553                }
1554            }
1555            Err(_) => {
1556                if self.auto_transactions && *in_transaction {
1557                    self.client.execute("COMMIT", &[]).await?;
1558                    *in_transaction = false;
1559                }
1560            }
1561        }
1562
1563        // `SHOW` commands reference catalog schema, thus are not in the same timedomain and not
1564        // allowed in the same transaction, see:
1565        // https://materialize.com/docs/sql/begin/#same-timedomain-error
1566        match statement {
1567            Statement::Show(..) => {
1568                if self.auto_transactions && *in_transaction {
1569                    self.client.execute("COMMIT", &[]).await?;
1570                    *in_transaction = false;
1571                }
1572            }
1573            _ => (),
1574        }
1575        Ok(PrepareQueryOutcome::QueryPrepared(QueryInfo {
1576            is_select,
1577            num_attributes,
1578        }))
1579    }
1580
1581    async fn execute_query<'r>(
1582        &self,
1583        sql: &str,
1584        output: &'r Result<QueryOutput<'_>, &'r str>,
1585        location: Location,
1586    ) -> Result<Outcome<'r>, anyhow::Error> {
1587        let rows = match self.client.query(sql, &[]).await {
1588            Ok(rows) => rows,
1589            Err(error) => {
1590                let error_string = error.to_string_with_causes();
1591                return match output {
1592                    Ok(_) => {
1593                        if error_string.contains("supported") || error_string.contains("overload") {
1594                            // this is a failure, but it's caused by lack of support rather than by bugs
1595                            Ok(Outcome::Unsupported {
1596                                error: anyhow!(error),
1597                                location,
1598                            })
1599                        } else {
1600                            Ok(Outcome::PlanFailure {
1601                                error: anyhow!(error),
1602                                expected_error: None,
1603                                location,
1604                            })
1605                        }
1606                    }
1607                    Err(expected_error) => {
1608                        if Regex::new(expected_error)?.is_match(&error_string) {
1609                            Ok(Outcome::Success)
1610                        } else {
1611                            Ok(Outcome::PlanFailure {
1612                                error: anyhow!(error),
1613                                expected_error: Some(expected_error.to_string()),
1614                                location,
1615                            })
1616                        }
1617                    }
1618                };
1619            }
1620        };
1621
1622        // unpack expected output
1623        let QueryOutput {
1624            sort,
1625            types: expected_types,
1626            column_names: expected_column_names,
1627            output: expected_output,
1628            mode,
1629            ..
1630        } = match output {
1631            Err(expected_error) => {
1632                return Ok(Outcome::UnexpectedPlanSuccess {
1633                    expected_error,
1634                    location,
1635                });
1636            }
1637            Ok(query_output) => query_output,
1638        };
1639
1640        // format output
1641        let mut formatted_rows = vec![];
1642        for row in &rows {
1643            if row.len() != expected_types.len() {
1644                return Ok(Outcome::WrongColumnCount {
1645                    expected_count: expected_types.len(),
1646                    actual_count: row.len(),
1647                    location,
1648                });
1649            }
1650            let row = format_row(row, expected_types, *mode);
1651            formatted_rows.push(row);
1652        }
1653
1654        // sort formatted output
1655        if let Sort::Row = sort {
1656            formatted_rows.sort();
1657        }
1658        let mut values = formatted_rows.into_iter().flatten().collect::<Vec<_>>();
1659        if let Sort::Value = sort {
1660            values.sort();
1661        }
1662
1663        // Various checks as long as there are returned rows.
1664        if let Some(row) = rows.get(0) {
1665            // check column names
1666            if let Some(expected_column_names) = expected_column_names {
1667                let actual_column_names = row
1668                    .columns()
1669                    .iter()
1670                    .map(|t| ColumnName::from(t.name()))
1671                    .collect::<Vec<_>>();
1672                if expected_column_names != &actual_column_names {
1673                    return Ok(Outcome::WrongColumnNames {
1674                        expected_column_names,
1675                        actual_column_names,
1676                        actual_output: Output::Values(values),
1677                        location,
1678                    });
1679                }
1680            }
1681        }
1682
1683        // check output
1684        match expected_output {
1685            Output::Values(expected_values) => {
1686                if values != *expected_values {
1687                    return Ok(Outcome::OutputFailure {
1688                        expected_output,
1689                        actual_raw_output: rows,
1690                        actual_output: Output::Values(values),
1691                        location,
1692                    });
1693                }
1694            }
1695            Output::Hashed {
1696                num_values,
1697                md5: expected_md5,
1698            } => {
1699                let mut hasher = Md5::new();
1700                for value in &values {
1701                    hasher.update(value);
1702                    hasher.update("\n");
1703                }
1704                let md5 = format!("{:x}", hasher.finalize());
1705                if values.len() != *num_values || md5 != *expected_md5 {
1706                    return Ok(Outcome::OutputFailure {
1707                        expected_output,
1708                        actual_raw_output: rows,
1709                        actual_output: Output::Hashed {
1710                            num_values: values.len(),
1711                            md5,
1712                        },
1713                        location,
1714                    });
1715                }
1716            }
1717        }
1718
1719        Ok(Outcome::Success)
1720    }
1721
1722    async fn execute_view_inner<'r>(
1723        &self,
1724        sql: &str,
1725        output: &'r Result<QueryOutput<'_>, &'r str>,
1726        location: Location,
1727    ) -> Result<Option<Outcome<'r>>, anyhow::Error> {
1728        print_sql_if(self.stdout, sql, self.verbose);
1729        let sql_result = self.client.execute(sql, &[]).await;
1730
1731        // Evaluate if we already reached an outcome or not.
1732        let tentative_outcome = if let Err(view_error) = sql_result {
1733            if let Err(expected_error) = output {
1734                if Regex::new(expected_error)?.is_match(&view_error.to_string_with_causes()) {
1735                    Some(Outcome::Success)
1736                } else {
1737                    Some(Outcome::PlanFailure {
1738                        error: view_error.into(),
1739                        expected_error: Some(expected_error.to_string()),
1740                        location: location.clone(),
1741                    })
1742                }
1743            } else {
1744                Some(Outcome::PlanFailure {
1745                    error: view_error.into(),
1746                    expected_error: None,
1747                    location: location.clone(),
1748                })
1749            }
1750        } else {
1751            None
1752        };
1753        Ok(tentative_outcome)
1754    }
1755
1756    async fn execute_view<'r>(
1757        &self,
1758        sql: &str,
1759        num_attributes: Option<usize>,
1760        output: &'r Result<QueryOutput<'_>, &'r str>,
1761        location: Location,
1762    ) -> Result<Outcome<'r>, anyhow::Error> {
1763        // Create indexed view SQL commands and execute `CREATE VIEW`.
1764        let expected_column_names = if let Ok(QueryOutput { column_names, .. }) = output {
1765            column_names.clone()
1766        } else {
1767            None
1768        };
1769        let (create_view, create_index, view_sql, drop_view) = generate_view_sql(
1770            sql,
1771            Uuid::new_v4().as_simple(),
1772            num_attributes,
1773            expected_column_names,
1774        );
1775        let tentative_outcome = self
1776            .execute_view_inner(create_view.as_str(), output, location.clone())
1777            .await?;
1778
1779        // Either we already have an outcome or alternatively,
1780        // we proceed to index and query the view.
1781        if let Some(view_outcome) = tentative_outcome {
1782            return Ok(view_outcome);
1783        }
1784
1785        let tentative_outcome = self
1786            .execute_view_inner(create_index.as_str(), output, location.clone())
1787            .await?;
1788
1789        let view_outcome;
1790        if let Some(outcome) = tentative_outcome {
1791            view_outcome = outcome;
1792        } else {
1793            print_sql_if(self.stdout, view_sql.as_str(), self.verbose);
1794            view_outcome = self
1795                .execute_query(view_sql.as_str(), output, location.clone())
1796                .await?;
1797        }
1798
1799        // Remember to clean up after ourselves by dropping the view.
1800        print_sql_if(self.stdout, drop_view.as_str(), self.verbose);
1801        self.client.execute(drop_view.as_str(), &[]).await?;
1802
1803        Ok(view_outcome)
1804    }
1805
1806    async fn run_query<'r>(
1807        &self,
1808        sql: &'r str,
1809        output: &'r Result<QueryOutput<'_>, &'r str>,
1810        location: Location,
1811        in_transaction: &mut bool,
1812    ) -> Result<Outcome<'r>, anyhow::Error> {
1813        let prepare_outcome = self
1814            .prepare_query(sql, output, location.clone(), in_transaction)
1815            .await?;
1816        match prepare_outcome {
1817            PrepareQueryOutcome::QueryPrepared(QueryInfo {
1818                is_select,
1819                num_attributes,
1820            }) => {
1821                let query_outcome = self.execute_query(sql, output, location.clone()).await?;
1822                if is_select && self.auto_index_selects {
1823                    let view_outcome = self
1824                        .execute_view(sql, None, output, location.clone())
1825                        .await?;
1826
1827                    // We compare here the query-based and view-based outcomes.
1828                    // We only produce a test failure if the outcomes are of different
1829                    // variant types, thus accepting smaller deviations in the details
1830                    // produced for each variant.
1831                    if std::mem::discriminant::<Outcome>(&query_outcome)
1832                        != std::mem::discriminant::<Outcome>(&view_outcome)
1833                    {
1834                        // Before producing a failure outcome, we try to obtain a new
1835                        // outcome for view-based execution exploiting analysis of the
1836                        // number of attributes. This two-level strategy can avoid errors
1837                        // produced by column ambiguity in the `SELECT`.
1838                        let view_outcome = if num_attributes.is_some() {
1839                            self.execute_view(sql, num_attributes, output, location.clone())
1840                                .await?
1841                        } else {
1842                            view_outcome
1843                        };
1844
1845                        if std::mem::discriminant::<Outcome>(&query_outcome)
1846                            != std::mem::discriminant::<Outcome>(&view_outcome)
1847                        {
1848                            let inconsistent_view_outcome = Outcome::InconsistentViewOutcome {
1849                                query_outcome: Box::new(query_outcome),
1850                                view_outcome: Box::new(view_outcome),
1851                                location: location.clone(),
1852                            };
1853                            // Determine if this inconsistent view outcome should be reported
1854                            // as an error or only as a warning.
1855                            let outcome = if should_warn(&inconsistent_view_outcome) {
1856                                Outcome::Warning {
1857                                    cause: Box::new(inconsistent_view_outcome),
1858                                    location: location.clone(),
1859                                }
1860                            } else {
1861                                inconsistent_view_outcome
1862                            };
1863                            return Ok(outcome);
1864                        }
1865                    }
1866                }
1867                Ok(query_outcome)
1868            }
1869            PrepareQueryOutcome::Outcome(outcome) => Ok(outcome),
1870        }
1871    }
1872
1873    async fn get_conn(
1874        &mut self,
1875        name: Option<&str>,
1876        user: Option<&str>,
1877        password: Option<&str>,
1878    ) -> Result<&tokio_postgres::Client, tokio_postgres::Error> {
1879        match name {
1880            None => Ok(&self.client),
1881            Some(name) => {
1882                if !self.clients.contains_key(name) {
1883                    let addr = if matches!(user, Some("mz_system") | Some("mz_support")) {
1884                        self.internal_server_addr
1885                    } else if password.is_some() {
1886                        // Use password server for password authentication
1887                        self.password_server_addr
1888                    } else {
1889                        self.server_addr
1890                    };
1891                    let client = connect(addr, user, password).await?;
1892                    self.clients.insert(name.into(), client);
1893                }
1894                Ok(self.clients.get(name).unwrap())
1895            }
1896        }
1897    }
1898
1899    async fn run_simple<'r>(
1900        &mut self,
1901        conn: Option<&'r str>,
1902        user: Option<&'r str>,
1903        password: Option<&'r str>,
1904        sql: &'r str,
1905        sort: Sort,
1906        output: &'r Output,
1907        location: Location,
1908    ) -> Result<Outcome<'r>, anyhow::Error> {
1909        let actual = match self.get_conn(conn, user, password).await {
1910            Ok(client) => match client.simple_query(sql).await {
1911                Ok(result) => {
1912                    let mut rows = Vec::new();
1913
1914                    for m in result.into_iter() {
1915                        match m {
1916                            SimpleQueryMessage::Row(row) => {
1917                                let mut s = vec![];
1918                                for i in 0..row.len() {
1919                                    s.push(row.get(i).unwrap_or("NULL"));
1920                                }
1921                                rows.push(s.join(","));
1922                            }
1923                            SimpleQueryMessage::CommandComplete(count) => {
1924                                // This applies any sort on the COMPLETE line as
1925                                // well, but we do the same for the expected output.
1926                                rows.push(format!("COMPLETE {}", count));
1927                            }
1928                            SimpleQueryMessage::RowDescription(_) => {}
1929                            _ => panic!("unexpected"),
1930                        }
1931                    }
1932
1933                    if let Sort::Row = sort {
1934                        rows.sort();
1935                    }
1936
1937                    Output::Values(rows)
1938                }
1939                // Errors can contain multiple lines (say if there are details), and rewrite
1940                // sticks them each on their own line, so we need to split up the lines here to
1941                // each be its own String in the Vec.
1942                Err(error) => Output::Values(
1943                    error
1944                        .to_string_with_causes()
1945                        .lines()
1946                        .map(|s| s.to_string())
1947                        .collect(),
1948                ),
1949            },
1950            Err(error) => Output::Values(
1951                error
1952                    .to_string_with_causes()
1953                    .lines()
1954                    .map(|s| s.to_string())
1955                    .collect(),
1956            ),
1957        };
1958        if *output != actual {
1959            Ok(Outcome::OutputFailure {
1960                expected_output: output,
1961                actual_raw_output: vec![],
1962                actual_output: actual,
1963                location,
1964            })
1965        } else {
1966            Ok(Outcome::Success)
1967        }
1968    }
1969
1970    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
1971        let url = format!(
1972            "http://{}/api/catalog/check",
1973            self.internal_http_server_addr
1974        );
1975        let response: serde_json::Value = reqwest::get(&url).await?.json().await?;
1976
1977        if let Some(inconsistencies) = response.get("err") {
1978            let inconsistencies = serde_json::to_string_pretty(&inconsistencies)
1979                .expect("serializing Value cannot fail");
1980            Err(anyhow::anyhow!("Catalog inconsistency\n{inconsistencies}"))
1981        } else {
1982            Ok(())
1983        }
1984    }
1985}
1986
1987async fn connect(
1988    addr: SocketAddr,
1989    user: Option<&str>,
1990    password: Option<&str>,
1991) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
1992    let mut config = tokio_postgres::Config::new();
1993    config.host(addr.ip().to_string());
1994    config.port(addr.port());
1995    config.user(user.unwrap_or("materialize"));
1996    if let Some(password) = password {
1997        config.password(password);
1998    }
1999    let (client, connection) = config.connect(NoTls).await?;
2000
2001    task::spawn(|| "sqllogictest_connect", async move {
2002        if let Err(e) = connection.await {
2003            eprintln!("connection error: {}", e);
2004        }
2005    });
2006    Ok(client)
2007}
2008
2009pub trait WriteFmt {
2010    fn write_fmt(&self, fmt: fmt::Arguments<'_>);
2011}
2012
2013pub struct RunConfig<'a> {
2014    pub stdout: &'a dyn WriteFmt,
2015    pub stderr: &'a dyn WriteFmt,
2016    pub verbose: bool,
2017    pub quiet: bool,
2018    pub postgres_url: String,
2019    pub prefix: String,
2020    pub no_fail: bool,
2021    pub fail_fast: bool,
2022    pub auto_index_tables: bool,
2023    pub auto_index_selects: bool,
2024    pub auto_transactions: bool,
2025    pub enable_table_keys: bool,
2026    pub orchestrator_process_wrapper: Option<String>,
2027    pub tracing: TracingCliArgs,
2028    pub tracing_handle: TracingHandle,
2029    pub system_parameter_defaults: BTreeMap<String, String>,
2030    /// Persist state is handled specially because:
2031    /// - Persist background workers do not necessarily shut down immediately once the server is
2032    ///   shut down, and may panic if their storage is deleted out from under them.
2033    /// - It's safe for different databases to reference the same state: all data is scoped by UUID.
2034    pub persist_dir: TempDir,
2035    pub replicas: usize,
2036    pub replica_size: String,
2037}
2038
2039/// Indentation used for verbose output of SQL statements.
2040const PRINT_INDENT: usize = 4;
2041
2042fn print_record(config: &RunConfig<'_>, record: &Record) {
2043    match record {
2044        Record::Statement { sql, .. } | Record::Query { sql, .. } => {
2045            print_sql(config.stdout, sql, None)
2046        }
2047        Record::Simple { conn, sql, .. } => print_sql(config.stdout, sql, *conn),
2048        Record::Copy {
2049            table_name,
2050            tsv_path,
2051        } => {
2052            writeln!(
2053                config.stdout,
2054                "{}slt copy {} from {}",
2055                " ".repeat(PRINT_INDENT),
2056                table_name,
2057                tsv_path
2058            )
2059        }
2060        Record::ResetServer => {
2061            writeln!(config.stdout, "{}reset-server", " ".repeat(PRINT_INDENT))
2062        }
2063        Record::Halt => {
2064            writeln!(config.stdout, "{}halt", " ".repeat(PRINT_INDENT))
2065        }
2066        Record::HashThreshold { threshold } => {
2067            writeln!(
2068                config.stdout,
2069                "{}hash-threshold {}",
2070                " ".repeat(PRINT_INDENT),
2071                threshold
2072            )
2073        }
2074    }
2075}
2076
2077fn print_sql_if<'a>(stdout: &'a dyn WriteFmt, sql: &str, cond: bool) {
2078    if cond {
2079        print_sql(stdout, sql, None)
2080    }
2081}
2082
2083fn print_sql<'a>(stdout: &'a dyn WriteFmt, sql: &str, conn: Option<&str>) {
2084    let text = if let Some(conn) = conn {
2085        format!("[conn={}] {}", conn, sql)
2086    } else {
2087        sql.to_string()
2088    };
2089    writeln!(stdout, "{}", util::indent(&text, PRINT_INDENT))
2090}
2091
2092/// Regular expressions for matching error messages that should force a plan failure
2093/// in an inconsistent view outcome into a warning if the corresponding query succeeds.
2094const INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS: [&str; 9] = [
2095    // The following are unfixable errors in indexed views given our
2096    // current constraints.
2097    "cannot materialize call to",
2098    "SHOW commands are not allowed in views",
2099    "cannot create view with unstable dependencies",
2100    "cannot use wildcard expansions or NATURAL JOINs in a view that depends on system objects",
2101    "no valid schema selected",
2102    r#"system schema '\w+' cannot be modified"#,
2103    r#"permission denied for (SCHEMA|CLUSTER) "(\w+\.)?\w+""#,
2104    // NOTE(vmarcos): Column ambiguity that could not be eliminated by our
2105    // currently implemented syntactic rewrites is considered unfixable.
2106    // In addition, if some column cannot be dealt with, e.g., in `ORDER BY`
2107    // references, we treat this condition as unfixable as well.
2108    r#"column "[\w\?]+" specified more than once"#,
2109    r#"column "(\w+\.)?\w+" does not exist"#,
2110];
2111
2112/// Evaluates if the given outcome should be returned directly or if it should
2113/// be wrapped as a warning. Note that this function should be used for outcomes
2114/// that can be judged in a context-independent manner, i.e., the outcome itself
2115/// provides enough information as to whether a warning should be emitted or not.
2116fn should_warn(outcome: &Outcome) -> bool {
2117    match outcome {
2118        Outcome::InconsistentViewOutcome {
2119            query_outcome,
2120            view_outcome,
2121            ..
2122        } => match (query_outcome.as_ref(), view_outcome.as_ref()) {
2123            (Outcome::Success, Outcome::PlanFailure { error, .. }) => {
2124                INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS.iter().any(|s| {
2125                    Regex::new(s)
2126                        .expect("unexpected error in regular expression parsing")
2127                        .is_match(&error.to_string_with_causes())
2128                })
2129            }
2130            _ => false,
2131        },
2132        _ => false,
2133    }
2134}
2135
2136pub async fn run_string(
2137    runner: &mut Runner<'_>,
2138    source: &str,
2139    input: &str,
2140) -> Result<Outcomes, anyhow::Error> {
2141    runner.reset_database().await?;
2142
2143    let mut outcomes = Outcomes::default();
2144    let mut parser = crate::parser::Parser::new(source, input);
2145    // Transactions are currently relatively slow. Since sqllogictest runs in a single connection
2146    // there should be no difference in having longer running transactions.
2147    let mut in_transaction = false;
2148    writeln!(runner.config.stdout, "--- {}", source);
2149
2150    for record in parser.parse_records()? {
2151        // In maximal-verbose mode, print the query before attempting to run
2152        // it. Running the query might panic, so it is important to print out
2153        // what query we are trying to run *before* we panic.
2154        if runner.config.verbose {
2155            print_record(runner.config, &record);
2156        }
2157
2158        let outcome = runner
2159            .run_record(&record, &mut in_transaction)
2160            .await
2161            .map_err(|err| format!("In {}:\n{}", source, err))
2162            .unwrap();
2163
2164        // Print warnings and failures in verbose mode.
2165        if !runner.config.quiet && !outcome.success() {
2166            if !runner.config.verbose {
2167                // If `verbose` is enabled, we'll already have printed the record,
2168                // so don't print it again. Yes, this is an ugly bit of logic.
2169                // Please don't try to consolidate it with the `print_record`
2170                // call above, as it's important to have a mode in which records
2171                // are printed before they are run, so that if running the
2172                // record panics, you can tell which record caused it.
2173                if !outcome.failure() {
2174                    writeln!(
2175                        runner.config.stdout,
2176                        "{}",
2177                        util::indent("Warning detected for: ", 4)
2178                    );
2179                }
2180                print_record(runner.config, &record);
2181            }
2182            if runner.config.verbose || outcome.failure() {
2183                writeln!(
2184                    runner.config.stdout,
2185                    "{}",
2186                    util::indent(&outcome.to_string(), 4)
2187                );
2188                writeln!(runner.config.stdout, "{}", util::indent("----", 4));
2189            }
2190        }
2191
2192        outcomes.stats[outcome.code()] += 1;
2193        if outcome.failure() {
2194            outcomes.details.push(format!("{}", outcome));
2195        }
2196
2197        if let Outcome::Bail { .. } = outcome {
2198            break;
2199        }
2200
2201        if runner.config.fail_fast && outcome.failure() {
2202            break;
2203        }
2204    }
2205    Ok(outcomes)
2206}
2207
2208pub async fn run_file(runner: &mut Runner<'_>, filename: &Path) -> Result<Outcomes, anyhow::Error> {
2209    let mut input = String::new();
2210    File::open(filename)?.read_to_string(&mut input)?;
2211    let outcomes = run_string(runner, &format!("{}", filename.display()), &input).await?;
2212    runner.check_catalog().await?;
2213
2214    Ok(outcomes)
2215}
2216
2217pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<(), anyhow::Error> {
2218    runner.reset_database().await?;
2219
2220    let mut file = OpenOptions::new().read(true).write(true).open(filename)?;
2221
2222    let mut input = String::new();
2223    file.read_to_string(&mut input)?;
2224
2225    let mut buf = RewriteBuffer::new(&input);
2226
2227    let mut parser = crate::parser::Parser::new(filename.to_str().unwrap_or(""), &input);
2228    writeln!(runner.config.stdout, "--- {}", filename.display());
2229    let mut in_transaction = false;
2230
2231    fn append_values_output(
2232        buf: &mut RewriteBuffer,
2233        input: &String,
2234        expected_output: &str,
2235        mode: &Mode,
2236        types: &Vec<Type>,
2237        column_names: Option<&Vec<ColumnName>>,
2238        actual_output: &Vec<String>,
2239        multiline: bool,
2240    ) {
2241        buf.append_header(input, expected_output, column_names);
2242
2243        for (i, row) in actual_output.chunks(types.len()).enumerate() {
2244            match mode {
2245                // In Cockroach mode, output each row on its own line, with
2246                // two spaces between each column.
2247                Mode::Cockroach => {
2248                    if i != 0 {
2249                        buf.append("\n");
2250                    }
2251
2252                    if row.len() == 0 {
2253                        // nothing to do
2254                    } else if row.len() == 1 {
2255                        // If there is only one column, then there is no need for space
2256                        // substitution, so we only do newline substitution.
2257                        if multiline {
2258                            buf.append(&row[0]);
2259                        } else {
2260                            buf.append(&row[0].replace('\n', "⏎"))
2261                        }
2262                    } else {
2263                        // Substitute spaces with ␠ to avoid mistaking the spaces in the result
2264                        // values with spaces that separate columns.
2265                        buf.append(
2266                            &row.iter()
2267                                .map(|col| {
2268                                    let mut col = col.replace(' ', "␠");
2269                                    if !multiline {
2270                                        col = col.replace('\n', "⏎");
2271                                    }
2272                                    col
2273                                })
2274                                .join("  "),
2275                        );
2276                    }
2277                }
2278                // In standard mode, output each value on its own line,
2279                // and ignore row boundaries.
2280                // No need to substitute spaces, because every value (not row) is on a separate
2281                // line. But we do need to substitute newlines.
2282                Mode::Standard => {
2283                    for (j, col) in row.iter().enumerate() {
2284                        if i != 0 || j != 0 {
2285                            buf.append("\n");
2286                        }
2287                        buf.append(&if multiline {
2288                            col.clone()
2289                        } else {
2290                            col.replace('\n', "⏎")
2291                        });
2292                    }
2293                }
2294            }
2295        }
2296    }
2297
2298    for record in parser.parse_records()? {
2299        let outcome = runner.run_record(&record, &mut in_transaction).await?;
2300
2301        match (&record, &outcome) {
2302            // If we see an output failure for a query, rewrite the expected output
2303            // to match the observed output.
2304            (
2305                Record::Query {
2306                    output:
2307                        Ok(QueryOutput {
2308                            mode,
2309                            output: Output::Values(_),
2310                            output_str: expected_output,
2311                            types,
2312                            column_names,
2313                            multiline,
2314                            ..
2315                        }),
2316                    ..
2317                },
2318                Outcome::OutputFailure {
2319                    actual_output: Output::Values(actual_output),
2320                    ..
2321                },
2322            ) => {
2323                append_values_output(
2324                    &mut buf,
2325                    &input,
2326                    expected_output,
2327                    mode,
2328                    types,
2329                    column_names.as_ref(),
2330                    actual_output,
2331                    *multiline,
2332                );
2333            }
2334            (
2335                Record::Query {
2336                    output:
2337                        Ok(QueryOutput {
2338                            mode,
2339                            output: Output::Values(_),
2340                            output_str: expected_output,
2341                            types,
2342                            multiline,
2343                            ..
2344                        }),
2345                    ..
2346                },
2347                Outcome::WrongColumnNames {
2348                    actual_column_names,
2349                    actual_output: Output::Values(actual_output),
2350                    ..
2351                },
2352            ) => {
2353                append_values_output(
2354                    &mut buf,
2355                    &input,
2356                    expected_output,
2357                    mode,
2358                    types,
2359                    Some(actual_column_names),
2360                    actual_output,
2361                    *multiline,
2362                );
2363            }
2364            (
2365                Record::Query {
2366                    output:
2367                        Ok(QueryOutput {
2368                            output: Output::Hashed { .. },
2369                            output_str: expected_output,
2370                            column_names,
2371                            ..
2372                        }),
2373                    ..
2374                },
2375                Outcome::OutputFailure {
2376                    actual_output: Output::Hashed { num_values, md5 },
2377                    ..
2378                },
2379            ) => {
2380                buf.append_header(&input, expected_output, column_names.as_ref());
2381
2382                buf.append(format!("{} values hashing to {}\n", num_values, md5).as_str())
2383            }
2384            (
2385                Record::Simple {
2386                    output_str: expected_output,
2387                    ..
2388                },
2389                Outcome::OutputFailure {
2390                    actual_output: Output::Values(actual_output),
2391                    ..
2392                },
2393            ) => {
2394                buf.append_header(&input, expected_output, None);
2395
2396                for (i, row) in actual_output.iter().enumerate() {
2397                    if i != 0 {
2398                        buf.append("\n");
2399                    }
2400                    buf.append(row);
2401                }
2402            }
2403            (
2404                Record::Query {
2405                    sql,
2406                    output: Err(err),
2407                    ..
2408                },
2409                outcome,
2410            )
2411            | (
2412                Record::Statement {
2413                    expected_error: Some(err),
2414                    sql,
2415                    ..
2416                },
2417                outcome,
2418            ) if outcome.err_msg().is_some() => {
2419                buf.rewrite_expected_error(&input, err, &outcome.err_msg().unwrap(), sql)
2420            }
2421            (_, Outcome::Success) => {}
2422            _ => bail!("unexpected: {:?} {:?}", record, outcome),
2423        }
2424    }
2425
2426    file.set_len(0)?;
2427    file.seek(SeekFrom::Start(0))?;
2428    file.write_all(buf.finish().as_bytes())?;
2429    file.sync_all()?;
2430    Ok(())
2431}
2432
2433/// Provides a means to rewrite the `.slt` file while iterating over it.
2434///
2435/// This struct takes the slt file as its `input`, tracks a cursor into it
2436/// (`input_offset`), and provides a buffer (`output`) to store the rewritten
2437/// results.
2438///
2439/// Functions that modify the file will lazily move `input` into `output` using
2440/// `flush_to`. However, those calls should all be interior to other functions.
2441#[derive(Debug)]
2442struct RewriteBuffer<'a> {
2443    input: &'a str,
2444    input_offset: usize,
2445    output: String,
2446}
2447
2448impl<'a> RewriteBuffer<'a> {
2449    fn new(input: &'a str) -> RewriteBuffer<'a> {
2450        RewriteBuffer {
2451            input,
2452            input_offset: 0,
2453            output: String::new(),
2454        }
2455    }
2456
2457    fn flush_to(&mut self, offset: usize) {
2458        assert!(offset >= self.input_offset);
2459        let chunk = &self.input[self.input_offset..offset];
2460        self.output.push_str(chunk);
2461        self.input_offset = offset;
2462    }
2463
2464    fn skip_to(&mut self, offset: usize) {
2465        assert!(offset >= self.input_offset);
2466        self.input_offset = offset;
2467    }
2468
2469    fn append(&mut self, s: &str) {
2470        self.output.push_str(s);
2471    }
2472
2473    fn append_header(
2474        &mut self,
2475        input: &String,
2476        expected_output: &str,
2477        column_names: Option<&Vec<ColumnName>>,
2478    ) {
2479        // Output everything before this record.
2480        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2481        #[allow(clippy::as_conversions)]
2482        let offset = expected_output.as_ptr() as usize - input.as_ptr() as usize;
2483        self.flush_to(offset);
2484        self.skip_to(offset + expected_output.len());
2485
2486        // Attempt to install the result separator (----), if it does
2487        // not already exist.
2488        if self.peek_last(5) == "\n----" {
2489            self.append("\n");
2490        } else if self.peek_last(6) != "\n----\n" {
2491            self.append("\n----\n");
2492        }
2493
2494        let Some(names) = column_names else {
2495            return;
2496        };
2497        self.append(
2498            &names
2499                .iter()
2500                .map(|name| name.replace(' ', "␠"))
2501                .collect::<Vec<_>>()
2502                .join(" "),
2503        );
2504        self.append("\n");
2505    }
2506
2507    fn rewrite_expected_error(
2508        &mut self,
2509        input: &String,
2510        old_err: &str,
2511        new_err: &str,
2512        query: &str,
2513    ) {
2514        // Output everything before this error message.
2515        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2516        #[allow(clippy::as_conversions)]
2517        let err_offset = old_err.as_ptr() as usize - input.as_ptr() as usize;
2518        self.flush_to(err_offset);
2519        self.append(new_err);
2520        self.append("\n");
2521        self.append(query);
2522        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2523        #[allow(clippy::as_conversions)]
2524        self.skip_to(query.as_ptr() as usize - input.as_ptr() as usize + query.len())
2525    }
2526
2527    fn peek_last(&self, n: usize) -> &str {
2528        &self.output[self.output.len() - n..]
2529    }
2530
2531    fn finish(mut self) -> String {
2532        self.flush_to(self.input.len());
2533        self.output
2534    }
2535}
2536
2537/// Generates view creation, view indexing, view querying, and view
2538/// dropping SQL commands for a given `SELECT` query. If the number
2539/// of attributes produced by the query is known, the view commands
2540/// are specialized to avoid issues with column ambiguity. This
2541/// function is a helper for `--auto_index_selects` and assumes that
2542/// the provided input SQL has already been run through the parser,
2543/// resulting in a valid `SELECT` statement.
2544fn generate_view_sql(
2545    sql: &str,
2546    view_uuid: &Simple,
2547    num_attributes: Option<usize>,
2548    expected_column_names: Option<Vec<ColumnName>>,
2549) -> (String, String, String, String) {
2550    // To create the view, re-parse the sql; note that we must find exactly
2551    // one statement and it must be a `SELECT`.
2552    // NOTE(vmarcos): Direct string manipulation was attempted while
2553    // prototyping the code below, which avoids the extra parsing and
2554    // data structure cloning. However, running DDL is so slow that
2555    // it did not matter in terms of runtime. We can revisit this if
2556    // DDL cost drops dramatically in the future.
2557    let stmts = parser::parse_statements(sql).unwrap_or_default();
2558    assert!(stmts.len() == 1);
2559    let (query, query_as_of) = match &stmts[0].ast {
2560        Statement::Select(stmt) => (&stmt.query, &stmt.as_of),
2561        _ => unreachable!("This function should only be called for SELECTs"),
2562    };
2563
2564    // Prior to creating the view, process the `ORDER BY` clause of
2565    // the `SELECT` query, if any. Ordering is not preserved when a
2566    // view includes an `ORDER BY` clause and must be re-enforced by
2567    // an external `ORDER BY` clause when querying the view.
2568    let (view_order_by, extra_columns, distinct) = if num_attributes.is_none() {
2569        (query.order_by.clone(), vec![], None)
2570    } else {
2571        derive_order_by(&query.body, &query.order_by)
2572    };
2573
2574    // Since one-shot SELECT statements may contain ambiguous column names,
2575    // we either use the expected column names, if that option was
2576    // provided, or else just rename the output schema of the view
2577    // using numerically increasing attribute names, whenever possible.
2578    // This strategy makes it possible to use `CREATE INDEX`, thus
2579    // matching the behavior of the option `auto_index_tables`. However,
2580    // we may be presented with a `SELECT *` query, in which case the parser
2581    // does not produce sufficient information to allow us to compute
2582    // the number of output columns. In the latter case, we are supplied
2583    // with `None` for `num_attributes` and just employ the command
2584    // `CREATE DEFAULT INDEX` instead. Additionally, the view is created
2585    // without schema renaming. This strategy is insufficient to dodge
2586    // column name ambiguity in all cases, but we assume here that we
2587    // can adjust the (hopefully) small number of tests that eventually
2588    // challenge us in this particular way.
2589    let name = UnresolvedItemName(vec![Ident::new_unchecked(format!("v{}", view_uuid))]);
2590    let projection = expected_column_names.map_or_else(
2591        || {
2592            num_attributes.map_or(vec![], |n| {
2593                (1..=n)
2594                    .map(|i| Ident::new_unchecked(format!("a{i}")))
2595                    .collect()
2596            })
2597        },
2598        |cols| {
2599            cols.iter()
2600                .map(|c| Ident::new_unchecked(c.as_str()))
2601                .collect()
2602        },
2603    );
2604    let columns: Vec<Ident> = projection
2605        .iter()
2606        .cloned()
2607        .chain(extra_columns.iter().map(|item| {
2608            if let SelectItem::Expr {
2609                expr: _,
2610                alias: Some(ident),
2611            } = item
2612            {
2613                ident.clone()
2614            } else {
2615                unreachable!("alias must be given for extra column")
2616            }
2617        }))
2618        .collect();
2619
2620    // Build a `CREATE VIEW` with the columns computed above.
2621    let mut query = query.clone();
2622    if extra_columns.len() > 0 {
2623        match &mut query.body {
2624            SetExpr::Select(stmt) => stmt.projection.extend(extra_columns.iter().cloned()),
2625            _ => unimplemented!("cannot yet rewrite projections of nested queries"),
2626        }
2627    }
2628    let create_view = AstStatement::<Raw>::CreateView(CreateViewStatement {
2629        if_exists: IfExistsBehavior::Error,
2630        temporary: false,
2631        definition: ViewDefinition {
2632            name: name.clone(),
2633            columns: columns.clone(),
2634            query,
2635        },
2636    })
2637    .to_ast_string_stable();
2638
2639    // We then create either a `CREATE INDEX` or a `CREATE DEFAULT INDEX`
2640    // statement, depending on whether we could obtain the number of
2641    // attributes from the original `SELECT`.
2642    let create_index = AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2643        name: None,
2644        in_cluster: None,
2645        on_name: RawItemName::Name(name.clone()),
2646        key_parts: if columns.len() == 0 {
2647            None
2648        } else {
2649            Some(
2650                columns
2651                    .iter()
2652                    .map(|ident| Expr::Identifier(vec![ident.clone()]))
2653                    .collect(),
2654            )
2655        },
2656        with_options: Vec::new(),
2657        if_not_exists: false,
2658    })
2659    .to_ast_string_stable();
2660
2661    // Assert if DISTINCT semantics are unchanged from view
2662    let distinct_unneeded = extra_columns.len() == 0
2663        || match distinct {
2664            None | Some(Distinct::On(_)) => true,
2665            Some(Distinct::EntireRow) => false,
2666        };
2667    let distinct = if distinct_unneeded { None } else { distinct };
2668
2669    // `SELECT [* | {projection}] FROM {name} [ORDER BY {view_order_by}]`
2670    let view_sql = AstStatement::<Raw>::Select(SelectStatement {
2671        query: Query {
2672            ctes: CteBlock::Simple(vec![]),
2673            body: SetExpr::Select(Box::new(Select {
2674                distinct,
2675                projection: if projection.len() == 0 {
2676                    vec![SelectItem::Wildcard]
2677                } else {
2678                    projection
2679                        .iter()
2680                        .map(|ident| SelectItem::Expr {
2681                            expr: Expr::Identifier(vec![ident.clone()]),
2682                            alias: None,
2683                        })
2684                        .collect()
2685                },
2686                from: vec![TableWithJoins {
2687                    relation: TableFactor::Table {
2688                        name: RawItemName::Name(name.clone()),
2689                        alias: None,
2690                    },
2691                    joins: vec![],
2692                }],
2693                selection: None,
2694                group_by: vec![],
2695                having: None,
2696                qualify: None,
2697                options: vec![],
2698            })),
2699            order_by: view_order_by,
2700            limit: None,
2701            offset: None,
2702        },
2703        as_of: query_as_of.clone(),
2704    })
2705    .to_ast_string_stable();
2706
2707    // `DROP VIEW {name}`
2708    let drop_view = AstStatement::<Raw>::DropObjects(DropObjectsStatement {
2709        object_type: ObjectType::View,
2710        if_exists: false,
2711        names: vec![UnresolvedObjectName::Item(name)],
2712        cascade: false,
2713    })
2714    .to_ast_string_stable();
2715
2716    (create_view, create_index, view_sql, drop_view)
2717}
2718
2719/// Analyzes the provided query `body` to derive the number of
2720/// attributes in the query. We only consider syntactic cues,
2721/// so we may end up deriving `None` for the number of attributes
2722/// as a conservative approximation.
2723fn derive_num_attributes(body: &SetExpr<Raw>) -> Option<usize> {
2724    let Some((projection, _)) = find_projection(body) else {
2725        return None;
2726    };
2727    derive_num_attributes_from_projection(projection)
2728}
2729
2730/// Analyzes a query's `ORDER BY` clause to derive an `ORDER BY`
2731/// clause that makes numeric references to any expressions in
2732/// the projection and generated-attribute references to expressions
2733/// that need to be added as extra columns to the projection list.
2734/// The rewritten `ORDER BY` clause is then usable when querying a
2735/// view that contains the same `SELECT` as the given query.
2736/// This function returns both the rewritten `ORDER BY` clause
2737/// as well as a list of extra columns that need to be added
2738/// to the query's projection for the `ORDER BY` clause to
2739/// succeed.
2740fn derive_order_by(
2741    body: &SetExpr<Raw>,
2742    order_by: &Vec<OrderByExpr<Raw>>,
2743) -> (
2744    Vec<OrderByExpr<Raw>>,
2745    Vec<SelectItem<Raw>>,
2746    Option<Distinct<Raw>>,
2747) {
2748    let Some((projection, distinct)) = find_projection(body) else {
2749        return (vec![], vec![], None);
2750    };
2751    let (view_order_by, extra_columns) = derive_order_by_from_projection(projection, order_by);
2752    (view_order_by, extra_columns, distinct.clone())
2753}
2754
2755/// Finds the projection list in a `SELECT` query body.
2756fn find_projection(body: &SetExpr<Raw>) -> Option<(&Vec<SelectItem<Raw>>, &Option<Distinct<Raw>>)> {
2757    // Iterate to peel off the query body until the query's
2758    // projection list is found.
2759    let mut set_expr = body;
2760    loop {
2761        match set_expr {
2762            SetExpr::Select(select) => {
2763                return Some((&select.projection, &select.distinct));
2764            }
2765            SetExpr::SetOperation { left, .. } => set_expr = left.as_ref(),
2766            SetExpr::Query(query) => set_expr = &query.body,
2767            _ => return None,
2768        }
2769    }
2770}
2771
2772/// Computes the number of attributes that are obtained by the
2773/// projection of a `SELECT` query. The projection may include
2774/// wildcards, in which case the analysis just returns `None`.
2775fn derive_num_attributes_from_projection(projection: &Vec<SelectItem<Raw>>) -> Option<usize> {
2776    let mut num_attributes = 0usize;
2777    for item in projection.iter() {
2778        let SelectItem::Expr { expr, .. } = item else {
2779            return None;
2780        };
2781        match expr {
2782            Expr::QualifiedWildcard(..) | Expr::WildcardAccess(..) => {
2783                return None;
2784            }
2785            _ => {
2786                num_attributes += 1;
2787            }
2788        }
2789    }
2790    Some(num_attributes)
2791}
2792
2793/// Computes an `ORDER BY` clause with only numeric references
2794/// from given projection and `ORDER BY` of a `SELECT` query.
2795/// If the derivation fails to match a given expression, the
2796/// matched prefix is returned. Note that this could be empty.
2797fn derive_order_by_from_projection(
2798    projection: &Vec<SelectItem<Raw>>,
2799    order_by: &Vec<OrderByExpr<Raw>>,
2800) -> (Vec<OrderByExpr<Raw>>, Vec<SelectItem<Raw>>) {
2801    let mut view_order_by: Vec<OrderByExpr<Raw>> = vec![];
2802    let mut extra_columns: Vec<SelectItem<Raw>> = vec![];
2803    for order_by_expr in order_by.iter() {
2804        let query_expr = &order_by_expr.expr;
2805        let view_expr = match query_expr {
2806            Expr::Value(mz_sql_parser::ast::Value::Number(_)) => query_expr.clone(),
2807            _ => {
2808                // Find expression in query projection, if we can.
2809                if let Some(i) = projection.iter().position(|item| match item {
2810                    SelectItem::Expr { expr, alias } => {
2811                        expr == query_expr
2812                            || match query_expr {
2813                                Expr::Identifier(ident) => {
2814                                    ident.len() == 1 && Some(&ident[0]) == alias.as_ref()
2815                                }
2816                                _ => false,
2817                            }
2818                    }
2819                    SelectItem::Wildcard => false,
2820                }) {
2821                    Expr::Value(mz_sql_parser::ast::Value::Number((i + 1).to_string()))
2822                } else {
2823                    // If the expression is not found in the
2824                    // projection, add extra column.
2825                    let ident = Ident::new_unchecked(format!(
2826                        "a{}",
2827                        (projection.len() + extra_columns.len() + 1)
2828                    ));
2829                    extra_columns.push(SelectItem::Expr {
2830                        expr: query_expr.clone(),
2831                        alias: Some(ident.clone()),
2832                    });
2833                    Expr::Identifier(vec![ident])
2834                }
2835            }
2836        };
2837        view_order_by.push(OrderByExpr {
2838            expr: view_expr,
2839            asc: order_by_expr.asc,
2840            nulls_last: order_by_expr.nulls_last,
2841        });
2842    }
2843    (view_order_by, extra_columns)
2844}
2845
2846/// Returns extra statements to execute after `stmt` is executed.
2847fn mutate(sql: &str) -> Vec<String> {
2848    let stmts = parser::parse_statements(sql).unwrap_or_default();
2849    let mut additional = Vec::new();
2850    for stmt in stmts {
2851        match stmt.ast {
2852            AstStatement::CreateTable(stmt) => additional.push(
2853                // CREATE TABLE -> CREATE INDEX. Specify all columns manually in case CREATE
2854                // DEFAULT INDEX ever goes away.
2855                AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2856                    name: None,
2857                    in_cluster: None,
2858                    on_name: RawItemName::Name(stmt.name.clone()),
2859                    key_parts: Some(
2860                        stmt.columns
2861                            .iter()
2862                            .map(|def| Expr::Identifier(vec![def.name.clone()]))
2863                            .collect(),
2864                    ),
2865                    with_options: Vec::new(),
2866                    if_not_exists: false,
2867                })
2868                .to_ast_string_stable(),
2869            ),
2870            _ => {}
2871        }
2872    }
2873    additional
2874}
2875
2876#[mz_ore::test]
2877#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
2878fn test_generate_view_sql() {
2879    let uuid = Uuid::parse_str("67e5504410b1426f9247bb680e5fe0c8").unwrap();
2880    let cases = vec![
2881        (("SELECT * FROM t", None, None),
2882        (
2883            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM "t""#.to_string(),
2884            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2885            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2886            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2887        )),
2888        (("SELECT a, b, c FROM t1, t2", Some(3), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c")])),
2889        (
2890            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2891            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c")"#.to_string(),
2892            r#"SELECT "a", "b", "c" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2893            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2894        )),
2895        (("SELECT a, b, c FROM t1, t2", Some(3), None),
2896        (
2897            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2898            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2899            r#"SELECT "a1", "a2", "a3" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2900            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2901        )),
2902        // A case with ambiguity that is accepted by the function, illustrating that
2903        // our measures to dodge this issue are imperfect.
2904        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a)", None, None),
2905        (
2906            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a")"#.to_string(),
2907            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2908            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2909            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2910        )),
2911        (("SELECT a, b, b + d AS c, a + b AS d FROM t1, t2 ORDER BY a, c, a + b", Some(4), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c"), ColumnName::from("d")])),
2912        (
2913            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d") AS SELECT "a", "b", "b" + "d" AS "c", "a" + "b" AS "d" FROM "t1", "t2" ORDER BY "a", "c", "a" + "b""#.to_string(),
2914            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d")"#.to_string(),
2915            r#"SELECT "a", "b", "c", "d" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, 3, 4"#.to_string(),
2916            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2917        )),
2918        (("((SELECT 1 AS a UNION SELECT 2 AS b) UNION SELECT 3 AS c) ORDER BY a", Some(1), None),
2919        (
2920            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1") AS (SELECT 1 AS "a" UNION SELECT 2 AS "b") UNION SELECT 3 AS "c" ORDER BY "a""#.to_string(),
2921            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1")"#.to_string(),
2922            r#"SELECT "a1" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2923            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2924        )),
2925        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY 1", None, None),
2926        (
2927            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY 1"#.to_string(),
2928            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2929            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2930            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2931        )),
2932        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY a", None, None),
2933        (
2934            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY "a""#.to_string(),
2935            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2936            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a""#.to_string(),
2937            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2938        )),
2939        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY a, c", Some(2), None),
2940        (
2941            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "a", "c""#.to_string(),
2942            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2943            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, "a3""#.to_string(),
2944            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2945        )),
2946        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY c, a", Some(2), None),
2947        (
2948            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "c", "a""#.to_string(),
2949            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2950            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a3", 1"#.to_string(),
2951            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2952        )),
2953    ];
2954    for ((sql, num_attributes, expected_column_names), expected) in cases {
2955        let view_sql =
2956            generate_view_sql(sql, uuid.as_simple(), num_attributes, expected_column_names);
2957        assert_eq!(expected, view_sql);
2958    }
2959}
2960
2961#[mz_ore::test]
2962fn test_mutate() {
2963    let cases = vec![
2964        ("CREATE TABLE t ()", vec![r#"CREATE INDEX ON "t" ()"#]),
2965        (
2966            "CREATE TABLE t (a INT)",
2967            vec![r#"CREATE INDEX ON "t" ("a")"#],
2968        ),
2969        (
2970            "CREATE TABLE t (a INT, b TEXT)",
2971            vec![r#"CREATE INDEX ON "t" ("a", "b")"#],
2972        ),
2973        // Invalid syntax works, just returns nothing.
2974        ("BAD SYNTAX", Vec::new()),
2975    ];
2976    for (sql, expected) in cases {
2977        let stmts = mutate(sql);
2978        assert_eq!(expected, stmts, "sql: {sql}");
2979    }
2980}