Skip to main content

mz_sqllogictest/
runner.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The Materialize-specific runner for sqllogictest.
11//!
12//! slt tests expect a serialized execution of sql statements and queries.
13//! To get the same results in materialize we track current_timestamp and increment it whenever we execute a statement.
14//!
15//! The high-level workflow is:
16//!   for each record in the test file:
17//!     if record is a sql statement:
18//!       run sql in postgres, observe changes and copy them to materialize using LocalInput::Updates(..)
19//!       advance current_timestamp
20//!       promise to never send updates for times < current_timestamp using LocalInput::Watermark(..)
21//!       compare to expected results
22//!       if wrong, bail out and stop processing this file
23//!     if record is a sql query:
24//!       peek query at current_timestamp
25//!       compare to expected results
26//!       if wrong, record the error
27
28use std::collections::BTreeMap;
29use std::error::Error;
30use std::fs::{File, OpenOptions};
31use std::io::{Read, Seek, SeekFrom, Write};
32use std::net::{IpAddr, Ipv4Addr, SocketAddr};
33use std::path::Path;
34use std::sync::Arc;
35use std::sync::LazyLock;
36use std::time::Duration;
37use std::{env, fmt, ops, str, thread};
38
39use anyhow::{anyhow, bail};
40use bytes::BytesMut;
41use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
42use fallible_iterator::FallibleIterator;
43use futures::sink::SinkExt;
44use itertools::Itertools;
45use maplit::btreemap;
46use md5::{Digest, Md5};
47use mz_adapter_types::bootstrap_builtin_cluster_config::{
48    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
49    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
50    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
51};
52use mz_catalog::config::ClusterReplicaSizeMap;
53use mz_controller::{ControllerConfig, ReplicaHttpLocator};
54use mz_environmentd::CatalogConfig;
55use mz_license_keys::ValidatedLicenseKey;
56use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
57use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
58use mz_ore::cast::{CastFrom, ReinterpretCast};
59use mz_ore::channel::trigger;
60use mz_ore::error::ErrorExt;
61use mz_ore::metrics::MetricsRegistry;
62use mz_ore::now::SYSTEM_TIME;
63use mz_ore::retry::Retry;
64use mz_ore::task;
65use mz_ore::thread::{JoinHandleExt, JoinOnDropHandle};
66use mz_ore::tracing::TracingHandle;
67use mz_ore::url::SensitiveUrl;
68use mz_persist_client::PersistLocation;
69use mz_persist_client::cache::PersistClientCache;
70use mz_persist_client::cfg::PersistConfig;
71use mz_persist_client::rpc::{
72    MetricsSameProcessPubSubSender, PersistGrpcPubSubServer, PubSubClientConnection, PubSubSender,
73};
74use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8, Value, oid};
75use mz_repr::ColumnName;
76use mz_repr::adt::date::Date;
77use mz_repr::adt::mz_acl_item::{AclItem, MzAclItem};
78use mz_repr::adt::numeric;
79use mz_secrets::SecretsController;
80use mz_server_core::listeners::{
81    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpListenerConfig, HttpRoutesEnabled,
82    ListenersConfig, SqlListenerConfig,
83};
84use mz_sql::ast::{Expr, Raw, Statement};
85use mz_sql::catalog::EnvironmentId;
86use mz_sql_parser::ast::display::AstDisplay;
87use mz_sql_parser::ast::{
88    CreateIndexStatement, CreateViewStatement, CteBlock, Distinct, DropObjectsStatement, Ident,
89    IfExistsBehavior, ObjectType, OrderByExpr, Query, RawItemName, Select, SelectItem,
90    SelectStatement, SetExpr, Statement as AstStatement, TableFactor, TableWithJoins,
91    UnresolvedItemName, UnresolvedObjectName, ViewDefinition,
92};
93use mz_sql_parser::parser;
94use mz_storage_types::connections::ConnectionContext;
95use postgres_protocol::types;
96use regex::Regex;
97use tempfile::TempDir;
98use tokio::net::TcpListener;
99use tokio::runtime::Runtime;
100use tokio::sync::oneshot;
101use tokio_postgres::types::{FromSql, Kind as PgKind, Type as PgType};
102use tokio_postgres::{NoTls, Row, SimpleQueryMessage};
103use tokio_stream::wrappers::TcpListenerStream;
104use tower_http::cors::AllowOrigin;
105use tracing::{error, info};
106use uuid::Uuid;
107use uuid::fmt::Simple;
108
109use crate::ast::{Location, Mode, Output, QueryOutput, Record, Sort, Type};
110use crate::util;
111
112#[derive(Debug)]
113pub enum Outcome<'a> {
114    Unsupported {
115        error: anyhow::Error,
116        location: Location,
117    },
118    ParseFailure {
119        error: anyhow::Error,
120        location: Location,
121    },
122    PlanFailure {
123        error: anyhow::Error,
124        location: Location,
125    },
126    UnexpectedPlanSuccess {
127        expected_error: &'a str,
128        location: Location,
129    },
130    WrongNumberOfRowsInserted {
131        expected_count: u64,
132        actual_count: u64,
133        location: Location,
134    },
135    WrongColumnCount {
136        expected_count: usize,
137        actual_count: usize,
138        location: Location,
139    },
140    WrongColumnNames {
141        expected_column_names: &'a Vec<ColumnName>,
142        actual_column_names: Vec<ColumnName>,
143        actual_output: Output,
144        location: Location,
145    },
146    OutputFailure {
147        expected_output: &'a Output,
148        actual_raw_output: Vec<Row>,
149        actual_output: Output,
150        location: Location,
151    },
152    InconsistentViewOutcome {
153        query_outcome: Box<Outcome<'a>>,
154        view_outcome: Box<Outcome<'a>>,
155        location: Location,
156    },
157    Bail {
158        cause: Box<Outcome<'a>>,
159        location: Location,
160    },
161    Warning {
162        cause: Box<Outcome<'a>>,
163        location: Location,
164    },
165    Success,
166}
167
168const NUM_OUTCOMES: usize = 12;
169const WARNING_OUTCOME: usize = NUM_OUTCOMES - 2;
170const SUCCESS_OUTCOME: usize = NUM_OUTCOMES - 1;
171
172impl<'a> Outcome<'a> {
173    fn code(&self) -> usize {
174        match self {
175            Outcome::Unsupported { .. } => 0,
176            Outcome::ParseFailure { .. } => 1,
177            Outcome::PlanFailure { .. } => 2,
178            Outcome::UnexpectedPlanSuccess { .. } => 3,
179            Outcome::WrongNumberOfRowsInserted { .. } => 4,
180            Outcome::WrongColumnCount { .. } => 5,
181            Outcome::WrongColumnNames { .. } => 6,
182            Outcome::OutputFailure { .. } => 7,
183            Outcome::InconsistentViewOutcome { .. } => 8,
184            Outcome::Bail { .. } => 9,
185            Outcome::Warning { .. } => 10,
186            Outcome::Success => 11,
187        }
188    }
189
190    fn success(&self) -> bool {
191        matches!(self, Outcome::Success)
192    }
193
194    fn failure(&self) -> bool {
195        !matches!(self, Outcome::Success) && !matches!(self, Outcome::Warning { .. })
196    }
197
198    /// Returns an error message that will match self. Appropriate for
199    /// rewriting error messages (i.e. not inserting error messages where we
200    /// currently expect success).
201    fn err_msg(&self) -> Option<String> {
202        match self {
203            Outcome::Unsupported { error, .. }
204            | Outcome::ParseFailure { error, .. }
205            | Outcome::PlanFailure { error, .. } => Some(
206                // This value gets fed back into regex to check that it matches
207                // `self`, so escape its meta characters.
208                regex::escape(
209                    // Take only the first string in the error message, which should be
210                    // sufficient for meaningfully matching the error.
211                    error.to_string_with_causes().split('\n').next().unwrap(),
212                )
213                // We need to undo the escaping of #. `regex::escape` escapes this because it
214                // expects that we use the `x` flag when building a regex, but this is not the case,
215                // so \# would end up being an invalid escape sequence, which would choke the
216                // parsing of the slt file the next time around.
217                .replace(r"\#", "#"),
218            ),
219            _ => None,
220        }
221    }
222}
223
224impl fmt::Display for Outcome<'_> {
225    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226        use Outcome::*;
227        const INDENT: &str = "\n        ";
228        match self {
229            Unsupported { error, location } => write!(
230                f,
231                "Unsupported:{}:\n{}",
232                location,
233                error.display_with_causes()
234            ),
235            ParseFailure { error, location } => {
236                write!(
237                    f,
238                    "ParseFailure:{}:\n{}",
239                    location,
240                    error.display_with_causes()
241                )
242            }
243            PlanFailure { error, location } => write!(f, "PlanFailure:{}:\n{:#}", location, error),
244            UnexpectedPlanSuccess {
245                expected_error,
246                location,
247            } => write!(
248                f,
249                "UnexpectedPlanSuccess:{} expected error: {}",
250                location, expected_error
251            ),
252            WrongNumberOfRowsInserted {
253                expected_count,
254                actual_count,
255                location,
256            } => write!(
257                f,
258                "WrongNumberOfRowsInserted:{}{}expected: {}{}actually: {}",
259                location, INDENT, expected_count, INDENT, actual_count
260            ),
261            WrongColumnCount {
262                expected_count,
263                actual_count,
264                location,
265            } => write!(
266                f,
267                "WrongColumnCount:{}{}expected: {}{}actually: {}",
268                location, INDENT, expected_count, INDENT, actual_count
269            ),
270            WrongColumnNames {
271                expected_column_names,
272                actual_column_names,
273                actual_output: _,
274                location,
275            } => write!(
276                f,
277                "Wrong Column Names:{}:{}expected column names: {}{}inferred column names: {}",
278                location,
279                INDENT,
280                expected_column_names
281                    .iter()
282                    .map(|n| n.to_string())
283                    .collect::<Vec<_>>()
284                    .join(" "),
285                INDENT,
286                actual_column_names
287                    .iter()
288                    .map(|n| n.to_string())
289                    .collect::<Vec<_>>()
290                    .join(" ")
291            ),
292            OutputFailure {
293                expected_output,
294                actual_raw_output,
295                actual_output,
296                location,
297            } => write!(
298                f,
299                "OutputFailure:{}{}expected: {:?}{}actually: {:?}{}actual raw: {:?}",
300                location, INDENT, expected_output, INDENT, actual_output, INDENT, actual_raw_output
301            ),
302            InconsistentViewOutcome {
303                query_outcome,
304                view_outcome,
305                location,
306            } => write!(
307                f,
308                "InconsistentViewOutcome:{}{}expected from query: {:?}{}actually from indexed view: {:?}{}",
309                location, INDENT, query_outcome, INDENT, view_outcome, INDENT
310            ),
311            Bail { cause, location } => write!(f, "Bail:{} {}", location, cause),
312            Warning { cause, location } => write!(f, "Warning:{} {}", location, cause),
313            Success => f.write_str("Success"),
314        }
315    }
316}
317
318#[derive(Default, Debug)]
319pub struct Outcomes {
320    stats: [usize; NUM_OUTCOMES],
321    details: Vec<String>,
322}
323
324impl ops::AddAssign<Outcomes> for Outcomes {
325    fn add_assign(&mut self, rhs: Outcomes) {
326        for (lhs, rhs) in self.stats.iter_mut().zip_eq(rhs.stats.iter()) {
327            *lhs += rhs
328        }
329    }
330}
331impl Outcomes {
332    pub fn any_failed(&self) -> bool {
333        self.stats[SUCCESS_OUTCOME] + self.stats[WARNING_OUTCOME] < self.stats.iter().sum::<usize>()
334    }
335
336    pub fn as_json(&self) -> serde_json::Value {
337        serde_json::json!({
338            "unsupported": self.stats[0],
339            "parse_failure": self.stats[1],
340            "plan_failure": self.stats[2],
341            "unexpected_plan_success": self.stats[3],
342            "wrong_number_of_rows_affected": self.stats[4],
343            "wrong_column_count": self.stats[5],
344            "wrong_column_names": self.stats[6],
345            "output_failure": self.stats[7],
346            "inconsistent_view_outcome": self.stats[8],
347            "bail": self.stats[9],
348            "warning": self.stats[10],
349            "success": self.stats[11],
350        })
351    }
352
353    pub fn display(&self, no_fail: bool, failure_details: bool) -> OutcomesDisplay<'_> {
354        OutcomesDisplay {
355            inner: self,
356            no_fail,
357            failure_details,
358        }
359    }
360}
361
362pub struct OutcomesDisplay<'a> {
363    inner: &'a Outcomes,
364    no_fail: bool,
365    failure_details: bool,
366}
367
368impl<'a> fmt::Display for OutcomesDisplay<'a> {
369    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
370        let total: usize = self.inner.stats.iter().sum();
371        if self.failure_details
372            && (self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] != total
373                || self.no_fail)
374        {
375            for outcome in &self.inner.details {
376                writeln!(f, "{}", outcome)?;
377            }
378            Ok(())
379        } else {
380            write!(
381                f,
382                "{}:",
383                if self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] == total {
384                    "PASS"
385                } else if self.no_fail {
386                    "FAIL-IGNORE"
387                } else {
388                    "FAIL"
389                }
390            )?;
391            static NAMES: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
392                vec![
393                    "unsupported",
394                    "parse-failure",
395                    "plan-failure",
396                    "unexpected-plan-success",
397                    "wrong-number-of-rows-inserted",
398                    "wrong-column-count",
399                    "wrong-column-names",
400                    "output-failure",
401                    "inconsistent-view-outcome",
402                    "bail",
403                    "warning",
404                    "success",
405                    "total",
406                ]
407            });
408            for (i, n) in self.inner.stats.iter().enumerate() {
409                if *n > 0 {
410                    write!(f, " {}={}", NAMES[i], n)?;
411                }
412            }
413            write!(f, " total={}", total)
414        }
415    }
416}
417
418struct QueryInfo {
419    is_select: bool,
420    num_attributes: Option<usize>,
421}
422
423enum PrepareQueryOutcome<'a> {
424    QueryPrepared(QueryInfo),
425    Outcome(Outcome<'a>),
426}
427
428pub struct Runner<'a> {
429    config: &'a RunConfig<'a>,
430    inner: Option<RunnerInner<'a>>,
431}
432
433pub struct RunnerInner<'a> {
434    server_addr: SocketAddr,
435    internal_server_addr: SocketAddr,
436    password_server_addr: SocketAddr,
437    internal_http_server_addr: SocketAddr,
438    // Drop order matters for these fields.
439    client: tokio_postgres::Client,
440    system_client: tokio_postgres::Client,
441    clients: BTreeMap<String, tokio_postgres::Client>,
442    auto_index_tables: bool,
443    auto_index_selects: bool,
444    auto_transactions: bool,
445    enable_table_keys: bool,
446    verbosity: u8,
447    stdout: &'a dyn WriteFmt,
448    _shutdown_trigger: trigger::Trigger,
449    _server_thread: JoinOnDropHandle<()>,
450    _temp_dir: TempDir,
451}
452
453#[derive(Debug)]
454pub struct Slt(Value);
455
456impl<'a> FromSql<'a> for Slt {
457    fn from_sql(
458        ty: &PgType,
459        mut raw: &'a [u8],
460    ) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
461        Ok(match *ty {
462            PgType::ACLITEM => Self(Value::AclItem(AclItem::decode_binary(
463                types::bytea_from_sql(raw),
464            )?)),
465            PgType::BOOL => Self(Value::Bool(types::bool_from_sql(raw)?)),
466            PgType::BYTEA => Self(Value::Bytea(types::bytea_from_sql(raw).to_vec())),
467            PgType::CHAR => Self(Value::Char(u8::from_be_bytes(
468                types::char_from_sql(raw)?.to_be_bytes(),
469            ))),
470            PgType::FLOAT4 => Self(Value::Float4(types::float4_from_sql(raw)?)),
471            PgType::FLOAT8 => Self(Value::Float8(types::float8_from_sql(raw)?)),
472            PgType::DATE => Self(Value::Date(Date::from_pg_epoch(types::int4_from_sql(
473                raw,
474            )?)?)),
475            PgType::INT2 => Self(Value::Int2(types::int2_from_sql(raw)?)),
476            PgType::INT4 => Self(Value::Int4(types::int4_from_sql(raw)?)),
477            PgType::INT8 => Self(Value::Int8(types::int8_from_sql(raw)?)),
478            PgType::INTERVAL => Self(Value::Interval(Interval::from_sql(ty, raw)?)),
479            PgType::JSONB => Self(Value::Jsonb(Jsonb::from_sql(ty, raw)?)),
480            PgType::NAME => Self(Value::Name(types::text_from_sql(raw)?.to_string())),
481            PgType::NUMERIC => Self(Value::Numeric(Numeric::from_sql(ty, raw)?)),
482            PgType::OID => Self(Value::Oid(types::oid_from_sql(raw)?)),
483            PgType::REGCLASS => Self(Value::Oid(types::oid_from_sql(raw)?)),
484            PgType::REGPROC => Self(Value::Oid(types::oid_from_sql(raw)?)),
485            PgType::REGTYPE => Self(Value::Oid(types::oid_from_sql(raw)?)),
486            PgType::TEXT | PgType::BPCHAR | PgType::VARCHAR => {
487                Self(Value::Text(types::text_from_sql(raw)?.to_string()))
488            }
489            PgType::TIME => Self(Value::Time(NaiveTime::from_sql(ty, raw)?)),
490            PgType::TIMESTAMP => Self(Value::Timestamp(
491                NaiveDateTime::from_sql(ty, raw)?.try_into()?,
492            )),
493            PgType::TIMESTAMPTZ => Self(Value::TimestampTz(
494                DateTime::<Utc>::from_sql(ty, raw)?.try_into()?,
495            )),
496            PgType::UUID => Self(Value::Uuid(Uuid::from_sql(ty, raw)?)),
497            PgType::RECORD => {
498                let num_fields = read_be_i32(&mut raw)?;
499                let mut tuple = vec![];
500                for _ in 0..num_fields {
501                    let oid = u32::reinterpret_cast(read_be_i32(&mut raw)?);
502                    let typ = match PgType::from_oid(oid) {
503                        Some(typ) => typ,
504                        None => return Err("unknown oid".into()),
505                    };
506                    let v = read_value::<Option<Slt>>(&typ, &mut raw)?;
507                    tuple.push(v.map(|v| v.0));
508                }
509                Self(Value::Record(tuple))
510            }
511            PgType::INT4_RANGE
512            | PgType::INT8_RANGE
513            | PgType::DATE_RANGE
514            | PgType::NUM_RANGE
515            | PgType::TS_RANGE
516            | PgType::TSTZ_RANGE => {
517                use mz_repr::adt::range::Range;
518                let range: Range<Slt> = Range::from_sql(ty, raw)?;
519                Self(Value::Range(range.into_bounds(|b| Box::new(b.0))))
520            }
521
522            _ => match ty.kind() {
523                PgKind::Array(arr_type) => {
524                    let arr = types::array_from_sql(raw)?;
525                    let elements: Vec<Option<Value>> = arr
526                        .values()
527                        .map(|v| match v {
528                            Some(v) => Ok(Some(Slt::from_sql(arr_type, v)?)),
529                            None => Ok(None),
530                        })
531                        .collect::<Vec<Option<Slt>>>()?
532                        .into_iter()
533                        // Map a Vec<Option<Slt>> to Vec<Option<Value>>.
534                        .map(|v| v.map(|v| v.0))
535                        .collect();
536
537                    Self(Value::Array {
538                        dims: arr
539                            .dimensions()
540                            .map(|d| {
541                                Ok(mz_repr::adt::array::ArrayDimension {
542                                    lower_bound: isize::cast_from(d.lower_bound),
543                                    length: usize::try_from(d.len)
544                                        .expect("cannot have negative length"),
545                                })
546                            })
547                            .collect()?,
548                        elements,
549                    })
550                }
551                _ => match ty.oid() {
552                    oid::TYPE_UINT2_OID => Self(Value::UInt2(UInt2::from_sql(ty, raw)?)),
553                    oid::TYPE_UINT4_OID => Self(Value::UInt4(UInt4::from_sql(ty, raw)?)),
554                    oid::TYPE_UINT8_OID => Self(Value::UInt8(UInt8::from_sql(ty, raw)?)),
555                    oid::TYPE_MZ_TIMESTAMP_OID => {
556                        let s = types::text_from_sql(raw)?;
557                        let t: mz_repr::Timestamp = s.parse()?;
558                        Self(Value::MzTimestamp(t))
559                    }
560                    oid::TYPE_MZ_ACL_ITEM_OID => Self(Value::MzAclItem(MzAclItem::decode_binary(
561                        types::bytea_from_sql(raw),
562                    )?)),
563                    _ => unreachable!(),
564                },
565            },
566        })
567    }
568    fn accepts(ty: &PgType) -> bool {
569        match ty.kind() {
570            PgKind::Array(_) | PgKind::Composite(_) => return true,
571            _ => {}
572        }
573        match ty.oid() {
574            oid::TYPE_UINT2_OID
575            | oid::TYPE_UINT4_OID
576            | oid::TYPE_UINT8_OID
577            | oid::TYPE_MZ_TIMESTAMP_OID
578            | oid::TYPE_MZ_ACL_ITEM_OID => return true,
579            _ => {}
580        }
581        matches!(
582            *ty,
583            PgType::ACLITEM
584                | PgType::BOOL
585                | PgType::BYTEA
586                | PgType::CHAR
587                | PgType::DATE
588                | PgType::FLOAT4
589                | PgType::FLOAT8
590                | PgType::INT2
591                | PgType::INT4
592                | PgType::INT8
593                | PgType::INTERVAL
594                | PgType::JSONB
595                | PgType::NAME
596                | PgType::NUMERIC
597                | PgType::OID
598                | PgType::REGCLASS
599                | PgType::REGPROC
600                | PgType::REGTYPE
601                | PgType::RECORD
602                | PgType::TEXT
603                | PgType::BPCHAR
604                | PgType::VARCHAR
605                | PgType::TIME
606                | PgType::TIMESTAMP
607                | PgType::TIMESTAMPTZ
608                | PgType::UUID
609                | PgType::INT4_RANGE
610                | PgType::INT4_RANGE_ARRAY
611                | PgType::INT8_RANGE
612                | PgType::INT8_RANGE_ARRAY
613                | PgType::DATE_RANGE
614                | PgType::DATE_RANGE_ARRAY
615                | PgType::NUM_RANGE
616                | PgType::NUM_RANGE_ARRAY
617                | PgType::TS_RANGE
618                | PgType::TS_RANGE_ARRAY
619                | PgType::TSTZ_RANGE
620                | PgType::TSTZ_RANGE_ARRAY
621        )
622    }
623}
624
625// From postgres-types/src/private.rs.
626fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
627    if buf.len() < 4 {
628        return Err("invalid buffer size".into());
629    }
630    let mut bytes = [0; 4];
631    bytes.copy_from_slice(&buf[..4]);
632    *buf = &buf[4..];
633    Ok(i32::from_be_bytes(bytes))
634}
635
636// From postgres-types/src/private.rs.
637fn read_value<'a, T>(type_: &PgType, buf: &mut &'a [u8]) -> Result<T, Box<dyn Error + Sync + Send>>
638where
639    T: FromSql<'a>,
640{
641    let value = match usize::try_from(read_be_i32(buf)?) {
642        Err(_) => None,
643        Ok(len) => {
644            if len > buf.len() {
645                return Err("invalid buffer size".into());
646            }
647            let (head, tail) = buf.split_at(len);
648            *buf = tail;
649            Some(head)
650        }
651    };
652    T::from_sql_nullable(type_, value)
653}
654
655fn format_datum(d: Slt, typ: &Type, mode: Mode, col: usize) -> String {
656    match (typ, d.0) {
657        (Type::Bool, Value::Bool(b)) => b.to_string(),
658
659        (Type::Integer, Value::Int2(i)) => i.to_string(),
660        (Type::Integer, Value::Int4(i)) => i.to_string(),
661        (Type::Integer, Value::Int8(i)) => i.to_string(),
662        (Type::Integer, Value::UInt2(u)) => u.0.to_string(),
663        (Type::Integer, Value::UInt4(u)) => u.0.to_string(),
664        (Type::Integer, Value::UInt8(u)) => u.0.to_string(),
665        (Type::Integer, Value::Oid(i)) => i.to_string(),
666        // TODO(benesch): rewrite to avoid `as`.
667        #[allow(clippy::as_conversions)]
668        (Type::Integer, Value::Float4(f)) => format!("{}", f as i64),
669        // TODO(benesch): rewrite to avoid `as`.
670        #[allow(clippy::as_conversions)]
671        (Type::Integer, Value::Float8(f)) => format!("{}", f as i64),
672        // This is so wrong, but sqlite needs it.
673        (Type::Integer, Value::Text(_)) => "0".to_string(),
674        (Type::Integer, Value::Bool(b)) => i8::from(b).to_string(),
675        (Type::Integer, Value::Numeric(d)) => {
676            let mut d = d.0.0.clone();
677            let mut cx = numeric::cx_datum();
678            // Truncate the decimal to match sqlite.
679            if mode == Mode::Standard {
680                cx.set_rounding(dec::Rounding::Down);
681            }
682            cx.round(&mut d);
683            numeric::munge_numeric(&mut d).unwrap();
684            d.to_standard_notation_string()
685        }
686
687        (Type::Real, Value::Int2(i)) => format!("{:.3}", i),
688        (Type::Real, Value::Int4(i)) => format!("{:.3}", i),
689        (Type::Real, Value::Int8(i)) => format!("{:.3}", i),
690        (Type::Real, Value::Float4(f)) => match mode {
691            Mode::Standard => format!("{:.3}", f),
692            Mode::Cockroach => format!("{}", f),
693        },
694        (Type::Real, Value::Float8(f)) => match mode {
695            Mode::Standard => format!("{:.3}", f),
696            Mode::Cockroach => format!("{}", f),
697        },
698        (Type::Real, Value::Numeric(d)) => match mode {
699            Mode::Standard => {
700                let mut d = d.0.0.clone();
701                if d.exponent() < -3 {
702                    numeric::rescale(&mut d, 3).unwrap();
703                }
704                numeric::munge_numeric(&mut d).unwrap();
705                d.to_standard_notation_string()
706            }
707            Mode::Cockroach => d.0.0.to_standard_notation_string(),
708        },
709
710        (Type::Text, Value::Text(s)) => {
711            if s.is_empty() {
712                "(empty)".to_string()
713            } else {
714                s
715            }
716        }
717        (Type::Text, Value::Bool(b)) => b.to_string(),
718        (Type::Text, Value::Float4(f)) => format!("{:.3}", f),
719        (Type::Text, Value::Float8(f)) => format!("{:.3}", f),
720        // Bytes are printed as text iff they are valid UTF-8. This
721        // seems guaranteed to confuse everyone, but it is required for
722        // compliance with the CockroachDB sqllogictest runner. [0]
723        //
724        // [0]: https://github.com/cockroachdb/cockroach/blob/970782487/pkg/sql/logictest/logic.go#L2038-L2043
725        (Type::Text, Value::Bytea(b)) => match str::from_utf8(&b) {
726            Ok(s) => s.to_string(),
727            Err(_) => format!("{:?}", b),
728        },
729        (Type::Text, Value::Numeric(d)) => d.0.0.to_standard_notation_string(),
730        // Everything else gets normal text encoding. This correctly handles things
731        // like arrays, tuples, and strings that need to be quoted.
732        (Type::Text, d) => {
733            let mut buf = BytesMut::new();
734            d.encode_text(&mut buf);
735            String::from_utf8_lossy(&buf).into_owned()
736        }
737
738        (Type::Oid, Value::Oid(o)) => o.to_string(),
739
740        (_, d) => panic!(
741            "Don't know how to format {:?} as {:?} in column {}",
742            d, typ, col,
743        ),
744    }
745}
746
747fn format_row(row: &Row, types: &[Type], mode: Mode) -> Vec<String> {
748    let mut formatted: Vec<String> = vec![];
749    for i in 0..row.len() {
750        let t: Option<Slt> = row.get::<usize, Option<Slt>>(i);
751        let t: Option<String> = t.map(|d| format_datum(d, &types[i], mode, i));
752        formatted.push(match t {
753            Some(t) => t,
754            None => "NULL".into(),
755        });
756    }
757
758    formatted
759}
760
761impl<'a> Runner<'a> {
762    pub async fn start(config: &'a RunConfig<'a>) -> Result<Runner<'a>, anyhow::Error> {
763        let mut runner = Self {
764            config,
765            inner: None,
766        };
767        runner.reset().await?;
768        Ok(runner)
769    }
770
771    pub async fn reset(&mut self) -> Result<(), anyhow::Error> {
772        // Explicitly drop the old runner here to ensure that we wait for threads to terminate
773        // before starting a new runner
774        drop(self.inner.take());
775        self.inner = Some(RunnerInner::start(self.config).await?);
776
777        Ok(())
778    }
779
780    async fn run_record<'r>(
781        &mut self,
782        record: &'r Record<'r>,
783        in_transaction: &mut bool,
784    ) -> Result<Outcome<'r>, anyhow::Error> {
785        if let Record::ResetServer = record {
786            self.reset().await?;
787            Ok(Outcome::Success)
788        } else {
789            self.inner
790                .as_mut()
791                .expect("RunnerInner missing")
792                .run_record(record, in_transaction)
793                .await
794        }
795    }
796
797    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
798        self.inner
799            .as_ref()
800            .expect("RunnerInner missing")
801            .check_catalog()
802            .await
803    }
804
805    async fn reset_database(&mut self) -> Result<(), anyhow::Error> {
806        let inner = self.inner.as_mut().expect("RunnerInner missing");
807
808        inner.client.batch_execute("ROLLBACK;").await?;
809
810        inner
811            .system_client
812            .batch_execute(
813                "ROLLBACK;
814                 SET cluster = mz_catalog_server;
815                 RESET cluster_replica;",
816            )
817            .await?;
818
819        inner
820            .system_client
821            .batch_execute("ALTER SYSTEM RESET ALL")
822            .await?;
823
824        // Drop all databases, then recreate the `materialize` database.
825        for row in inner
826            .system_client
827            .query("SELECT name FROM mz_databases", &[])
828            .await?
829        {
830            let name: &str = row.get("name");
831            inner
832                .system_client
833                .batch_execute(&format!("DROP DATABASE \"{name}\""))
834                .await?;
835        }
836        inner
837            .system_client
838            .batch_execute("CREATE DATABASE materialize")
839            .await?;
840
841        // Ensure quickstart cluster exists with one replica of size `self.config.replica_size`.
842        // We don't destroy the existing quickstart cluster replica if it exists, as turning
843        // on a cluster replica is exceptionally slow.
844        let mut needs_default_cluster = true;
845        for row in inner
846            .system_client
847            .query("SELECT name FROM mz_clusters WHERE id LIKE 'u%'", &[])
848            .await?
849        {
850            match row.get("name") {
851                "quickstart" => needs_default_cluster = false,
852                name => {
853                    inner
854                        .system_client
855                        .batch_execute(&format!("DROP CLUSTER {name}"))
856                        .await?
857                }
858            }
859        }
860        if needs_default_cluster {
861            inner
862                .system_client
863                .batch_execute("CREATE CLUSTER quickstart REPLICAS ()")
864                .await?;
865        }
866        let mut needs_default_replica = false;
867        let rows = inner
868            .system_client
869            .query(
870                "SELECT name, size FROM mz_cluster_replicas
871                 WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')
872                 ORDER BY name",
873                &[],
874            )
875            .await?;
876        if rows.len() != self.config.replicas {
877            needs_default_replica = true;
878        } else {
879            for (i, row) in rows.iter().enumerate() {
880                let name: &str = row.get("name");
881                let size: &str = row.get("size");
882                if name != format!("r{i}") || size != self.config.replica_size {
883                    needs_default_replica = true;
884                    break;
885                }
886            }
887        }
888
889        if needs_default_replica {
890            inner
891                .system_client
892                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = false)")
893                .await?;
894            for row in inner
895                .system_client
896                .query(
897                    "SELECT name FROM mz_cluster_replicas
898                     WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')",
899                    &[],
900                )
901                .await?
902            {
903                let name: &str = row.get("name");
904                inner
905                    .system_client
906                    .batch_execute(&format!("DROP CLUSTER REPLICA quickstart.{}", name))
907                    .await?;
908            }
909            for i in 1..=self.config.replicas {
910                inner
911                    .system_client
912                    .batch_execute(&format!(
913                        "CREATE CLUSTER REPLICA quickstart.r{i} SIZE '{}'",
914                        self.config.replica_size
915                    ))
916                    .await?;
917            }
918            inner
919                .system_client
920                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = true)")
921                .await?;
922        }
923
924        // Grant initial privileges.
925        inner
926            .system_client
927            .batch_execute("GRANT USAGE ON DATABASE materialize TO PUBLIC")
928            .await?;
929        inner
930            .system_client
931            .batch_execute("GRANT CREATE ON DATABASE materialize TO materialize")
932            .await?;
933        inner
934            .system_client
935            .batch_execute("GRANT CREATE ON SCHEMA materialize.public TO materialize")
936            .await?;
937        inner
938            .system_client
939            .batch_execute("GRANT USAGE ON CLUSTER quickstart TO PUBLIC")
940            .await?;
941        inner
942            .system_client
943            .batch_execute("GRANT CREATE ON CLUSTER quickstart TO materialize")
944            .await?;
945
946        // Some sqllogictests require more than the default amount of tables, so we increase the
947        // limit for all tests.
948        inner
949            .system_client
950            .simple_query("ALTER SYSTEM SET max_tables = 100")
951            .await?;
952
953        if inner.enable_table_keys {
954            inner
955                .system_client
956                .simple_query("ALTER SYSTEM SET unsafe_enable_table_keys = true")
957                .await?;
958        }
959
960        inner.ensure_fixed_features().await?;
961
962        inner.client = connect(inner.server_addr, None, None).await.unwrap();
963        inner.system_client = connect(inner.internal_server_addr, Some("mz_system"), None)
964            .await
965            .unwrap();
966        inner.clients = BTreeMap::new();
967
968        Ok(())
969    }
970}
971
972impl<'a> RunnerInner<'a> {
973    pub async fn start(config: &RunConfig<'a>) -> Result<RunnerInner<'a>, anyhow::Error> {
974        let temp_dir = tempfile::tempdir()?;
975        let scratch_dir = tempfile::tempdir()?;
976        let environment_id = EnvironmentId::for_tests();
977        let (consensus_uri, timestamp_oracle_url): (SensitiveUrl, SensitiveUrl) = {
978            let postgres_url = &config.postgres_url;
979            let prefix = &config.prefix;
980            info!(%postgres_url, "starting server");
981            let (client, conn) = Retry::default()
982                .max_tries(5)
983                .retry_async(|_| async {
984                    match tokio_postgres::connect(postgres_url, NoTls).await {
985                        Ok(c) => Ok(c),
986                        Err(e) => {
987                            error!(%e, "failed to connect to postgres");
988                            Err(e)
989                        }
990                    }
991                })
992                .await?;
993            task::spawn(|| "sqllogictest_connect", async move {
994                if let Err(e) = conn.await {
995                    panic!("connection error: {}", e);
996                }
997            });
998            client
999                .batch_execute(&format!(
1000                    "DROP SCHEMA IF EXISTS {prefix}_tsoracle CASCADE;
1001                     CREATE SCHEMA IF NOT EXISTS {prefix}_consensus;
1002                     CREATE SCHEMA {prefix}_tsoracle;"
1003                ))
1004                .await?;
1005            (
1006                format!("{postgres_url}?options=--search_path={prefix}_consensus")
1007                    .parse()
1008                    .expect("invalid consensus URI"),
1009                format!("{postgres_url}?options=--search_path={prefix}_tsoracle")
1010                    .parse()
1011                    .expect("invalid timestamp oracle URI"),
1012            )
1013        };
1014
1015        let secrets_dir = temp_dir.path().join("secrets");
1016        let orchestrator = Arc::new(
1017            ProcessOrchestrator::new(ProcessOrchestratorConfig {
1018                image_dir: env::current_exe()?.parent().unwrap().to_path_buf(),
1019                suppress_output: false,
1020                environment_id: environment_id.to_string(),
1021                secrets_dir: secrets_dir.clone(),
1022                command_wrapper: config
1023                    .orchestrator_process_wrapper
1024                    .as_ref()
1025                    .map_or(Ok(vec![]), |s| shell_words::split(s))?,
1026                propagate_crashes: true,
1027                tcp_proxy: None,
1028                scratch_directory: scratch_dir.path().to_path_buf(),
1029            })
1030            .await?,
1031        );
1032        let now = SYSTEM_TIME.clone();
1033        let metrics_registry = MetricsRegistry::new();
1034
1035        let persist_config = PersistConfig::new(
1036            &mz_environmentd::BUILD_INFO,
1037            now.clone(),
1038            mz_dyncfgs::all_dyncfgs(),
1039        );
1040        let persist_pubsub_server =
1041            PersistGrpcPubSubServer::new(&persist_config, &metrics_registry);
1042        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
1043        let persist_pubsub_tcp_listener =
1044            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1045                .await
1046                .expect("pubsub addr binding");
1047        let persist_pubsub_server_port = persist_pubsub_tcp_listener
1048            .local_addr()
1049            .expect("pubsub addr has local addr")
1050            .port();
1051        info!("listening for persist pubsub connections on localhost:{persist_pubsub_server_port}");
1052        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
1053            persist_pubsub_server
1054                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
1055                .await
1056                .expect("success")
1057        });
1058        let persist_clients =
1059            PersistClientCache::new(persist_config, &metrics_registry, |cfg, metrics| {
1060                let sender: Arc<dyn PubSubSender> = Arc::new(MetricsSameProcessPubSubSender::new(
1061                    cfg,
1062                    persist_pubsub_client.sender,
1063                    metrics,
1064                ));
1065                PubSubClientConnection::new(sender, persist_pubsub_client.receiver)
1066            });
1067        let persist_clients = Arc::new(persist_clients);
1068
1069        let secrets_controller = Arc::clone(&orchestrator);
1070        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
1071        let orchestrator = Arc::new(TracingOrchestrator::new(
1072            orchestrator,
1073            config.tracing.clone(),
1074        ));
1075        let listeners_config = ListenersConfig {
1076            sql: btreemap! {
1077                "external".to_owned() => SqlListenerConfig {
1078                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1079                    authenticator_kind: AuthenticatorKind::None,
1080                    allowed_roles: AllowedRoles::Normal,
1081                    enable_tls: false,
1082                },
1083                "internal".to_owned() => SqlListenerConfig {
1084                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1085                    authenticator_kind: AuthenticatorKind::None,
1086                    allowed_roles: AllowedRoles::Internal,
1087                    enable_tls: false,
1088                },
1089                "password".to_owned() => SqlListenerConfig {
1090                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1091                    authenticator_kind: AuthenticatorKind::Password,
1092                    allowed_roles: AllowedRoles::Normal,
1093                    enable_tls: false,
1094                },
1095            },
1096            http: btreemap![
1097                "external".to_owned() => HttpListenerConfig {
1098                    base: BaseListenerConfig {
1099                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1100                        authenticator_kind: AuthenticatorKind::None,
1101                        allowed_roles: AllowedRoles::Normal,
1102                        enable_tls: false
1103                    },
1104                    routes: HttpRoutesEnabled {
1105                        base: true,
1106                        webhook: true,
1107                        internal: false,
1108                        metrics: false,
1109                        profiling: false,
1110                    },
1111                },
1112                "internal".to_owned() => HttpListenerConfig {
1113                    base: BaseListenerConfig {
1114                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1115                        authenticator_kind: AuthenticatorKind::None,
1116                        allowed_roles: AllowedRoles::NormalAndInternal,
1117                        enable_tls: false
1118                    },
1119                    routes: HttpRoutesEnabled {
1120                        base: true,
1121                        webhook: true,
1122                        internal: true,
1123                        metrics: true,
1124                        profiling: true,
1125                    },
1126                },
1127            ],
1128        };
1129        let listeners = mz_environmentd::Listeners::bind(listeners_config).await?;
1130        let host_name = format!(
1131            "localhost:{}",
1132            listeners.http["external"].handle.local_addr.port()
1133        );
1134        let catalog_config = CatalogConfig {
1135            persist_clients: Arc::clone(&persist_clients),
1136            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
1137        };
1138        let server_config = mz_environmentd::Config {
1139            catalog_config,
1140            timestamp_oracle_url: Some(timestamp_oracle_url),
1141            controller: ControllerConfig {
1142                build_info: &mz_environmentd::BUILD_INFO,
1143                orchestrator,
1144                clusterd_image: "clusterd".into(),
1145                init_container_image: None,
1146                deploy_generation: 0,
1147                persist_location: PersistLocation {
1148                    blob_uri: format!(
1149                        "file://{}/persist/blob",
1150                        config.persist_dir.path().display()
1151                    )
1152                    .parse()
1153                    .expect("invalid blob URI"),
1154                    consensus_uri,
1155                },
1156                persist_clients,
1157                now: SYSTEM_TIME.clone(),
1158                metrics_registry: metrics_registry.clone(),
1159                persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
1160                secrets_args: mz_service::secrets::SecretsReaderCliArgs {
1161                    secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
1162                    secrets_reader_local_file_dir: Some(secrets_dir),
1163                    secrets_reader_kubernetes_context: None,
1164                    secrets_reader_aws_prefix: None,
1165                    secrets_reader_name_prefix: None,
1166                },
1167                connection_context,
1168                replica_http_locator: Arc::new(ReplicaHttpLocator::default()),
1169            },
1170            secrets_controller,
1171            cloud_resource_controller: None,
1172            tls: None,
1173            frontegg: None,
1174            cors_allowed_origin: AllowOrigin::list([]),
1175            unsafe_mode: true,
1176            all_features: false,
1177            metrics_registry,
1178            now,
1179            environment_id,
1180            cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
1181            bootstrap_default_cluster_replica_size: config.replica_size.clone(),
1182            bootstrap_default_cluster_replication_factor: config
1183                .replicas
1184                .try_into()
1185                .expect("replicas must fit"),
1186            bootstrap_builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
1187                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1188                size: config.replica_size.clone(),
1189            },
1190            bootstrap_builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
1191                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1192                size: config.replica_size.clone(),
1193            },
1194            bootstrap_builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
1195                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1196                size: config.replica_size.clone(),
1197            },
1198            bootstrap_builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
1199                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1200                size: config.replica_size.clone(),
1201            },
1202            bootstrap_builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
1203                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1204                size: config.replica_size.clone(),
1205            },
1206            system_parameter_defaults: {
1207                let mut params = BTreeMap::new();
1208                params.insert(
1209                    "log_filter".to_string(),
1210                    config.tracing.startup_log_filter.to_string(),
1211                );
1212                params.extend(config.system_parameter_defaults.clone());
1213                params
1214            },
1215            availability_zones: Default::default(),
1216            tracing_handle: config.tracing_handle.clone(),
1217            storage_usage_collection_interval: Duration::from_secs(3600),
1218            storage_usage_retention_period: None,
1219            segment_api_key: None,
1220            segment_client_side: false,
1221            // SLT doesn't like eternally running tasks since it waits for them to finish inbetween SLT files
1222            test_only_dummy_segment_client: false,
1223            egress_addresses: vec![],
1224            aws_account_id: None,
1225            aws_privatelink_availability_zones: None,
1226            launchdarkly_sdk_key: None,
1227            launchdarkly_key_map: Default::default(),
1228            config_sync_file_path: None,
1229            config_sync_timeout: Duration::from_secs(30),
1230            config_sync_loop_interval: None,
1231            bootstrap_role: Some("materialize".into()),
1232            http_host_name: Some(host_name),
1233            internal_console_redirect_url: None,
1234            tls_reload_certs: mz_server_core::cert_reload_never_reload(),
1235            helm_chart_version: None,
1236            license_key: ValidatedLicenseKey::for_tests(),
1237            external_login_password_mz_system: None,
1238            force_builtin_schema_migration: None,
1239        };
1240        // We need to run the server on its own Tokio runtime, which in turn
1241        // requires its own thread, so that we can wait for any tasks spawned
1242        // by the server to be shutdown at the end of each file. If we were to
1243        // share a Tokio runtime, tasks from the last file's server would still
1244        // be live at the start of the next file's server.
1245        let (server_addr_tx, server_addr_rx): (oneshot::Sender<Result<_, anyhow::Error>>, _) =
1246            oneshot::channel();
1247        let (internal_server_addr_tx, internal_server_addr_rx) = oneshot::channel();
1248        let (password_server_addr_tx, password_server_addr_rx) = oneshot::channel();
1249        let (internal_http_server_addr_tx, internal_http_server_addr_rx) = oneshot::channel();
1250        let (shutdown_trigger, shutdown_trigger_rx) = trigger::channel();
1251        let server_thread = thread::spawn(|| {
1252            let runtime = match Runtime::new() {
1253                Ok(runtime) => runtime,
1254                Err(e) => {
1255                    server_addr_tx
1256                        .send(Err(e.into()))
1257                        .expect("receiver should not drop first");
1258                    return;
1259                }
1260            };
1261            let server = match runtime.block_on(listeners.serve(server_config)) {
1262                Ok(runtime) => runtime,
1263                Err(e) => {
1264                    server_addr_tx
1265                        .send(Err(e.into()))
1266                        .expect("receiver should not drop first");
1267                    return;
1268                }
1269            };
1270            server_addr_tx
1271                .send(Ok(server.sql_listener_handles["external"].local_addr))
1272                .expect("receiver should not drop first");
1273            internal_server_addr_tx
1274                .send(server.sql_listener_handles["internal"].local_addr)
1275                .expect("receiver should not drop first");
1276            password_server_addr_tx
1277                .send(server.sql_listener_handles["password"].local_addr)
1278                .expect("receiver should not drop first");
1279            internal_http_server_addr_tx
1280                .send(server.http_listener_handles["internal"].local_addr)
1281                .expect("receiver should not drop first");
1282            let _ = runtime.block_on(shutdown_trigger_rx);
1283        });
1284        let server_addr = server_addr_rx.await??;
1285        let internal_server_addr = internal_server_addr_rx.await?;
1286        let password_server_addr = password_server_addr_rx.await?;
1287        let internal_http_server_addr = internal_http_server_addr_rx.await?;
1288
1289        let system_client = connect(internal_server_addr, Some("mz_system"), None)
1290            .await
1291            .unwrap();
1292        let client = connect(server_addr, None, None).await.unwrap();
1293
1294        let inner = RunnerInner {
1295            server_addr,
1296            internal_server_addr,
1297            password_server_addr,
1298            internal_http_server_addr,
1299            _shutdown_trigger: shutdown_trigger,
1300            _server_thread: server_thread.join_on_drop(),
1301            _temp_dir: temp_dir,
1302            client,
1303            system_client,
1304            clients: BTreeMap::new(),
1305            auto_index_tables: config.auto_index_tables,
1306            auto_index_selects: config.auto_index_selects,
1307            auto_transactions: config.auto_transactions,
1308            enable_table_keys: config.enable_table_keys,
1309            verbosity: config.verbosity,
1310            stdout: config.stdout,
1311        };
1312        inner.ensure_fixed_features().await?;
1313
1314        Ok(inner)
1315    }
1316
1317    /// Set features that should be enabled regardless of whether reset-server was
1318    /// called. These features may be set conditionally depending on the run configuration.
1319    async fn ensure_fixed_features(&self) -> Result<(), anyhow::Error> {
1320        // We turn on enable_reduce_mfp_fusion, as we wish
1321        // to get as much coverage of these features as we can.
1322        // TODO(vmarcos): Remove this code when we retire this feature flag.
1323        self.system_client
1324            .execute("ALTER SYSTEM SET enable_reduce_mfp_fusion = on", &[])
1325            .await?;
1326
1327        // Dangerous functions are useful for tests so we enable it for all tests.
1328        self.system_client
1329            .execute("ALTER SYSTEM SET unsafe_enable_unsafe_functions = on", &[])
1330            .await?;
1331        Ok(())
1332    }
1333
1334    async fn run_record<'r>(
1335        &mut self,
1336        record: &'r Record<'r>,
1337        in_transaction: &mut bool,
1338    ) -> Result<Outcome<'r>, anyhow::Error> {
1339        match &record {
1340            Record::Statement {
1341                expected_error,
1342                rows_affected,
1343                sql,
1344                location,
1345            } => {
1346                if self.auto_transactions && *in_transaction {
1347                    self.client.execute("COMMIT", &[]).await?;
1348                    *in_transaction = false;
1349                }
1350                match self
1351                    .run_statement(*expected_error, *rows_affected, sql, location.clone())
1352                    .await?
1353                {
1354                    Outcome::Success => {
1355                        if self.auto_index_tables {
1356                            let additional = mutate(sql);
1357                            for stmt in additional {
1358                                self.client.execute(&stmt, &[]).await?;
1359                            }
1360                        }
1361                        Ok(Outcome::Success)
1362                    }
1363                    other => {
1364                        if expected_error.is_some() {
1365                            Ok(other)
1366                        } else {
1367                            // If we failed to execute a statement that was supposed to succeed,
1368                            // running the rest of the tests in this file will probably cause
1369                            // false positives, so just give up on the file entirely.
1370                            Ok(Outcome::Bail {
1371                                cause: Box::new(other),
1372                                location: location.clone(),
1373                            })
1374                        }
1375                    }
1376                }
1377            }
1378            Record::Query {
1379                sql,
1380                output,
1381                location,
1382            } => {
1383                self.run_query(sql, output, location.clone(), in_transaction)
1384                    .await
1385            }
1386            Record::Simple {
1387                conn,
1388                user,
1389                password,
1390                sql,
1391                sort,
1392                output,
1393                location,
1394                ..
1395            } => {
1396                self.run_simple(
1397                    *conn,
1398                    *user,
1399                    *password,
1400                    sql,
1401                    sort.clone(),
1402                    output,
1403                    location.clone(),
1404                )
1405                .await
1406            }
1407            Record::Copy {
1408                table_name,
1409                tsv_path,
1410            } => {
1411                let tsv = tokio::fs::read(tsv_path).await?;
1412                let copy = self
1413                    .client
1414                    .copy_in(&*format!("COPY {} FROM STDIN", table_name))
1415                    .await?;
1416                tokio::pin!(copy);
1417                copy.send(bytes::Bytes::from(tsv)).await?;
1418                copy.finish().await?;
1419                Ok(Outcome::Success)
1420            }
1421            _ => Ok(Outcome::Success),
1422        }
1423    }
1424
1425    async fn run_statement<'r>(
1426        &self,
1427        expected_error: Option<&'r str>,
1428        expected_rows_affected: Option<u64>,
1429        sql: &'r str,
1430        location: Location,
1431    ) -> Result<Outcome<'r>, anyhow::Error> {
1432        static UNSUPPORTED_INDEX_STATEMENT_REGEX: LazyLock<Regex> =
1433            LazyLock::new(|| Regex::new("^(CREATE UNIQUE INDEX|REINDEX)").unwrap());
1434        if UNSUPPORTED_INDEX_STATEMENT_REGEX.is_match(sql) {
1435            // sure, we totally made you an index
1436            return Ok(Outcome::Success);
1437        }
1438
1439        match self.client.execute(sql, &[]).await {
1440            Ok(actual) => {
1441                if let Some(expected_error) = expected_error {
1442                    return Ok(Outcome::UnexpectedPlanSuccess {
1443                        expected_error,
1444                        location,
1445                    });
1446                }
1447                match expected_rows_affected {
1448                    None => Ok(Outcome::Success),
1449                    Some(expected) => {
1450                        if expected != actual {
1451                            Ok(Outcome::WrongNumberOfRowsInserted {
1452                                expected_count: expected,
1453                                actual_count: actual,
1454                                location,
1455                            })
1456                        } else {
1457                            Ok(Outcome::Success)
1458                        }
1459                    }
1460                }
1461            }
1462            Err(error) => {
1463                if let Some(expected_error) = expected_error {
1464                    if Regex::new(expected_error)?.is_match(&error.to_string_with_causes()) {
1465                        return Ok(Outcome::Success);
1466                    }
1467                }
1468                Ok(Outcome::PlanFailure {
1469                    error: anyhow!(error),
1470                    location,
1471                })
1472            }
1473        }
1474    }
1475
1476    async fn prepare_query<'r>(
1477        &self,
1478        sql: &str,
1479        output: &'r Result<QueryOutput<'_>, &'r str>,
1480        location: Location,
1481        in_transaction: &mut bool,
1482    ) -> Result<PrepareQueryOutcome<'r>, anyhow::Error> {
1483        // get statement
1484        let statements = match mz_sql::parse::parse(sql) {
1485            Ok(statements) => statements,
1486            Err(e) => match output {
1487                Ok(_) => {
1488                    return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1489                        error: e.into(),
1490                        location,
1491                    }));
1492                }
1493                Err(expected_error) => {
1494                    if Regex::new(expected_error)?.is_match(&e.to_string_with_causes()) {
1495                        return Ok(PrepareQueryOutcome::Outcome(Outcome::Success));
1496                    } else {
1497                        return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1498                            error: e.into(),
1499                            location,
1500                        }));
1501                    }
1502                }
1503            },
1504        };
1505        let statement = match &*statements {
1506            [] => bail!("Got zero statements?"),
1507            [statement] => &statement.ast,
1508            _ => bail!("Got multiple statements: {:?}", statements),
1509        };
1510        let (is_select, num_attributes) = match statement {
1511            Statement::Select(stmt) => (true, derive_num_attributes(&stmt.query.body)),
1512            _ => (false, None),
1513        };
1514
1515        match output {
1516            Ok(_) => {
1517                if self.auto_transactions && !*in_transaction {
1518                    // No ISOLATION LEVEL SERIALIZABLE because of database-issues#5323
1519                    self.client.execute("BEGIN", &[]).await?;
1520                    *in_transaction = true;
1521                }
1522            }
1523            Err(_) => {
1524                if self.auto_transactions && *in_transaction {
1525                    self.client.execute("COMMIT", &[]).await?;
1526                    *in_transaction = false;
1527                }
1528            }
1529        }
1530
1531        // `SHOW` commands reference catalog schema, thus are not in the same timedomain and not
1532        // allowed in the same transaction, see:
1533        // https://materialize.com/docs/sql/begin/#same-timedomain-error
1534        match statement {
1535            Statement::Show(..) => {
1536                if self.auto_transactions && *in_transaction {
1537                    self.client.execute("COMMIT", &[]).await?;
1538                    *in_transaction = false;
1539                }
1540            }
1541            _ => (),
1542        }
1543        Ok(PrepareQueryOutcome::QueryPrepared(QueryInfo {
1544            is_select,
1545            num_attributes,
1546        }))
1547    }
1548
1549    async fn execute_query<'r>(
1550        &self,
1551        sql: &str,
1552        output: &'r Result<QueryOutput<'_>, &'r str>,
1553        location: Location,
1554    ) -> Result<Outcome<'r>, anyhow::Error> {
1555        let rows = match self.client.query(sql, &[]).await {
1556            Ok(rows) => rows,
1557            Err(error) => {
1558                let error_string = error.to_string_with_causes();
1559                return match output {
1560                    Ok(_) => {
1561                        if error_string.contains("supported") || error_string.contains("overload") {
1562                            // this is a failure, but it's caused by lack of support rather than by bugs
1563                            Ok(Outcome::Unsupported {
1564                                error: anyhow!(error),
1565                                location,
1566                            })
1567                        } else {
1568                            Ok(Outcome::PlanFailure {
1569                                error: anyhow!(error),
1570                                location,
1571                            })
1572                        }
1573                    }
1574                    Err(expected_error) => {
1575                        if Regex::new(expected_error)?.is_match(&error_string) {
1576                            Ok(Outcome::Success)
1577                        } else {
1578                            Ok(Outcome::PlanFailure {
1579                                error: anyhow!(error),
1580                                location,
1581                            })
1582                        }
1583                    }
1584                };
1585            }
1586        };
1587
1588        // unpack expected output
1589        let QueryOutput {
1590            sort,
1591            types: expected_types,
1592            column_names: expected_column_names,
1593            output: expected_output,
1594            mode,
1595            ..
1596        } = match output {
1597            Err(expected_error) => {
1598                return Ok(Outcome::UnexpectedPlanSuccess {
1599                    expected_error,
1600                    location,
1601                });
1602            }
1603            Ok(query_output) => query_output,
1604        };
1605
1606        // format output
1607        let mut formatted_rows = vec![];
1608        for row in &rows {
1609            if row.len() != expected_types.len() {
1610                return Ok(Outcome::WrongColumnCount {
1611                    expected_count: expected_types.len(),
1612                    actual_count: row.len(),
1613                    location,
1614                });
1615            }
1616            let row = format_row(row, expected_types, *mode);
1617            formatted_rows.push(row);
1618        }
1619
1620        // sort formatted output
1621        if let Sort::Row = sort {
1622            formatted_rows.sort();
1623        }
1624        let mut values = formatted_rows.into_iter().flatten().collect::<Vec<_>>();
1625        if let Sort::Value = sort {
1626            values.sort();
1627        }
1628
1629        // Various checks as long as there are returned rows.
1630        if let Some(row) = rows.get(0) {
1631            // check column names
1632            if let Some(expected_column_names) = expected_column_names {
1633                let actual_column_names = row
1634                    .columns()
1635                    .iter()
1636                    .map(|t| ColumnName::from(t.name()))
1637                    .collect::<Vec<_>>();
1638                if expected_column_names != &actual_column_names {
1639                    return Ok(Outcome::WrongColumnNames {
1640                        expected_column_names,
1641                        actual_column_names,
1642                        actual_output: Output::Values(values),
1643                        location,
1644                    });
1645                }
1646            }
1647        }
1648
1649        // check output
1650        match expected_output {
1651            Output::Values(expected_values) => {
1652                if values != *expected_values {
1653                    return Ok(Outcome::OutputFailure {
1654                        expected_output,
1655                        actual_raw_output: rows,
1656                        actual_output: Output::Values(values),
1657                        location,
1658                    });
1659                }
1660            }
1661            Output::Hashed {
1662                num_values,
1663                md5: expected_md5,
1664            } => {
1665                let mut hasher = Md5::new();
1666                for value in &values {
1667                    hasher.update(value);
1668                    hasher.update("\n");
1669                }
1670                let md5 = format!("{:x}", hasher.finalize());
1671                if values.len() != *num_values || md5 != *expected_md5 {
1672                    return Ok(Outcome::OutputFailure {
1673                        expected_output,
1674                        actual_raw_output: rows,
1675                        actual_output: Output::Hashed {
1676                            num_values: values.len(),
1677                            md5,
1678                        },
1679                        location,
1680                    });
1681                }
1682            }
1683        }
1684
1685        Ok(Outcome::Success)
1686    }
1687
1688    async fn execute_view_inner<'r>(
1689        &self,
1690        sql: &str,
1691        output: &'r Result<QueryOutput<'_>, &'r str>,
1692        location: Location,
1693    ) -> Result<Option<Outcome<'r>>, anyhow::Error> {
1694        print_sql_if(self.stdout, sql, self.verbosity >= 2);
1695        let sql_result = self.client.execute(sql, &[]).await;
1696
1697        // Evaluate if we already reached an outcome or not.
1698        let tentative_outcome = if let Err(view_error) = sql_result {
1699            if let Err(expected_error) = output {
1700                if Regex::new(expected_error)?.is_match(&view_error.to_string_with_causes()) {
1701                    Some(Outcome::Success)
1702                } else {
1703                    Some(Outcome::PlanFailure {
1704                        error: view_error.into(),
1705                        location: location.clone(),
1706                    })
1707                }
1708            } else {
1709                Some(Outcome::PlanFailure {
1710                    error: view_error.into(),
1711                    location: location.clone(),
1712                })
1713            }
1714        } else {
1715            None
1716        };
1717        Ok(tentative_outcome)
1718    }
1719
1720    async fn execute_view<'r>(
1721        &self,
1722        sql: &str,
1723        num_attributes: Option<usize>,
1724        output: &'r Result<QueryOutput<'_>, &'r str>,
1725        location: Location,
1726    ) -> Result<Outcome<'r>, anyhow::Error> {
1727        // Create indexed view SQL commands and execute `CREATE VIEW`.
1728        let expected_column_names = if let Ok(QueryOutput { column_names, .. }) = output {
1729            column_names.clone()
1730        } else {
1731            None
1732        };
1733        let (create_view, create_index, view_sql, drop_view) = generate_view_sql(
1734            sql,
1735            Uuid::new_v4().as_simple(),
1736            num_attributes,
1737            expected_column_names,
1738        );
1739        let tentative_outcome = self
1740            .execute_view_inner(create_view.as_str(), output, location.clone())
1741            .await?;
1742
1743        // Either we already have an outcome or alternatively,
1744        // we proceed to index and query the view.
1745        if let Some(view_outcome) = tentative_outcome {
1746            return Ok(view_outcome);
1747        }
1748
1749        let tentative_outcome = self
1750            .execute_view_inner(create_index.as_str(), output, location.clone())
1751            .await?;
1752
1753        let view_outcome;
1754        if let Some(outcome) = tentative_outcome {
1755            view_outcome = outcome;
1756        } else {
1757            print_sql_if(self.stdout, view_sql.as_str(), self.verbosity >= 2);
1758            view_outcome = self
1759                .execute_query(view_sql.as_str(), output, location.clone())
1760                .await?;
1761        }
1762
1763        // Remember to clean up after ourselves by dropping the view.
1764        print_sql_if(self.stdout, drop_view.as_str(), self.verbosity >= 2);
1765        self.client.execute(drop_view.as_str(), &[]).await?;
1766
1767        Ok(view_outcome)
1768    }
1769
1770    async fn run_query<'r>(
1771        &self,
1772        sql: &'r str,
1773        output: &'r Result<QueryOutput<'_>, &'r str>,
1774        location: Location,
1775        in_transaction: &mut bool,
1776    ) -> Result<Outcome<'r>, anyhow::Error> {
1777        let prepare_outcome = self
1778            .prepare_query(sql, output, location.clone(), in_transaction)
1779            .await?;
1780        match prepare_outcome {
1781            PrepareQueryOutcome::QueryPrepared(QueryInfo {
1782                is_select,
1783                num_attributes,
1784            }) => {
1785                let query_outcome = self.execute_query(sql, output, location.clone()).await?;
1786                if is_select && self.auto_index_selects {
1787                    let view_outcome = self
1788                        .execute_view(sql, None, output, location.clone())
1789                        .await?;
1790
1791                    // We compare here the query-based and view-based outcomes.
1792                    // We only produce a test failure if the outcomes are of different
1793                    // variant types, thus accepting smaller deviations in the details
1794                    // produced for each variant.
1795                    if std::mem::discriminant::<Outcome>(&query_outcome)
1796                        != std::mem::discriminant::<Outcome>(&view_outcome)
1797                    {
1798                        // Before producing a failure outcome, we try to obtain a new
1799                        // outcome for view-based execution exploiting analysis of the
1800                        // number of attributes. This two-level strategy can avoid errors
1801                        // produced by column ambiguity in the `SELECT`.
1802                        let view_outcome = if num_attributes.is_some() {
1803                            self.execute_view(sql, num_attributes, output, location.clone())
1804                                .await?
1805                        } else {
1806                            view_outcome
1807                        };
1808
1809                        if std::mem::discriminant::<Outcome>(&query_outcome)
1810                            != std::mem::discriminant::<Outcome>(&view_outcome)
1811                        {
1812                            let inconsistent_view_outcome = Outcome::InconsistentViewOutcome {
1813                                query_outcome: Box::new(query_outcome),
1814                                view_outcome: Box::new(view_outcome),
1815                                location: location.clone(),
1816                            };
1817                            // Determine if this inconsistent view outcome should be reported
1818                            // as an error or only as a warning.
1819                            let outcome = if should_warn(&inconsistent_view_outcome) {
1820                                Outcome::Warning {
1821                                    cause: Box::new(inconsistent_view_outcome),
1822                                    location: location.clone(),
1823                                }
1824                            } else {
1825                                inconsistent_view_outcome
1826                            };
1827                            return Ok(outcome);
1828                        }
1829                    }
1830                }
1831                Ok(query_outcome)
1832            }
1833            PrepareQueryOutcome::Outcome(outcome) => Ok(outcome),
1834        }
1835    }
1836
1837    async fn get_conn(
1838        &mut self,
1839        name: Option<&str>,
1840        user: Option<&str>,
1841        password: Option<&str>,
1842    ) -> Result<&tokio_postgres::Client, tokio_postgres::Error> {
1843        match name {
1844            None => Ok(&self.client),
1845            Some(name) => {
1846                if !self.clients.contains_key(name) {
1847                    let addr = if matches!(user, Some("mz_system") | Some("mz_support")) {
1848                        self.internal_server_addr
1849                    } else if password.is_some() {
1850                        // Use password server for password authentication
1851                        self.password_server_addr
1852                    } else {
1853                        self.server_addr
1854                    };
1855                    let client = connect(addr, user, password).await?;
1856                    self.clients.insert(name.into(), client);
1857                }
1858                Ok(self.clients.get(name).unwrap())
1859            }
1860        }
1861    }
1862
1863    async fn run_simple<'r>(
1864        &mut self,
1865        conn: Option<&'r str>,
1866        user: Option<&'r str>,
1867        password: Option<&'r str>,
1868        sql: &'r str,
1869        sort: Sort,
1870        output: &'r Output,
1871        location: Location,
1872    ) -> Result<Outcome<'r>, anyhow::Error> {
1873        let actual = match self.get_conn(conn, user, password).await {
1874            Ok(client) => match client.simple_query(sql).await {
1875                Ok(result) => {
1876                    let mut rows = Vec::new();
1877
1878                    for m in result.into_iter() {
1879                        match m {
1880                            SimpleQueryMessage::Row(row) => {
1881                                let mut s = vec![];
1882                                for i in 0..row.len() {
1883                                    s.push(row.get(i).unwrap_or("NULL"));
1884                                }
1885                                rows.push(s.join(","));
1886                            }
1887                            SimpleQueryMessage::CommandComplete(count) => {
1888                                // This applies any sort on the COMPLETE line as
1889                                // well, but we do the same for the expected output.
1890                                rows.push(format!("COMPLETE {}", count));
1891                            }
1892                            SimpleQueryMessage::RowDescription(_) => {}
1893                            _ => panic!("unexpected"),
1894                        }
1895                    }
1896
1897                    if let Sort::Row = sort {
1898                        rows.sort();
1899                    }
1900
1901                    Output::Values(rows)
1902                }
1903                // Errors can contain multiple lines (say if there are details), and rewrite
1904                // sticks them each on their own line, so we need to split up the lines here to
1905                // each be its own String in the Vec.
1906                Err(error) => Output::Values(
1907                    error
1908                        .to_string_with_causes()
1909                        .lines()
1910                        .map(|s| s.to_string())
1911                        .collect(),
1912                ),
1913            },
1914            Err(error) => Output::Values(
1915                error
1916                    .to_string_with_causes()
1917                    .lines()
1918                    .map(|s| s.to_string())
1919                    .collect(),
1920            ),
1921        };
1922        if *output != actual {
1923            Ok(Outcome::OutputFailure {
1924                expected_output: output,
1925                actual_raw_output: vec![],
1926                actual_output: actual,
1927                location,
1928            })
1929        } else {
1930            Ok(Outcome::Success)
1931        }
1932    }
1933
1934    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
1935        let url = format!(
1936            "http://{}/api/catalog/check",
1937            self.internal_http_server_addr
1938        );
1939        let response: serde_json::Value = reqwest::get(&url).await?.json().await?;
1940
1941        if let Some(inconsistencies) = response.get("err") {
1942            let inconsistencies = serde_json::to_string_pretty(&inconsistencies)
1943                .expect("serializing Value cannot fail");
1944            Err(anyhow::anyhow!("Catalog inconsistency\n{inconsistencies}"))
1945        } else {
1946            Ok(())
1947        }
1948    }
1949}
1950
1951async fn connect(
1952    addr: SocketAddr,
1953    user: Option<&str>,
1954    password: Option<&str>,
1955) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
1956    let mut config = tokio_postgres::Config::new();
1957    config.host(addr.ip().to_string());
1958    config.port(addr.port());
1959    config.user(user.unwrap_or("materialize"));
1960    if let Some(password) = password {
1961        config.password(password);
1962    }
1963    let (client, connection) = config.connect(NoTls).await?;
1964
1965    task::spawn(|| "sqllogictest_connect", async move {
1966        if let Err(e) = connection.await {
1967            eprintln!("connection error: {}", e);
1968        }
1969    });
1970    Ok(client)
1971}
1972
1973pub trait WriteFmt {
1974    fn write_fmt(&self, fmt: fmt::Arguments<'_>);
1975}
1976
1977pub struct RunConfig<'a> {
1978    pub stdout: &'a dyn WriteFmt,
1979    pub stderr: &'a dyn WriteFmt,
1980    pub verbosity: u8,
1981    pub postgres_url: String,
1982    pub prefix: String,
1983    pub no_fail: bool,
1984    pub fail_fast: bool,
1985    pub auto_index_tables: bool,
1986    pub auto_index_selects: bool,
1987    pub auto_transactions: bool,
1988    pub enable_table_keys: bool,
1989    pub orchestrator_process_wrapper: Option<String>,
1990    pub tracing: TracingCliArgs,
1991    pub tracing_handle: TracingHandle,
1992    pub system_parameter_defaults: BTreeMap<String, String>,
1993    /// Persist state is handled specially because:
1994    /// - Persist background workers do not necessarily shut down immediately once the server is
1995    ///   shut down, and may panic if their storage is deleted out from under them.
1996    /// - It's safe for different databases to reference the same state: all data is scoped by UUID.
1997    pub persist_dir: TempDir,
1998    pub replicas: usize,
1999    pub replica_size: String,
2000}
2001
2002fn print_record(config: &RunConfig<'_>, record: &Record) {
2003    match record {
2004        Record::Statement { sql, .. } | Record::Query { sql, .. } => print_sql(config.stdout, sql),
2005        _ => (),
2006    }
2007}
2008
2009fn print_sql_if<'a>(stdout: &'a dyn WriteFmt, sql: &str, cond: bool) {
2010    if cond {
2011        print_sql(stdout, sql)
2012    }
2013}
2014
2015fn print_sql<'a>(stdout: &'a dyn WriteFmt, sql: &str) {
2016    writeln!(stdout, "{}", crate::util::indent(sql, 4))
2017}
2018
2019/// Regular expressions for matching error messages that should force a plan failure
2020/// in an inconsistent view outcome into a warning if the corresponding query succeeds.
2021const INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS: [&str; 9] = [
2022    // The following are unfixable errors in indexed views given our
2023    // current constraints.
2024    "cannot materialize call to",
2025    "SHOW commands are not allowed in views",
2026    "cannot create view with unstable dependencies",
2027    "cannot use wildcard expansions or NATURAL JOINs in a view that depends on system objects",
2028    "no valid schema selected",
2029    r#"system schema '\w+' cannot be modified"#,
2030    r#"permission denied for (SCHEMA|CLUSTER) "(\w+\.)?\w+""#,
2031    // NOTE(vmarcos): Column ambiguity that could not be eliminated by our
2032    // currently implemented syntactic rewrites is considered unfixable.
2033    // In addition, if some column cannot be dealt with, e.g., in `ORDER BY`
2034    // references, we treat this condition as unfixable as well.
2035    r#"column "[\w\?]+" specified more than once"#,
2036    r#"column "(\w+\.)?\w+" does not exist"#,
2037];
2038
2039/// Evaluates if the given outcome should be returned directly or if it should
2040/// be wrapped as a warning. Note that this function should be used for outcomes
2041/// that can be judged in a context-independent manner, i.e., the outcome itself
2042/// provides enough information as to whether a warning should be emitted or not.
2043fn should_warn(outcome: &Outcome) -> bool {
2044    match outcome {
2045        Outcome::InconsistentViewOutcome {
2046            query_outcome,
2047            view_outcome,
2048            ..
2049        } => match (query_outcome.as_ref(), view_outcome.as_ref()) {
2050            (Outcome::Success, Outcome::PlanFailure { error, .. }) => {
2051                INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS.iter().any(|s| {
2052                    Regex::new(s)
2053                        .expect("unexpected error in regular expression parsing")
2054                        .is_match(&error.to_string_with_causes())
2055                })
2056            }
2057            _ => false,
2058        },
2059        _ => false,
2060    }
2061}
2062
2063pub async fn run_string(
2064    runner: &mut Runner<'_>,
2065    source: &str,
2066    input: &str,
2067) -> Result<Outcomes, anyhow::Error> {
2068    runner.reset_database().await?;
2069
2070    let mut outcomes = Outcomes::default();
2071    let mut parser = crate::parser::Parser::new(source, input);
2072    // Transactions are currently relatively slow. Since sqllogictest runs in a single connection
2073    // there should be no difference in having longer running transactions.
2074    let mut in_transaction = false;
2075    writeln!(runner.config.stdout, "--- {}", source);
2076
2077    for record in parser.parse_records()? {
2078        // In maximal-verbosity mode, print the query before attempting to run
2079        // it. Running the query might panic, so it is important to print out
2080        // what query we are trying to run *before* we panic.
2081        if runner.config.verbosity >= 2 {
2082            print_record(runner.config, &record);
2083        }
2084
2085        let outcome = runner
2086            .run_record(&record, &mut in_transaction)
2087            .await
2088            .map_err(|err| format!("In {}:\n{}", source, err))
2089            .unwrap();
2090
2091        // Print warnings and failures in verbose mode.
2092        if runner.config.verbosity >= 1 && !outcome.success() {
2093            if runner.config.verbosity < 2 {
2094                // If `verbosity >= 2`, we'll already have printed the record,
2095                // so don't print it again. Yes, this is an ugly bit of logic.
2096                // Please don't try to consolidate it with the `print_record`
2097                // call above, as it's important to have a mode in which records
2098                // are printed before they are run, so that if running the
2099                // record panics, you can tell which record caused it.
2100                if !outcome.failure() {
2101                    writeln!(
2102                        runner.config.stdout,
2103                        "{}",
2104                        util::indent("Warning detected for: ", 4)
2105                    );
2106                }
2107                print_record(runner.config, &record);
2108            }
2109            if runner.config.verbosity >= 2 || outcome.failure() {
2110                writeln!(
2111                    runner.config.stdout,
2112                    "{}",
2113                    util::indent(&outcome.to_string(), 4)
2114                );
2115                writeln!(runner.config.stdout, "{}", util::indent("----", 4));
2116            }
2117        }
2118
2119        outcomes.stats[outcome.code()] += 1;
2120        if outcome.failure() {
2121            outcomes.details.push(format!("{}", outcome));
2122        }
2123
2124        if let Outcome::Bail { .. } = outcome {
2125            break;
2126        }
2127
2128        if runner.config.fail_fast && outcome.failure() {
2129            break;
2130        }
2131    }
2132    Ok(outcomes)
2133}
2134
2135pub async fn run_file(runner: &mut Runner<'_>, filename: &Path) -> Result<Outcomes, anyhow::Error> {
2136    let mut input = String::new();
2137    File::open(filename)?.read_to_string(&mut input)?;
2138    let outcomes = run_string(runner, &format!("{}", filename.display()), &input).await?;
2139    runner.check_catalog().await?;
2140
2141    Ok(outcomes)
2142}
2143
2144pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<(), anyhow::Error> {
2145    runner.reset_database().await?;
2146
2147    let mut file = OpenOptions::new().read(true).write(true).open(filename)?;
2148
2149    let mut input = String::new();
2150    file.read_to_string(&mut input)?;
2151
2152    let mut buf = RewriteBuffer::new(&input);
2153
2154    let mut parser = crate::parser::Parser::new(filename.to_str().unwrap_or(""), &input);
2155    writeln!(runner.config.stdout, "--- {}", filename.display());
2156    let mut in_transaction = false;
2157
2158    fn append_values_output(
2159        buf: &mut RewriteBuffer,
2160        input: &String,
2161        expected_output: &str,
2162        mode: &Mode,
2163        types: &Vec<Type>,
2164        column_names: Option<&Vec<ColumnName>>,
2165        actual_output: &Vec<String>,
2166        multiline: bool,
2167    ) {
2168        buf.append_header(input, expected_output, column_names);
2169
2170        for (i, row) in actual_output.chunks(types.len()).enumerate() {
2171            match mode {
2172                // In Cockroach mode, output each row on its own line, with
2173                // two spaces between each column.
2174                Mode::Cockroach => {
2175                    if i != 0 {
2176                        buf.append("\n");
2177                    }
2178
2179                    if row.len() == 0 {
2180                        // nothing to do
2181                    } else if row.len() == 1 {
2182                        // If there is only one column, then there is no need for space
2183                        // substitution, so we only do newline substitution.
2184                        if multiline {
2185                            buf.append(&row[0]);
2186                        } else {
2187                            buf.append(&row[0].replace('\n', "⏎"))
2188                        }
2189                    } else {
2190                        // Substitute spaces with ␠ to avoid mistaking the spaces in the result
2191                        // values with spaces that separate columns.
2192                        buf.append(
2193                            &row.iter()
2194                                .map(|col| {
2195                                    let mut col = col.replace(' ', "␠");
2196                                    if !multiline {
2197                                        col = col.replace('\n', "⏎");
2198                                    }
2199                                    col
2200                                })
2201                                .join("  "),
2202                        );
2203                    }
2204                }
2205                // In standard mode, output each value on its own line,
2206                // and ignore row boundaries.
2207                // No need to substitute spaces, because every value (not row) is on a separate
2208                // line. But we do need to substitute newlines.
2209                Mode::Standard => {
2210                    for (j, col) in row.iter().enumerate() {
2211                        if i != 0 || j != 0 {
2212                            buf.append("\n");
2213                        }
2214                        buf.append(&if multiline {
2215                            col.clone()
2216                        } else {
2217                            col.replace('\n', "⏎")
2218                        });
2219                    }
2220                }
2221            }
2222        }
2223    }
2224
2225    for record in parser.parse_records()? {
2226        let outcome = runner.run_record(&record, &mut in_transaction).await?;
2227
2228        match (&record, &outcome) {
2229            // If we see an output failure for a query, rewrite the expected output
2230            // to match the observed output.
2231            (
2232                Record::Query {
2233                    output:
2234                        Ok(QueryOutput {
2235                            mode,
2236                            output: Output::Values(_),
2237                            output_str: expected_output,
2238                            types,
2239                            column_names,
2240                            multiline,
2241                            ..
2242                        }),
2243                    ..
2244                },
2245                Outcome::OutputFailure {
2246                    actual_output: Output::Values(actual_output),
2247                    ..
2248                },
2249            ) => {
2250                append_values_output(
2251                    &mut buf,
2252                    &input,
2253                    expected_output,
2254                    mode,
2255                    types,
2256                    column_names.as_ref(),
2257                    actual_output,
2258                    *multiline,
2259                );
2260            }
2261            (
2262                Record::Query {
2263                    output:
2264                        Ok(QueryOutput {
2265                            mode,
2266                            output: Output::Values(_),
2267                            output_str: expected_output,
2268                            types,
2269                            multiline,
2270                            ..
2271                        }),
2272                    ..
2273                },
2274                Outcome::WrongColumnNames {
2275                    actual_column_names,
2276                    actual_output: Output::Values(actual_output),
2277                    ..
2278                },
2279            ) => {
2280                append_values_output(
2281                    &mut buf,
2282                    &input,
2283                    expected_output,
2284                    mode,
2285                    types,
2286                    Some(actual_column_names),
2287                    actual_output,
2288                    *multiline,
2289                );
2290            }
2291            (
2292                Record::Query {
2293                    output:
2294                        Ok(QueryOutput {
2295                            output: Output::Hashed { .. },
2296                            output_str: expected_output,
2297                            column_names,
2298                            ..
2299                        }),
2300                    ..
2301                },
2302                Outcome::OutputFailure {
2303                    actual_output: Output::Hashed { num_values, md5 },
2304                    ..
2305                },
2306            ) => {
2307                buf.append_header(&input, expected_output, column_names.as_ref());
2308
2309                buf.append(format!("{} values hashing to {}\n", num_values, md5).as_str())
2310            }
2311            (
2312                Record::Simple {
2313                    output_str: expected_output,
2314                    ..
2315                },
2316                Outcome::OutputFailure {
2317                    actual_output: Output::Values(actual_output),
2318                    ..
2319                },
2320            ) => {
2321                buf.append_header(&input, expected_output, None);
2322
2323                for (i, row) in actual_output.iter().enumerate() {
2324                    if i != 0 {
2325                        buf.append("\n");
2326                    }
2327                    buf.append(row);
2328                }
2329            }
2330            (
2331                Record::Query {
2332                    sql,
2333                    output: Err(err),
2334                    ..
2335                },
2336                outcome,
2337            )
2338            | (
2339                Record::Statement {
2340                    expected_error: Some(err),
2341                    sql,
2342                    ..
2343                },
2344                outcome,
2345            ) if outcome.err_msg().is_some() => {
2346                buf.rewrite_expected_error(&input, err, &outcome.err_msg().unwrap(), sql)
2347            }
2348            (_, Outcome::Success) => {}
2349            _ => bail!("unexpected: {:?} {:?}", record, outcome),
2350        }
2351    }
2352
2353    file.set_len(0)?;
2354    file.seek(SeekFrom::Start(0))?;
2355    file.write_all(buf.finish().as_bytes())?;
2356    file.sync_all()?;
2357    Ok(())
2358}
2359
2360/// Provides a means to rewrite the `.slt` file while iterating over it.
2361///
2362/// This struct takes the slt file as its `input`, tracks a cursor into it
2363/// (`input_offset`), and provides a buffer (`output`) to store the rewritten
2364/// results.
2365///
2366/// Functions that modify the file will lazily move `input` into `output` using
2367/// `flush_to`. However, those calls should all be interior to other functions.
2368#[derive(Debug)]
2369struct RewriteBuffer<'a> {
2370    input: &'a str,
2371    input_offset: usize,
2372    output: String,
2373}
2374
2375impl<'a> RewriteBuffer<'a> {
2376    fn new(input: &'a str) -> RewriteBuffer<'a> {
2377        RewriteBuffer {
2378            input,
2379            input_offset: 0,
2380            output: String::new(),
2381        }
2382    }
2383
2384    fn flush_to(&mut self, offset: usize) {
2385        assert!(offset >= self.input_offset);
2386        let chunk = &self.input[self.input_offset..offset];
2387        self.output.push_str(chunk);
2388        self.input_offset = offset;
2389    }
2390
2391    fn skip_to(&mut self, offset: usize) {
2392        assert!(offset >= self.input_offset);
2393        self.input_offset = offset;
2394    }
2395
2396    fn append(&mut self, s: &str) {
2397        self.output.push_str(s);
2398    }
2399
2400    fn append_header(
2401        &mut self,
2402        input: &String,
2403        expected_output: &str,
2404        column_names: Option<&Vec<ColumnName>>,
2405    ) {
2406        // Output everything before this record.
2407        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2408        #[allow(clippy::as_conversions)]
2409        let offset = expected_output.as_ptr() as usize - input.as_ptr() as usize;
2410        self.flush_to(offset);
2411        self.skip_to(offset + expected_output.len());
2412
2413        // Attempt to install the result separator (----), if it does
2414        // not already exist.
2415        if self.peek_last(5) == "\n----" {
2416            self.append("\n");
2417        } else if self.peek_last(6) != "\n----\n" {
2418            self.append("\n----\n");
2419        }
2420
2421        let Some(names) = column_names else {
2422            return;
2423        };
2424        self.append(
2425            &names
2426                .iter()
2427                .map(|name| name.replace(' ', "␠"))
2428                .collect::<Vec<_>>()
2429                .join(" "),
2430        );
2431        self.append("\n");
2432    }
2433
2434    fn rewrite_expected_error(
2435        &mut self,
2436        input: &String,
2437        old_err: &str,
2438        new_err: &str,
2439        query: &str,
2440    ) {
2441        // Output everything before this error message.
2442        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2443        #[allow(clippy::as_conversions)]
2444        let err_offset = old_err.as_ptr() as usize - input.as_ptr() as usize;
2445        self.flush_to(err_offset);
2446        self.append(new_err);
2447        self.append("\n");
2448        self.append(query);
2449        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2450        #[allow(clippy::as_conversions)]
2451        self.skip_to(query.as_ptr() as usize - input.as_ptr() as usize + query.len())
2452    }
2453
2454    fn peek_last(&self, n: usize) -> &str {
2455        &self.output[self.output.len() - n..]
2456    }
2457
2458    fn finish(mut self) -> String {
2459        self.flush_to(self.input.len());
2460        self.output
2461    }
2462}
2463
2464/// Generates view creation, view indexing, view querying, and view
2465/// dropping SQL commands for a given `SELECT` query. If the number
2466/// of attributes produced by the query is known, the view commands
2467/// are specialized to avoid issues with column ambiguity. This
2468/// function is a helper for `--auto_index_selects` and assumes that
2469/// the provided input SQL has already been run through the parser,
2470/// resulting in a valid `SELECT` statement.
2471fn generate_view_sql(
2472    sql: &str,
2473    view_uuid: &Simple,
2474    num_attributes: Option<usize>,
2475    expected_column_names: Option<Vec<ColumnName>>,
2476) -> (String, String, String, String) {
2477    // To create the view, re-parse the sql; note that we must find exactly
2478    // one statement and it must be a `SELECT`.
2479    // NOTE(vmarcos): Direct string manipulation was attempted while
2480    // prototyping the code below, which avoids the extra parsing and
2481    // data structure cloning. However, running DDL is so slow that
2482    // it did not matter in terms of runtime. We can revisit this if
2483    // DDL cost drops dramatically in the future.
2484    let stmts = parser::parse_statements(sql).unwrap_or_default();
2485    assert!(stmts.len() == 1);
2486    let (query, query_as_of) = match &stmts[0].ast {
2487        Statement::Select(stmt) => (&stmt.query, &stmt.as_of),
2488        _ => unreachable!("This function should only be called for SELECTs"),
2489    };
2490
2491    // Prior to creating the view, process the `ORDER BY` clause of
2492    // the `SELECT` query, if any. Ordering is not preserved when a
2493    // view includes an `ORDER BY` clause and must be re-enforced by
2494    // an external `ORDER BY` clause when querying the view.
2495    let (view_order_by, extra_columns, distinct) = if num_attributes.is_none() {
2496        (query.order_by.clone(), vec![], None)
2497    } else {
2498        derive_order_by(&query.body, &query.order_by)
2499    };
2500
2501    // Since one-shot SELECT statements may contain ambiguous column names,
2502    // we either use the expected column names, if that option was
2503    // provided, or else just rename the output schema of the view
2504    // using numerically increasing attribute names, whenever possible.
2505    // This strategy makes it possible to use `CREATE INDEX`, thus
2506    // matching the behavior of the option `auto_index_tables`. However,
2507    // we may be presented with a `SELECT *` query, in which case the parser
2508    // does not produce sufficient information to allow us to compute
2509    // the number of output columns. In the latter case, we are supplied
2510    // with `None` for `num_attributes` and just employ the command
2511    // `CREATE DEFAULT INDEX` instead. Additionally, the view is created
2512    // without schema renaming. This strategy is insufficient to dodge
2513    // column name ambiguity in all cases, but we assume here that we
2514    // can adjust the (hopefully) small number of tests that eventually
2515    // challenge us in this particular way.
2516    let name = UnresolvedItemName(vec![Ident::new_unchecked(format!("v{}", view_uuid))]);
2517    let projection = expected_column_names.map_or_else(
2518        || {
2519            num_attributes.map_or(vec![], |n| {
2520                (1..=n)
2521                    .map(|i| Ident::new_unchecked(format!("a{i}")))
2522                    .collect()
2523            })
2524        },
2525        |cols| {
2526            cols.iter()
2527                .map(|c| Ident::new_unchecked(c.as_str()))
2528                .collect()
2529        },
2530    );
2531    let columns: Vec<Ident> = projection
2532        .iter()
2533        .cloned()
2534        .chain(extra_columns.iter().map(|item| {
2535            if let SelectItem::Expr {
2536                expr: _,
2537                alias: Some(ident),
2538            } = item
2539            {
2540                ident.clone()
2541            } else {
2542                unreachable!("alias must be given for extra column")
2543            }
2544        }))
2545        .collect();
2546
2547    // Build a `CREATE VIEW` with the columns computed above.
2548    let mut query = query.clone();
2549    if extra_columns.len() > 0 {
2550        match &mut query.body {
2551            SetExpr::Select(stmt) => stmt.projection.extend(extra_columns.iter().cloned()),
2552            _ => unimplemented!("cannot yet rewrite projections of nested queries"),
2553        }
2554    }
2555    let create_view = AstStatement::<Raw>::CreateView(CreateViewStatement {
2556        if_exists: IfExistsBehavior::Error,
2557        temporary: false,
2558        definition: ViewDefinition {
2559            name: name.clone(),
2560            columns: columns.clone(),
2561            query,
2562        },
2563    })
2564    .to_ast_string_stable();
2565
2566    // We then create either a `CREATE INDEX` or a `CREATE DEFAULT INDEX`
2567    // statement, depending on whether we could obtain the number of
2568    // attributes from the original `SELECT`.
2569    let create_index = AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2570        name: None,
2571        in_cluster: None,
2572        on_name: RawItemName::Name(name.clone()),
2573        key_parts: if columns.len() == 0 {
2574            None
2575        } else {
2576            Some(
2577                columns
2578                    .iter()
2579                    .map(|ident| Expr::Identifier(vec![ident.clone()]))
2580                    .collect(),
2581            )
2582        },
2583        with_options: Vec::new(),
2584        if_not_exists: false,
2585    })
2586    .to_ast_string_stable();
2587
2588    // Assert if DISTINCT semantics are unchanged from view
2589    let distinct_unneeded = extra_columns.len() == 0
2590        || match distinct {
2591            None | Some(Distinct::On(_)) => true,
2592            Some(Distinct::EntireRow) => false,
2593        };
2594    let distinct = if distinct_unneeded { None } else { distinct };
2595
2596    // `SELECT [* | {projection}] FROM {name} [ORDER BY {view_order_by}]`
2597    let view_sql = AstStatement::<Raw>::Select(SelectStatement {
2598        query: Query {
2599            ctes: CteBlock::Simple(vec![]),
2600            body: SetExpr::Select(Box::new(Select {
2601                distinct,
2602                projection: if projection.len() == 0 {
2603                    vec![SelectItem::Wildcard]
2604                } else {
2605                    projection
2606                        .iter()
2607                        .map(|ident| SelectItem::Expr {
2608                            expr: Expr::Identifier(vec![ident.clone()]),
2609                            alias: None,
2610                        })
2611                        .collect()
2612                },
2613                from: vec![TableWithJoins {
2614                    relation: TableFactor::Table {
2615                        name: RawItemName::Name(name.clone()),
2616                        alias: None,
2617                    },
2618                    joins: vec![],
2619                }],
2620                selection: None,
2621                group_by: vec![],
2622                having: None,
2623                qualify: None,
2624                options: vec![],
2625            })),
2626            order_by: view_order_by,
2627            limit: None,
2628            offset: None,
2629        },
2630        as_of: query_as_of.clone(),
2631    })
2632    .to_ast_string_stable();
2633
2634    // `DROP VIEW {name}`
2635    let drop_view = AstStatement::<Raw>::DropObjects(DropObjectsStatement {
2636        object_type: ObjectType::View,
2637        if_exists: false,
2638        names: vec![UnresolvedObjectName::Item(name)],
2639        cascade: false,
2640    })
2641    .to_ast_string_stable();
2642
2643    (create_view, create_index, view_sql, drop_view)
2644}
2645
2646/// Analyzes the provided query `body` to derive the number of
2647/// attributes in the query. We only consider syntactic cues,
2648/// so we may end up deriving `None` for the number of attributes
2649/// as a conservative approximation.
2650fn derive_num_attributes(body: &SetExpr<Raw>) -> Option<usize> {
2651    let Some((projection, _)) = find_projection(body) else {
2652        return None;
2653    };
2654    derive_num_attributes_from_projection(projection)
2655}
2656
2657/// Analyzes a query's `ORDER BY` clause to derive an `ORDER BY`
2658/// clause that makes numeric references to any expressions in
2659/// the projection and generated-attribute references to expressions
2660/// that need to be added as extra columns to the projection list.
2661/// The rewritten `ORDER BY` clause is then usable when querying a
2662/// view that contains the same `SELECT` as the given query.
2663/// This function returns both the rewritten `ORDER BY` clause
2664/// as well as a list of extra columns that need to be added
2665/// to the query's projection for the `ORDER BY` clause to
2666/// succeed.
2667fn derive_order_by(
2668    body: &SetExpr<Raw>,
2669    order_by: &Vec<OrderByExpr<Raw>>,
2670) -> (
2671    Vec<OrderByExpr<Raw>>,
2672    Vec<SelectItem<Raw>>,
2673    Option<Distinct<Raw>>,
2674) {
2675    let Some((projection, distinct)) = find_projection(body) else {
2676        return (vec![], vec![], None);
2677    };
2678    let (view_order_by, extra_columns) = derive_order_by_from_projection(projection, order_by);
2679    (view_order_by, extra_columns, distinct.clone())
2680}
2681
2682/// Finds the projection list in a `SELECT` query body.
2683fn find_projection(body: &SetExpr<Raw>) -> Option<(&Vec<SelectItem<Raw>>, &Option<Distinct<Raw>>)> {
2684    // Iterate to peel off the query body until the query's
2685    // projection list is found.
2686    let mut set_expr = body;
2687    loop {
2688        match set_expr {
2689            SetExpr::Select(select) => {
2690                return Some((&select.projection, &select.distinct));
2691            }
2692            SetExpr::SetOperation { left, .. } => set_expr = left.as_ref(),
2693            SetExpr::Query(query) => set_expr = &query.body,
2694            _ => return None,
2695        }
2696    }
2697}
2698
2699/// Computes the number of attributes that are obtained by the
2700/// projection of a `SELECT` query. The projection may include
2701/// wildcards, in which case the analysis just returns `None`.
2702fn derive_num_attributes_from_projection(projection: &Vec<SelectItem<Raw>>) -> Option<usize> {
2703    let mut num_attributes = 0usize;
2704    for item in projection.iter() {
2705        let SelectItem::Expr { expr, .. } = item else {
2706            return None;
2707        };
2708        match expr {
2709            Expr::QualifiedWildcard(..) | Expr::WildcardAccess(..) => {
2710                return None;
2711            }
2712            _ => {
2713                num_attributes += 1;
2714            }
2715        }
2716    }
2717    Some(num_attributes)
2718}
2719
2720/// Computes an `ORDER BY` clause with only numeric references
2721/// from given projection and `ORDER BY` of a `SELECT` query.
2722/// If the derivation fails to match a given expression, the
2723/// matched prefix is returned. Note that this could be empty.
2724fn derive_order_by_from_projection(
2725    projection: &Vec<SelectItem<Raw>>,
2726    order_by: &Vec<OrderByExpr<Raw>>,
2727) -> (Vec<OrderByExpr<Raw>>, Vec<SelectItem<Raw>>) {
2728    let mut view_order_by: Vec<OrderByExpr<Raw>> = vec![];
2729    let mut extra_columns: Vec<SelectItem<Raw>> = vec![];
2730    for order_by_expr in order_by.iter() {
2731        let query_expr = &order_by_expr.expr;
2732        let view_expr = match query_expr {
2733            Expr::Value(mz_sql_parser::ast::Value::Number(_)) => query_expr.clone(),
2734            _ => {
2735                // Find expression in query projection, if we can.
2736                if let Some(i) = projection.iter().position(|item| match item {
2737                    SelectItem::Expr { expr, alias } => {
2738                        expr == query_expr
2739                            || match query_expr {
2740                                Expr::Identifier(ident) => {
2741                                    ident.len() == 1 && Some(&ident[0]) == alias.as_ref()
2742                                }
2743                                _ => false,
2744                            }
2745                    }
2746                    SelectItem::Wildcard => false,
2747                }) {
2748                    Expr::Value(mz_sql_parser::ast::Value::Number((i + 1).to_string()))
2749                } else {
2750                    // If the expression is not found in the
2751                    // projection, add extra column.
2752                    let ident = Ident::new_unchecked(format!(
2753                        "a{}",
2754                        (projection.len() + extra_columns.len() + 1)
2755                    ));
2756                    extra_columns.push(SelectItem::Expr {
2757                        expr: query_expr.clone(),
2758                        alias: Some(ident.clone()),
2759                    });
2760                    Expr::Identifier(vec![ident])
2761                }
2762            }
2763        };
2764        view_order_by.push(OrderByExpr {
2765            expr: view_expr,
2766            asc: order_by_expr.asc,
2767            nulls_last: order_by_expr.nulls_last,
2768        });
2769    }
2770    (view_order_by, extra_columns)
2771}
2772
2773/// Returns extra statements to execute after `stmt` is executed.
2774fn mutate(sql: &str) -> Vec<String> {
2775    let stmts = parser::parse_statements(sql).unwrap_or_default();
2776    let mut additional = Vec::new();
2777    for stmt in stmts {
2778        match stmt.ast {
2779            AstStatement::CreateTable(stmt) => additional.push(
2780                // CREATE TABLE -> CREATE INDEX. Specify all columns manually in case CREATE
2781                // DEFAULT INDEX ever goes away.
2782                AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2783                    name: None,
2784                    in_cluster: None,
2785                    on_name: RawItemName::Name(stmt.name.clone()),
2786                    key_parts: Some(
2787                        stmt.columns
2788                            .iter()
2789                            .map(|def| Expr::Identifier(vec![def.name.clone()]))
2790                            .collect(),
2791                    ),
2792                    with_options: Vec::new(),
2793                    if_not_exists: false,
2794                })
2795                .to_ast_string_stable(),
2796            ),
2797            _ => {}
2798        }
2799    }
2800    additional
2801}
2802
2803#[mz_ore::test]
2804#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
2805fn test_generate_view_sql() {
2806    let uuid = Uuid::parse_str("67e5504410b1426f9247bb680e5fe0c8").unwrap();
2807    let cases = vec![
2808        (("SELECT * FROM t", None, None),
2809        (
2810            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM "t""#.to_string(),
2811            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2812            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2813            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2814        )),
2815        (("SELECT a, b, c FROM t1, t2", Some(3), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c")])),
2816        (
2817            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2818            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c")"#.to_string(),
2819            r#"SELECT "a", "b", "c" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2820            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2821        )),
2822        (("SELECT a, b, c FROM t1, t2", Some(3), None),
2823        (
2824            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2825            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2826            r#"SELECT "a1", "a2", "a3" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2827            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2828        )),
2829        // A case with ambiguity that is accepted by the function, illustrating that
2830        // our measures to dodge this issue are imperfect.
2831        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a)", None, None),
2832        (
2833            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a")"#.to_string(),
2834            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2835            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2836            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2837        )),
2838        (("SELECT a, b, b + d AS c, a + b AS d FROM t1, t2 ORDER BY a, c, a + b", Some(4), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c"), ColumnName::from("d")])),
2839        (
2840            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d") AS SELECT "a", "b", "b" + "d" AS "c", "a" + "b" AS "d" FROM "t1", "t2" ORDER BY "a", "c", "a" + "b""#.to_string(),
2841            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d")"#.to_string(),
2842            r#"SELECT "a", "b", "c", "d" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, 3, 4"#.to_string(),
2843            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2844        )),
2845        (("((SELECT 1 AS a UNION SELECT 2 AS b) UNION SELECT 3 AS c) ORDER BY a", Some(1), None),
2846        (
2847            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1") AS (SELECT 1 AS "a" UNION SELECT 2 AS "b") UNION SELECT 3 AS "c" ORDER BY "a""#.to_string(),
2848            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1")"#.to_string(),
2849            r#"SELECT "a1" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2850            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2851        )),
2852        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY 1", None, None),
2853        (
2854            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY 1"#.to_string(),
2855            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2856            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2857            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2858        )),
2859        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY a", None, None),
2860        (
2861            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY "a""#.to_string(),
2862            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2863            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a""#.to_string(),
2864            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2865        )),
2866        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY a, c", Some(2), None),
2867        (
2868            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "a", "c""#.to_string(),
2869            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2870            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, "a3""#.to_string(),
2871            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2872        )),
2873        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY c, a", Some(2), None),
2874        (
2875            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "c", "a""#.to_string(),
2876            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2877            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a3", 1"#.to_string(),
2878            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2879        )),
2880    ];
2881    for ((sql, num_attributes, expected_column_names), expected) in cases {
2882        let view_sql =
2883            generate_view_sql(sql, uuid.as_simple(), num_attributes, expected_column_names);
2884        assert_eq!(expected, view_sql);
2885    }
2886}
2887
2888#[mz_ore::test]
2889fn test_mutate() {
2890    let cases = vec![
2891        ("CREATE TABLE t ()", vec![r#"CREATE INDEX ON "t" ()"#]),
2892        (
2893            "CREATE TABLE t (a INT)",
2894            vec![r#"CREATE INDEX ON "t" ("a")"#],
2895        ),
2896        (
2897            "CREATE TABLE t (a INT, b TEXT)",
2898            vec![r#"CREATE INDEX ON "t" ("a", "b")"#],
2899        ),
2900        // Invalid syntax works, just returns nothing.
2901        ("BAD SYNTAX", Vec::new()),
2902    ];
2903    for (sql, expected) in cases {
2904        let stmts = mutate(sql);
2905        assert_eq!(expected, stmts, "sql: {sql}");
2906    }
2907}