Skip to main content

mz_sqllogictest/
runner.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! The Materialize-specific runner for sqllogictest.
11//!
12//! slt tests expect a serialized execution of sql statements and queries.
13//! To get the same results in materialize we track current_timestamp and increment it whenever we execute a statement.
14//!
15//! The high-level workflow is:
16//!   for each record in the test file:
17//!     if record is a sql statement:
18//!       run sql in postgres, observe changes and copy them to materialize using LocalInput::Updates(..)
19//!       advance current_timestamp
20//!       promise to never send updates for times < current_timestamp using LocalInput::Watermark(..)
21//!       compare to expected results
22//!       if wrong, bail out and stop processing this file
23//!     if record is a sql query:
24//!       peek query at current_timestamp
25//!       compare to expected results
26//!       if wrong, record the error
27
28use std::collections::BTreeMap;
29use std::error::Error;
30use std::fs::{File, OpenOptions};
31use std::io::{Read, Seek, SeekFrom, Write};
32use std::net::{IpAddr, Ipv4Addr, SocketAddr};
33use std::path::Path;
34use std::sync::Arc;
35use std::sync::LazyLock;
36use std::time::Duration;
37use std::{env, fmt, ops, str, thread};
38
39use anyhow::{anyhow, bail};
40use bytes::BytesMut;
41use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
42use fallible_iterator::FallibleIterator;
43use futures::sink::SinkExt;
44use itertools::Itertools;
45use maplit::btreemap;
46use md5::{Digest, Md5};
47use mz_adapter_types::bootstrap_builtin_cluster_config::{
48    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
49    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
50    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
51};
52use mz_catalog::config::ClusterReplicaSizeMap;
53use mz_controller::{ControllerConfig, ReplicaHttpLocator};
54use mz_environmentd::CatalogConfig;
55use mz_license_keys::ValidatedLicenseKey;
56use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
57use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
58use mz_ore::cast::{CastFrom, ReinterpretCast};
59use mz_ore::channel::trigger;
60use mz_ore::error::ErrorExt;
61use mz_ore::metrics::MetricsRegistry;
62use mz_ore::now::SYSTEM_TIME;
63use mz_ore::retry::Retry;
64use mz_ore::task;
65use mz_ore::thread::{JoinHandleExt, JoinOnDropHandle};
66use mz_ore::tracing::TracingHandle;
67use mz_ore::url::SensitiveUrl;
68use mz_persist_client::PersistLocation;
69use mz_persist_client::cache::PersistClientCache;
70use mz_persist_client::cfg::PersistConfig;
71use mz_persist_client::rpc::{
72    MetricsSameProcessPubSubSender, PersistGrpcPubSubServer, PubSubClientConnection, PubSubSender,
73};
74use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8, Value, oid};
75use mz_repr::ColumnName;
76use mz_repr::adt::date::Date;
77use mz_repr::adt::mz_acl_item::{AclItem, MzAclItem};
78use mz_repr::adt::numeric;
79use mz_secrets::SecretsController;
80use mz_server_core::listeners::{
81    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpListenerConfig, HttpRoutesEnabled,
82    ListenersConfig, SqlListenerConfig,
83};
84use mz_sql::ast::{Expr, Raw, Statement};
85use mz_sql::catalog::EnvironmentId;
86use mz_sql_parser::ast::display::AstDisplay;
87use mz_sql_parser::ast::{
88    CreateIndexStatement, CreateViewStatement, CteBlock, Distinct, DropObjectsStatement, Ident,
89    IfExistsBehavior, ObjectType, OrderByExpr, Query, RawItemName, Select, SelectItem,
90    SelectStatement, SetExpr, Statement as AstStatement, TableFactor, TableWithJoins,
91    UnresolvedItemName, UnresolvedObjectName, ViewDefinition,
92};
93use mz_sql_parser::parser;
94use mz_storage_types::connections::ConnectionContext;
95use postgres_protocol::types;
96use regex::Regex;
97use tempfile::TempDir;
98use tokio::net::TcpListener;
99use tokio::runtime::Runtime;
100use tokio::sync::oneshot;
101use tokio_postgres::types::{FromSql, Kind as PgKind, Type as PgType};
102use tokio_postgres::{NoTls, Row, SimpleQueryMessage};
103use tokio_stream::wrappers::TcpListenerStream;
104use tower_http::cors::AllowOrigin;
105use tracing::{error, info};
106use uuid::Uuid;
107use uuid::fmt::Simple;
108
109use crate::ast::{Location, Mode, Output, QueryOutput, Record, Sort, Type};
110use crate::util;
111
112#[derive(Debug)]
113pub enum Outcome<'a> {
114    Unsupported {
115        error: anyhow::Error,
116        location: Location,
117    },
118    ParseFailure {
119        error: anyhow::Error,
120        location: Location,
121    },
122    PlanFailure {
123        error: anyhow::Error,
124        location: Location,
125    },
126    UnexpectedPlanSuccess {
127        expected_error: &'a str,
128        location: Location,
129    },
130    WrongNumberOfRowsInserted {
131        expected_count: u64,
132        actual_count: u64,
133        location: Location,
134    },
135    WrongColumnCount {
136        expected_count: usize,
137        actual_count: usize,
138        location: Location,
139    },
140    WrongColumnNames {
141        expected_column_names: &'a Vec<ColumnName>,
142        actual_column_names: Vec<ColumnName>,
143        actual_output: Output,
144        location: Location,
145    },
146    OutputFailure {
147        expected_output: &'a Output,
148        actual_raw_output: Vec<Row>,
149        actual_output: Output,
150        location: Location,
151    },
152    InconsistentViewOutcome {
153        query_outcome: Box<Outcome<'a>>,
154        view_outcome: Box<Outcome<'a>>,
155        location: Location,
156    },
157    Bail {
158        cause: Box<Outcome<'a>>,
159        location: Location,
160    },
161    Warning {
162        cause: Box<Outcome<'a>>,
163        location: Location,
164    },
165    Success,
166}
167
168const NUM_OUTCOMES: usize = 12;
169const WARNING_OUTCOME: usize = NUM_OUTCOMES - 2;
170const SUCCESS_OUTCOME: usize = NUM_OUTCOMES - 1;
171
172impl<'a> Outcome<'a> {
173    fn code(&self) -> usize {
174        match self {
175            Outcome::Unsupported { .. } => 0,
176            Outcome::ParseFailure { .. } => 1,
177            Outcome::PlanFailure { .. } => 2,
178            Outcome::UnexpectedPlanSuccess { .. } => 3,
179            Outcome::WrongNumberOfRowsInserted { .. } => 4,
180            Outcome::WrongColumnCount { .. } => 5,
181            Outcome::WrongColumnNames { .. } => 6,
182            Outcome::OutputFailure { .. } => 7,
183            Outcome::InconsistentViewOutcome { .. } => 8,
184            Outcome::Bail { .. } => 9,
185            Outcome::Warning { .. } => 10,
186            Outcome::Success => 11,
187        }
188    }
189
190    fn success(&self) -> bool {
191        matches!(self, Outcome::Success)
192    }
193
194    fn failure(&self) -> bool {
195        !matches!(self, Outcome::Success) && !matches!(self, Outcome::Warning { .. })
196    }
197
198    /// Returns an error message that will match self. Appropriate for
199    /// rewriting error messages (i.e. not inserting error messages where we
200    /// currently expect success).
201    fn err_msg(&self) -> Option<String> {
202        match self {
203            Outcome::Unsupported { error, .. }
204            | Outcome::ParseFailure { error, .. }
205            | Outcome::PlanFailure { error, .. } => {
206                // Take only the first line, which should be sufficient for
207                // meaningfully matching the error.
208                let err_str = error.to_string_with_causes();
209                let err_str = err_str.split('\n').next().unwrap();
210                // Strip the "db error: ERROR: " prefix added by the postgres
211                // client library, as it's noisy and not useful for matching.
212                let err_str = err_str.strip_prefix("db error: ERROR: ").unwrap_or(err_str);
213                // This value gets fed back into regex to check that it matches
214                // `self`, so escape its meta characters. We need to undo the
215                // escaping of #. `regex::escape` escapes this because it
216                // expects that we use the `x` flag when building a regex, but
217                // this is not the case, so \# would end up being an invalid
218                // escape sequence, which would choke the parsing of the slt
219                // file the next time around.
220                Some(regex::escape(err_str).replace(r"\#", "#"))
221            }
222            _ => None,
223        }
224    }
225}
226
227impl fmt::Display for Outcome<'_> {
228    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229        use Outcome::*;
230        const INDENT: &str = "\n        ";
231        match self {
232            Unsupported { error, location } => write!(
233                f,
234                "Unsupported:{}:\n{}",
235                location,
236                error.display_with_causes()
237            ),
238            ParseFailure { error, location } => {
239                write!(
240                    f,
241                    "ParseFailure:{}:\n{}",
242                    location,
243                    error.display_with_causes()
244                )
245            }
246            PlanFailure { error, location } => write!(f, "PlanFailure:{}:\n{:#}", location, error),
247            UnexpectedPlanSuccess {
248                expected_error,
249                location,
250            } => write!(
251                f,
252                "UnexpectedPlanSuccess:{} expected error: {}",
253                location, expected_error
254            ),
255            WrongNumberOfRowsInserted {
256                expected_count,
257                actual_count,
258                location,
259            } => write!(
260                f,
261                "WrongNumberOfRowsInserted:{}{}expected: {}{}actually: {}",
262                location, INDENT, expected_count, INDENT, actual_count
263            ),
264            WrongColumnCount {
265                expected_count,
266                actual_count,
267                location,
268            } => write!(
269                f,
270                "WrongColumnCount:{}{}expected: {}{}actually: {}",
271                location, INDENT, expected_count, INDENT, actual_count
272            ),
273            WrongColumnNames {
274                expected_column_names,
275                actual_column_names,
276                actual_output: _,
277                location,
278            } => write!(
279                f,
280                "Wrong Column Names:{}:{}expected column names: {}{}inferred column names: {}",
281                location,
282                INDENT,
283                expected_column_names
284                    .iter()
285                    .map(|n| n.to_string())
286                    .collect::<Vec<_>>()
287                    .join(" "),
288                INDENT,
289                actual_column_names
290                    .iter()
291                    .map(|n| n.to_string())
292                    .collect::<Vec<_>>()
293                    .join(" ")
294            ),
295            OutputFailure {
296                expected_output,
297                actual_raw_output,
298                actual_output,
299                location,
300            } => write!(
301                f,
302                "OutputFailure:{}{}expected: {:?}{}actually: {:?}{}actual raw: {:?}",
303                location, INDENT, expected_output, INDENT, actual_output, INDENT, actual_raw_output
304            ),
305            InconsistentViewOutcome {
306                query_outcome,
307                view_outcome,
308                location,
309            } => write!(
310                f,
311                "InconsistentViewOutcome:{}{}expected from query: {}{}actually from indexed view: {}",
312                location, INDENT, query_outcome, INDENT, view_outcome
313            ),
314            Bail { cause, location } => write!(f, "Bail:{} {}", location, cause),
315            Warning { cause, location } => write!(f, "Warning:{} {}", location, cause),
316            Success => f.write_str("Success"),
317        }
318    }
319}
320
321#[derive(Default, Debug)]
322pub struct Outcomes {
323    stats: [usize; NUM_OUTCOMES],
324    details: Vec<String>,
325}
326
327impl ops::AddAssign<Outcomes> for Outcomes {
328    fn add_assign(&mut self, rhs: Outcomes) {
329        for (lhs, rhs) in self.stats.iter_mut().zip_eq(rhs.stats.iter()) {
330            *lhs += rhs
331        }
332    }
333}
334impl Outcomes {
335    pub fn any_failed(&self) -> bool {
336        self.stats[SUCCESS_OUTCOME] + self.stats[WARNING_OUTCOME] < self.stats.iter().sum::<usize>()
337    }
338
339    pub fn as_json(&self) -> serde_json::Value {
340        serde_json::json!({
341            "unsupported": self.stats[0],
342            "parse_failure": self.stats[1],
343            "plan_failure": self.stats[2],
344            "unexpected_plan_success": self.stats[3],
345            "wrong_number_of_rows_affected": self.stats[4],
346            "wrong_column_count": self.stats[5],
347            "wrong_column_names": self.stats[6],
348            "output_failure": self.stats[7],
349            "inconsistent_view_outcome": self.stats[8],
350            "bail": self.stats[9],
351            "warning": self.stats[10],
352            "success": self.stats[11],
353        })
354    }
355
356    pub fn display(&self, no_fail: bool, failure_details: bool) -> OutcomesDisplay<'_> {
357        OutcomesDisplay {
358            inner: self,
359            no_fail,
360            failure_details,
361        }
362    }
363}
364
365pub struct OutcomesDisplay<'a> {
366    inner: &'a Outcomes,
367    no_fail: bool,
368    failure_details: bool,
369}
370
371impl<'a> fmt::Display for OutcomesDisplay<'a> {
372    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373        let total: usize = self.inner.stats.iter().sum();
374        if self.failure_details
375            && (self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] != total
376                || self.no_fail)
377        {
378            for outcome in &self.inner.details {
379                writeln!(f, "{}", outcome)?;
380            }
381            Ok(())
382        } else {
383            write!(
384                f,
385                "{}:",
386                if self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] == total {
387                    "PASS"
388                } else if self.no_fail {
389                    "FAIL-IGNORE"
390                } else {
391                    "FAIL"
392                }
393            )?;
394            static NAMES: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
395                vec![
396                    "unsupported",
397                    "parse-failure",
398                    "plan-failure",
399                    "unexpected-plan-success",
400                    "wrong-number-of-rows-inserted",
401                    "wrong-column-count",
402                    "wrong-column-names",
403                    "output-failure",
404                    "inconsistent-view-outcome",
405                    "bail",
406                    "warning",
407                    "success",
408                    "total",
409                ]
410            });
411            for (i, n) in self.inner.stats.iter().enumerate() {
412                if *n > 0 {
413                    write!(f, " {}={}", NAMES[i], n)?;
414                }
415            }
416            write!(f, " total={}", total)
417        }
418    }
419}
420
421struct QueryInfo {
422    is_select: bool,
423    num_attributes: Option<usize>,
424}
425
426enum PrepareQueryOutcome<'a> {
427    QueryPrepared(QueryInfo),
428    Outcome(Outcome<'a>),
429}
430
431pub struct Runner<'a> {
432    config: &'a RunConfig<'a>,
433    inner: Option<RunnerInner<'a>>,
434}
435
436pub struct RunnerInner<'a> {
437    server_addr: SocketAddr,
438    internal_server_addr: SocketAddr,
439    password_server_addr: SocketAddr,
440    internal_http_server_addr: SocketAddr,
441    // Drop order matters for these fields.
442    client: tokio_postgres::Client,
443    system_client: tokio_postgres::Client,
444    clients: BTreeMap<String, tokio_postgres::Client>,
445    auto_index_tables: bool,
446    auto_index_selects: bool,
447    auto_transactions: bool,
448    enable_table_keys: bool,
449    verbose: bool,
450    stdout: &'a dyn WriteFmt,
451    _shutdown_trigger: trigger::Trigger,
452    _server_thread: JoinOnDropHandle<()>,
453    _temp_dir: TempDir,
454}
455
456#[derive(Debug)]
457pub struct Slt(Value);
458
459impl<'a> FromSql<'a> for Slt {
460    fn from_sql(
461        ty: &PgType,
462        mut raw: &'a [u8],
463    ) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
464        Ok(match *ty {
465            PgType::ACLITEM => Self(Value::AclItem(AclItem::decode_binary(
466                types::bytea_from_sql(raw),
467            )?)),
468            PgType::BOOL => Self(Value::Bool(types::bool_from_sql(raw)?)),
469            PgType::BYTEA => Self(Value::Bytea(types::bytea_from_sql(raw).to_vec())),
470            PgType::CHAR => Self(Value::Char(u8::from_be_bytes(
471                types::char_from_sql(raw)?.to_be_bytes(),
472            ))),
473            PgType::FLOAT4 => Self(Value::Float4(types::float4_from_sql(raw)?)),
474            PgType::FLOAT8 => Self(Value::Float8(types::float8_from_sql(raw)?)),
475            PgType::DATE => Self(Value::Date(Date::from_pg_epoch(types::int4_from_sql(
476                raw,
477            )?)?)),
478            PgType::INT2 => Self(Value::Int2(types::int2_from_sql(raw)?)),
479            PgType::INT4 => Self(Value::Int4(types::int4_from_sql(raw)?)),
480            PgType::INT8 => Self(Value::Int8(types::int8_from_sql(raw)?)),
481            PgType::INTERVAL => Self(Value::Interval(Interval::from_sql(ty, raw)?)),
482            PgType::JSONB => Self(Value::Jsonb(Jsonb::from_sql(ty, raw)?)),
483            PgType::NAME => Self(Value::Name(types::text_from_sql(raw)?.to_string())),
484            PgType::NUMERIC => Self(Value::Numeric(Numeric::from_sql(ty, raw)?)),
485            PgType::OID => Self(Value::Oid(types::oid_from_sql(raw)?)),
486            PgType::REGCLASS => Self(Value::Oid(types::oid_from_sql(raw)?)),
487            PgType::REGPROC => Self(Value::Oid(types::oid_from_sql(raw)?)),
488            PgType::REGTYPE => Self(Value::Oid(types::oid_from_sql(raw)?)),
489            PgType::TEXT | PgType::BPCHAR | PgType::VARCHAR => {
490                Self(Value::Text(types::text_from_sql(raw)?.to_string()))
491            }
492            PgType::TIME => Self(Value::Time(NaiveTime::from_sql(ty, raw)?)),
493            PgType::TIMESTAMP => Self(Value::Timestamp(
494                NaiveDateTime::from_sql(ty, raw)?.try_into()?,
495            )),
496            PgType::TIMESTAMPTZ => Self(Value::TimestampTz(
497                DateTime::<Utc>::from_sql(ty, raw)?.try_into()?,
498            )),
499            PgType::UUID => Self(Value::Uuid(Uuid::from_sql(ty, raw)?)),
500            PgType::RECORD => {
501                let num_fields = read_be_i32(&mut raw)?;
502                let mut tuple = vec![];
503                for _ in 0..num_fields {
504                    let oid = u32::reinterpret_cast(read_be_i32(&mut raw)?);
505                    let typ = match PgType::from_oid(oid) {
506                        Some(typ) => typ,
507                        None => return Err("unknown oid".into()),
508                    };
509                    let v = read_value::<Option<Slt>>(&typ, &mut raw)?;
510                    tuple.push(v.map(|v| v.0));
511                }
512                Self(Value::Record(tuple))
513            }
514            PgType::INT4_RANGE
515            | PgType::INT8_RANGE
516            | PgType::DATE_RANGE
517            | PgType::NUM_RANGE
518            | PgType::TS_RANGE
519            | PgType::TSTZ_RANGE => {
520                use mz_repr::adt::range::Range;
521                let range: Range<Slt> = Range::from_sql(ty, raw)?;
522                Self(Value::Range(range.into_bounds(|b| Box::new(b.0))))
523            }
524
525            _ => match ty.kind() {
526                PgKind::Array(arr_type) => {
527                    let arr = types::array_from_sql(raw)?;
528                    let elements: Vec<Option<Value>> = arr
529                        .values()
530                        .map(|v| match v {
531                            Some(v) => Ok(Some(Slt::from_sql(arr_type, v)?)),
532                            None => Ok(None),
533                        })
534                        .collect::<Vec<Option<Slt>>>()?
535                        .into_iter()
536                        // Map a Vec<Option<Slt>> to Vec<Option<Value>>.
537                        .map(|v| v.map(|v| v.0))
538                        .collect();
539
540                    Self(Value::Array {
541                        dims: arr
542                            .dimensions()
543                            .map(|d| {
544                                Ok(mz_repr::adt::array::ArrayDimension {
545                                    lower_bound: isize::cast_from(d.lower_bound),
546                                    length: usize::try_from(d.len)
547                                        .expect("cannot have negative length"),
548                                })
549                            })
550                            .collect()?,
551                        elements,
552                    })
553                }
554                _ => match ty.oid() {
555                    oid::TYPE_UINT2_OID => Self(Value::UInt2(UInt2::from_sql(ty, raw)?)),
556                    oid::TYPE_UINT4_OID => Self(Value::UInt4(UInt4::from_sql(ty, raw)?)),
557                    oid::TYPE_UINT8_OID => Self(Value::UInt8(UInt8::from_sql(ty, raw)?)),
558                    oid::TYPE_MZ_TIMESTAMP_OID => {
559                        let s = types::text_from_sql(raw)?;
560                        let t: mz_repr::Timestamp = s.parse()?;
561                        Self(Value::MzTimestamp(t))
562                    }
563                    oid::TYPE_MZ_ACL_ITEM_OID => Self(Value::MzAclItem(MzAclItem::decode_binary(
564                        types::bytea_from_sql(raw),
565                    )?)),
566                    _ => unreachable!(),
567                },
568            },
569        })
570    }
571    fn accepts(ty: &PgType) -> bool {
572        match ty.kind() {
573            PgKind::Array(_) | PgKind::Composite(_) => return true,
574            _ => {}
575        }
576        match ty.oid() {
577            oid::TYPE_UINT2_OID
578            | oid::TYPE_UINT4_OID
579            | oid::TYPE_UINT8_OID
580            | oid::TYPE_MZ_TIMESTAMP_OID
581            | oid::TYPE_MZ_ACL_ITEM_OID => return true,
582            _ => {}
583        }
584        matches!(
585            *ty,
586            PgType::ACLITEM
587                | PgType::BOOL
588                | PgType::BYTEA
589                | PgType::CHAR
590                | PgType::DATE
591                | PgType::FLOAT4
592                | PgType::FLOAT8
593                | PgType::INT2
594                | PgType::INT4
595                | PgType::INT8
596                | PgType::INTERVAL
597                | PgType::JSONB
598                | PgType::NAME
599                | PgType::NUMERIC
600                | PgType::OID
601                | PgType::REGCLASS
602                | PgType::REGPROC
603                | PgType::REGTYPE
604                | PgType::RECORD
605                | PgType::TEXT
606                | PgType::BPCHAR
607                | PgType::VARCHAR
608                | PgType::TIME
609                | PgType::TIMESTAMP
610                | PgType::TIMESTAMPTZ
611                | PgType::UUID
612                | PgType::INT4_RANGE
613                | PgType::INT4_RANGE_ARRAY
614                | PgType::INT8_RANGE
615                | PgType::INT8_RANGE_ARRAY
616                | PgType::DATE_RANGE
617                | PgType::DATE_RANGE_ARRAY
618                | PgType::NUM_RANGE
619                | PgType::NUM_RANGE_ARRAY
620                | PgType::TS_RANGE
621                | PgType::TS_RANGE_ARRAY
622                | PgType::TSTZ_RANGE
623                | PgType::TSTZ_RANGE_ARRAY
624        )
625    }
626}
627
628// From postgres-types/src/private.rs.
629fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
630    if buf.len() < 4 {
631        return Err("invalid buffer size".into());
632    }
633    let mut bytes = [0; 4];
634    bytes.copy_from_slice(&buf[..4]);
635    *buf = &buf[4..];
636    Ok(i32::from_be_bytes(bytes))
637}
638
639// From postgres-types/src/private.rs.
640fn read_value<'a, T>(type_: &PgType, buf: &mut &'a [u8]) -> Result<T, Box<dyn Error + Sync + Send>>
641where
642    T: FromSql<'a>,
643{
644    let value = match usize::try_from(read_be_i32(buf)?) {
645        Err(_) => None,
646        Ok(len) => {
647            if len > buf.len() {
648                return Err("invalid buffer size".into());
649            }
650            let (head, tail) = buf.split_at(len);
651            *buf = tail;
652            Some(head)
653        }
654    };
655    T::from_sql_nullable(type_, value)
656}
657
658fn format_datum(d: Slt, typ: &Type, mode: Mode, col: usize) -> String {
659    match (typ, d.0) {
660        (Type::Bool, Value::Bool(b)) => b.to_string(),
661
662        (Type::Integer, Value::Int2(i)) => i.to_string(),
663        (Type::Integer, Value::Int4(i)) => i.to_string(),
664        (Type::Integer, Value::Int8(i)) => i.to_string(),
665        (Type::Integer, Value::UInt2(u)) => u.0.to_string(),
666        (Type::Integer, Value::UInt4(u)) => u.0.to_string(),
667        (Type::Integer, Value::UInt8(u)) => u.0.to_string(),
668        (Type::Integer, Value::Oid(i)) => i.to_string(),
669        // TODO(benesch): rewrite to avoid `as`.
670        #[allow(clippy::as_conversions)]
671        (Type::Integer, Value::Float4(f)) => format!("{}", f as i64),
672        // TODO(benesch): rewrite to avoid `as`.
673        #[allow(clippy::as_conversions)]
674        (Type::Integer, Value::Float8(f)) => format!("{}", f as i64),
675        // This is so wrong, but sqlite needs it.
676        (Type::Integer, Value::Text(_)) => "0".to_string(),
677        (Type::Integer, Value::Bool(b)) => i8::from(b).to_string(),
678        (Type::Integer, Value::Numeric(d)) => {
679            let mut d = d.0.0.clone();
680            let mut cx = numeric::cx_datum();
681            // Truncate the decimal to match sqlite.
682            if mode == Mode::Standard {
683                cx.set_rounding(dec::Rounding::Down);
684            }
685            cx.round(&mut d);
686            numeric::munge_numeric(&mut d).unwrap();
687            d.to_standard_notation_string()
688        }
689
690        (Type::Real, Value::Int2(i)) => format!("{:.3}", i),
691        (Type::Real, Value::Int4(i)) => format!("{:.3}", i),
692        (Type::Real, Value::Int8(i)) => format!("{:.3}", i),
693        (Type::Real, Value::Float4(f)) => match mode {
694            Mode::Standard => format!("{:.3}", f),
695            Mode::Cockroach => format!("{}", f),
696        },
697        (Type::Real, Value::Float8(f)) => match mode {
698            Mode::Standard => format!("{:.3}", f),
699            Mode::Cockroach => format!("{}", f),
700        },
701        (Type::Real, Value::Numeric(d)) => match mode {
702            Mode::Standard => {
703                let mut d = d.0.0.clone();
704                if d.exponent() < -3 {
705                    numeric::rescale(&mut d, 3).unwrap();
706                }
707                numeric::munge_numeric(&mut d).unwrap();
708                d.to_standard_notation_string()
709            }
710            Mode::Cockroach => d.0.0.to_standard_notation_string(),
711        },
712
713        (Type::Text, Value::Text(s)) => {
714            if s.is_empty() {
715                "(empty)".to_string()
716            } else {
717                s
718            }
719        }
720        (Type::Text, Value::Bool(b)) => b.to_string(),
721        (Type::Text, Value::Float4(f)) => format!("{:.3}", f),
722        (Type::Text, Value::Float8(f)) => format!("{:.3}", f),
723        // Bytes are printed as text iff they are valid UTF-8. This
724        // seems guaranteed to confuse everyone, but it is required for
725        // compliance with the CockroachDB sqllogictest runner. [0]
726        //
727        // [0]: https://github.com/cockroachdb/cockroach/blob/970782487/pkg/sql/logictest/logic.go#L2038-L2043
728        (Type::Text, Value::Bytea(b)) => match str::from_utf8(&b) {
729            Ok(s) => s.to_string(),
730            Err(_) => format!("{:?}", b),
731        },
732        (Type::Text, Value::Numeric(d)) => d.0.0.to_standard_notation_string(),
733        // Everything else gets normal text encoding. This correctly handles things
734        // like arrays, tuples, and strings that need to be quoted.
735        (Type::Text, d) => {
736            let mut buf = BytesMut::new();
737            d.encode_text(&mut buf);
738            String::from_utf8_lossy(&buf).into_owned()
739        }
740
741        (Type::Oid, Value::Oid(o)) => o.to_string(),
742
743        (_, d) => panic!(
744            "Don't know how to format {:?} as {:?} in column {}",
745            d, typ, col,
746        ),
747    }
748}
749
750fn format_row(row: &Row, types: &[Type], mode: Mode) -> Vec<String> {
751    let mut formatted: Vec<String> = vec![];
752    for i in 0..row.len() {
753        let t: Option<Slt> = row.get::<usize, Option<Slt>>(i);
754        let t: Option<String> = t.map(|d| format_datum(d, &types[i], mode, i));
755        formatted.push(match t {
756            Some(t) => t,
757            None => "NULL".into(),
758        });
759    }
760
761    formatted
762}
763
764impl<'a> Runner<'a> {
765    pub async fn start(config: &'a RunConfig<'a>) -> Result<Runner<'a>, anyhow::Error> {
766        let mut runner = Self {
767            config,
768            inner: None,
769        };
770        runner.reset().await?;
771        Ok(runner)
772    }
773
774    pub async fn reset(&mut self) -> Result<(), anyhow::Error> {
775        // Explicitly drop the old runner here to ensure that we wait for threads to terminate
776        // before starting a new runner
777        drop(self.inner.take());
778        self.inner = Some(RunnerInner::start(self.config).await?);
779
780        Ok(())
781    }
782
783    async fn run_record<'r>(
784        &mut self,
785        record: &'r Record<'r>,
786        in_transaction: &mut bool,
787    ) -> Result<Outcome<'r>, anyhow::Error> {
788        if let Record::ResetServer = record {
789            self.reset().await?;
790            Ok(Outcome::Success)
791        } else {
792            self.inner
793                .as_mut()
794                .expect("RunnerInner missing")
795                .run_record(record, in_transaction)
796                .await
797        }
798    }
799
800    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
801        self.inner
802            .as_ref()
803            .expect("RunnerInner missing")
804            .check_catalog()
805            .await
806    }
807
808    async fn reset_database(&mut self) -> Result<(), anyhow::Error> {
809        let inner = self.inner.as_mut().expect("RunnerInner missing");
810
811        inner.client.batch_execute("ROLLBACK;").await?;
812
813        inner
814            .system_client
815            .batch_execute(
816                "ROLLBACK;
817                 SET cluster = mz_catalog_server;
818                 RESET cluster_replica;",
819            )
820            .await?;
821
822        inner
823            .system_client
824            .batch_execute("ALTER SYSTEM RESET ALL")
825            .await?;
826
827        // Drop all databases, then recreate the `materialize` database.
828        for row in inner
829            .system_client
830            .query("SELECT name FROM mz_databases", &[])
831            .await?
832        {
833            let name: &str = row.get("name");
834            inner
835                .system_client
836                .batch_execute(&format!("DROP DATABASE \"{name}\""))
837                .await?;
838        }
839        inner
840            .system_client
841            .batch_execute("CREATE DATABASE materialize")
842            .await?;
843
844        // Ensure quickstart cluster exists with one replica of size `self.config.replica_size`.
845        // We don't destroy the existing quickstart cluster replica if it exists, as turning
846        // on a cluster replica is exceptionally slow.
847        let mut needs_default_cluster = true;
848        for row in inner
849            .system_client
850            .query("SELECT name FROM mz_clusters WHERE id LIKE 'u%'", &[])
851            .await?
852        {
853            match row.get("name") {
854                "quickstart" => needs_default_cluster = false,
855                name => {
856                    inner
857                        .system_client
858                        .batch_execute(&format!("DROP CLUSTER {name}"))
859                        .await?
860                }
861            }
862        }
863        if needs_default_cluster {
864            inner
865                .system_client
866                .batch_execute("CREATE CLUSTER quickstart REPLICAS ()")
867                .await?;
868        }
869        let mut needs_default_replica = false;
870        let rows = inner
871            .system_client
872            .query(
873                "SELECT name, size FROM mz_cluster_replicas
874                 WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')
875                 ORDER BY name",
876                &[],
877            )
878            .await?;
879        if rows.len() != self.config.replicas {
880            needs_default_replica = true;
881        } else {
882            for (i, row) in rows.iter().enumerate() {
883                let name: &str = row.get("name");
884                let size: &str = row.get("size");
885                if name != format!("r{i}") || size != self.config.replica_size {
886                    needs_default_replica = true;
887                    break;
888                }
889            }
890        }
891
892        if needs_default_replica {
893            inner
894                .system_client
895                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = false)")
896                .await?;
897            for row in inner
898                .system_client
899                .query(
900                    "SELECT name FROM mz_cluster_replicas
901                     WHERE cluster_id = (SELECT id FROM mz_clusters WHERE name = 'quickstart')",
902                    &[],
903                )
904                .await?
905            {
906                let name: &str = row.get("name");
907                inner
908                    .system_client
909                    .batch_execute(&format!("DROP CLUSTER REPLICA quickstart.{}", name))
910                    .await?;
911            }
912            for i in 1..=self.config.replicas {
913                inner
914                    .system_client
915                    .batch_execute(&format!(
916                        "CREATE CLUSTER REPLICA quickstart.r{i} SIZE '{}'",
917                        self.config.replica_size
918                    ))
919                    .await?;
920            }
921            inner
922                .system_client
923                .batch_execute("ALTER CLUSTER quickstart SET (MANAGED = true)")
924                .await?;
925        }
926
927        // Grant initial privileges.
928        inner
929            .system_client
930            .batch_execute("GRANT USAGE ON DATABASE materialize TO PUBLIC")
931            .await?;
932        inner
933            .system_client
934            .batch_execute("GRANT CREATE ON DATABASE materialize TO materialize")
935            .await?;
936        inner
937            .system_client
938            .batch_execute("GRANT CREATE ON SCHEMA materialize.public TO materialize")
939            .await?;
940        inner
941            .system_client
942            .batch_execute("GRANT USAGE ON CLUSTER quickstart TO PUBLIC")
943            .await?;
944        inner
945            .system_client
946            .batch_execute("GRANT CREATE ON CLUSTER quickstart TO materialize")
947            .await?;
948
949        // Some sqllogictests require more than the default amount of tables, so we increase the
950        // limit for all tests.
951        inner
952            .system_client
953            .simple_query("ALTER SYSTEM SET max_tables = 100")
954            .await?;
955
956        if inner.enable_table_keys {
957            inner
958                .system_client
959                .simple_query("ALTER SYSTEM SET unsafe_enable_table_keys = true")
960                .await?;
961        }
962
963        inner.ensure_fixed_features().await?;
964
965        inner.client = connect(inner.server_addr, None, None).await.unwrap();
966        inner.system_client = connect(inner.internal_server_addr, Some("mz_system"), None)
967            .await
968            .unwrap();
969        inner.clients = BTreeMap::new();
970
971        Ok(())
972    }
973}
974
975impl<'a> RunnerInner<'a> {
976    pub async fn start(config: &RunConfig<'a>) -> Result<RunnerInner<'a>, anyhow::Error> {
977        let temp_dir = tempfile::tempdir()?;
978        let scratch_dir = tempfile::tempdir()?;
979        let environment_id = EnvironmentId::for_tests();
980        let (consensus_uri, timestamp_oracle_url): (SensitiveUrl, SensitiveUrl) = {
981            let postgres_url = &config.postgres_url;
982            let prefix = &config.prefix;
983            info!(%postgres_url, "starting server");
984            let (client, conn) = Retry::default()
985                .max_tries(5)
986                .retry_async(|_| async {
987                    match tokio_postgres::connect(postgres_url, NoTls).await {
988                        Ok(c) => Ok(c),
989                        Err(e) => {
990                            error!(%e, "failed to connect to postgres");
991                            Err(e)
992                        }
993                    }
994                })
995                .await?;
996            task::spawn(|| "sqllogictest_connect", async move {
997                if let Err(e) = conn.await {
998                    panic!("connection error: {}", e);
999                }
1000            });
1001            client
1002                .batch_execute(&format!(
1003                    "DROP SCHEMA IF EXISTS {prefix}_tsoracle CASCADE;
1004                     CREATE SCHEMA IF NOT EXISTS {prefix}_consensus;
1005                     CREATE SCHEMA {prefix}_tsoracle;"
1006                ))
1007                .await?;
1008            (
1009                format!("{postgres_url}?options=--search_path={prefix}_consensus")
1010                    .parse()
1011                    .expect("invalid consensus URI"),
1012                format!("{postgres_url}?options=--search_path={prefix}_tsoracle")
1013                    .parse()
1014                    .expect("invalid timestamp oracle URI"),
1015            )
1016        };
1017
1018        let secrets_dir = temp_dir.path().join("secrets");
1019        let orchestrator = Arc::new(
1020            ProcessOrchestrator::new(ProcessOrchestratorConfig {
1021                image_dir: env::current_exe()?.parent().unwrap().to_path_buf(),
1022                suppress_output: false,
1023                environment_id: environment_id.to_string(),
1024                secrets_dir: secrets_dir.clone(),
1025                command_wrapper: config
1026                    .orchestrator_process_wrapper
1027                    .as_ref()
1028                    .map_or(Ok(vec![]), |s| shell_words::split(s))?,
1029                propagate_crashes: true,
1030                tcp_proxy: None,
1031                scratch_directory: scratch_dir.path().to_path_buf(),
1032            })
1033            .await?,
1034        );
1035        let now = SYSTEM_TIME.clone();
1036        let metrics_registry = MetricsRegistry::new();
1037
1038        let persist_config = PersistConfig::new(
1039            &mz_environmentd::BUILD_INFO,
1040            now.clone(),
1041            mz_dyncfgs::all_dyncfgs(),
1042        );
1043        let persist_pubsub_server =
1044            PersistGrpcPubSubServer::new(&persist_config, &metrics_registry);
1045        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
1046        let persist_pubsub_tcp_listener =
1047            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
1048                .await
1049                .expect("pubsub addr binding");
1050        let persist_pubsub_server_port = persist_pubsub_tcp_listener
1051            .local_addr()
1052            .expect("pubsub addr has local addr")
1053            .port();
1054        info!("listening for persist pubsub connections on localhost:{persist_pubsub_server_port}");
1055        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
1056            persist_pubsub_server
1057                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
1058                .await
1059                .expect("success")
1060        });
1061        let persist_clients =
1062            PersistClientCache::new(persist_config, &metrics_registry, |cfg, metrics| {
1063                let sender: Arc<dyn PubSubSender> = Arc::new(MetricsSameProcessPubSubSender::new(
1064                    cfg,
1065                    persist_pubsub_client.sender,
1066                    metrics,
1067                ));
1068                PubSubClientConnection::new(sender, persist_pubsub_client.receiver)
1069            });
1070        let persist_clients = Arc::new(persist_clients);
1071
1072        let secrets_controller = Arc::clone(&orchestrator);
1073        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
1074        let orchestrator = Arc::new(TracingOrchestrator::new(
1075            orchestrator,
1076            config.tracing.clone(),
1077        ));
1078        let listeners_config = ListenersConfig {
1079            sql: btreemap! {
1080                "external".to_owned() => SqlListenerConfig {
1081                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1082                    authenticator_kind: AuthenticatorKind::None,
1083                    allowed_roles: AllowedRoles::Normal,
1084                    enable_tls: false,
1085                },
1086                "internal".to_owned() => SqlListenerConfig {
1087                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1088                    authenticator_kind: AuthenticatorKind::None,
1089                    allowed_roles: AllowedRoles::Internal,
1090                    enable_tls: false,
1091                },
1092                "password".to_owned() => SqlListenerConfig {
1093                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1094                    authenticator_kind: AuthenticatorKind::Password,
1095                    allowed_roles: AllowedRoles::Normal,
1096                    enable_tls: false,
1097                },
1098            },
1099            http: btreemap![
1100                "external".to_owned() => HttpListenerConfig {
1101                    base: BaseListenerConfig {
1102                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1103                        authenticator_kind: AuthenticatorKind::None,
1104                        allowed_roles: AllowedRoles::Normal,
1105                        enable_tls: false
1106                    },
1107                    routes: HttpRoutesEnabled {
1108                        base: true,
1109                        webhook: true,
1110                        internal: false,
1111                        metrics: false,
1112                        profiling: false,
1113                        mcp_agents: false,
1114                        mcp_observatory: false,
1115                    },
1116                },
1117                "internal".to_owned() => HttpListenerConfig {
1118                    base: BaseListenerConfig {
1119                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
1120                        authenticator_kind: AuthenticatorKind::None,
1121                        allowed_roles: AllowedRoles::NormalAndInternal,
1122                        enable_tls: false
1123                    },
1124                    routes: HttpRoutesEnabled {
1125                        base: true,
1126                        webhook: true,
1127                        internal: true,
1128                        metrics: true,
1129                        profiling: true,
1130                        mcp_agents: false,
1131                        mcp_observatory: false,
1132                    },
1133                },
1134            ],
1135        };
1136        let listeners = mz_environmentd::Listeners::bind(listeners_config).await?;
1137        let host_name = format!(
1138            "localhost:{}",
1139            listeners.http["external"].handle.local_addr.port()
1140        );
1141        let catalog_config = CatalogConfig {
1142            persist_clients: Arc::clone(&persist_clients),
1143            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
1144        };
1145        let server_config = mz_environmentd::Config {
1146            catalog_config,
1147            timestamp_oracle_url: Some(timestamp_oracle_url),
1148            controller: ControllerConfig {
1149                build_info: &mz_environmentd::BUILD_INFO,
1150                orchestrator,
1151                clusterd_image: "clusterd".into(),
1152                init_container_image: None,
1153                deploy_generation: 0,
1154                persist_location: PersistLocation {
1155                    blob_uri: format!(
1156                        "file://{}/persist/blob",
1157                        config.persist_dir.path().display()
1158                    )
1159                    .parse()
1160                    .expect("invalid blob URI"),
1161                    consensus_uri,
1162                },
1163                persist_clients,
1164                now: SYSTEM_TIME.clone(),
1165                metrics_registry: metrics_registry.clone(),
1166                persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
1167                secrets_args: mz_service::secrets::SecretsReaderCliArgs {
1168                    secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
1169                    secrets_reader_local_file_dir: Some(secrets_dir),
1170                    secrets_reader_kubernetes_context: None,
1171                    secrets_reader_aws_prefix: None,
1172                    secrets_reader_name_prefix: None,
1173                },
1174                connection_context,
1175                replica_http_locator: Arc::new(ReplicaHttpLocator::default()),
1176            },
1177            secrets_controller,
1178            cloud_resource_controller: None,
1179            tls: None,
1180            frontegg: None,
1181            cors_allowed_origin: AllowOrigin::list([]),
1182            unsafe_mode: true,
1183            all_features: false,
1184            metrics_registry,
1185            now,
1186            environment_id,
1187            cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
1188            bootstrap_default_cluster_replica_size: config.replica_size.clone(),
1189            bootstrap_default_cluster_replication_factor: config
1190                .replicas
1191                .try_into()
1192                .expect("replicas must fit"),
1193            bootstrap_builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
1194                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1195                size: config.replica_size.clone(),
1196            },
1197            bootstrap_builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
1198                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1199                size: config.replica_size.clone(),
1200            },
1201            bootstrap_builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
1202                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1203                size: config.replica_size.clone(),
1204            },
1205            bootstrap_builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
1206                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1207                size: config.replica_size.clone(),
1208            },
1209            bootstrap_builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
1210                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1211                size: config.replica_size.clone(),
1212            },
1213            system_parameter_defaults: {
1214                let mut params = BTreeMap::new();
1215                params.insert(
1216                    "log_filter".to_string(),
1217                    config.tracing.startup_log_filter.to_string(),
1218                );
1219                params.extend(config.system_parameter_defaults.clone());
1220                params
1221            },
1222            availability_zones: Default::default(),
1223            tracing_handle: config.tracing_handle.clone(),
1224            storage_usage_collection_interval: Duration::from_secs(3600),
1225            storage_usage_retention_period: None,
1226            segment_api_key: None,
1227            segment_client_side: false,
1228            // SLT doesn't like eternally running tasks since it waits for them to finish inbetween SLT files
1229            test_only_dummy_segment_client: false,
1230            egress_addresses: vec![],
1231            aws_account_id: None,
1232            aws_privatelink_availability_zones: None,
1233            launchdarkly_sdk_key: None,
1234            launchdarkly_key_map: Default::default(),
1235            config_sync_file_path: None,
1236            config_sync_timeout: Duration::from_secs(30),
1237            config_sync_loop_interval: None,
1238            bootstrap_role: Some("materialize".into()),
1239            http_host_name: Some(host_name),
1240            internal_console_redirect_url: None,
1241            tls_reload_certs: mz_server_core::cert_reload_never_reload(),
1242            helm_chart_version: None,
1243            license_key: ValidatedLicenseKey::for_tests(),
1244            external_login_password_mz_system: None,
1245            force_builtin_schema_migration: None,
1246        };
1247        // We need to run the server on its own Tokio runtime, which in turn
1248        // requires its own thread, so that we can wait for any tasks spawned
1249        // by the server to be shutdown at the end of each file. If we were to
1250        // share a Tokio runtime, tasks from the last file's server would still
1251        // be live at the start of the next file's server.
1252        let (server_addr_tx, server_addr_rx): (oneshot::Sender<Result<_, anyhow::Error>>, _) =
1253            oneshot::channel();
1254        let (internal_server_addr_tx, internal_server_addr_rx) = oneshot::channel();
1255        let (password_server_addr_tx, password_server_addr_rx) = oneshot::channel();
1256        let (internal_http_server_addr_tx, internal_http_server_addr_rx) = oneshot::channel();
1257        let (shutdown_trigger, shutdown_trigger_rx) = trigger::channel();
1258        let server_thread = thread::spawn(|| {
1259            let runtime = match Runtime::new() {
1260                Ok(runtime) => runtime,
1261                Err(e) => {
1262                    server_addr_tx
1263                        .send(Err(e.into()))
1264                        .expect("receiver should not drop first");
1265                    return;
1266                }
1267            };
1268            let server = match runtime.block_on(listeners.serve(server_config)) {
1269                Ok(runtime) => runtime,
1270                Err(e) => {
1271                    server_addr_tx
1272                        .send(Err(e.into()))
1273                        .expect("receiver should not drop first");
1274                    return;
1275                }
1276            };
1277            server_addr_tx
1278                .send(Ok(server.sql_listener_handles["external"].local_addr))
1279                .expect("receiver should not drop first");
1280            internal_server_addr_tx
1281                .send(server.sql_listener_handles["internal"].local_addr)
1282                .expect("receiver should not drop first");
1283            password_server_addr_tx
1284                .send(server.sql_listener_handles["password"].local_addr)
1285                .expect("receiver should not drop first");
1286            internal_http_server_addr_tx
1287                .send(server.http_listener_handles["internal"].local_addr)
1288                .expect("receiver should not drop first");
1289            let _ = runtime.block_on(shutdown_trigger_rx);
1290        });
1291        let server_addr = server_addr_rx.await??;
1292        let internal_server_addr = internal_server_addr_rx.await?;
1293        let password_server_addr = password_server_addr_rx.await?;
1294        let internal_http_server_addr = internal_http_server_addr_rx.await?;
1295
1296        let system_client = connect(internal_server_addr, Some("mz_system"), None)
1297            .await
1298            .unwrap();
1299        let client = connect(server_addr, None, None).await.unwrap();
1300
1301        let inner = RunnerInner {
1302            server_addr,
1303            internal_server_addr,
1304            password_server_addr,
1305            internal_http_server_addr,
1306            _shutdown_trigger: shutdown_trigger,
1307            _server_thread: server_thread.join_on_drop(),
1308            _temp_dir: temp_dir,
1309            client,
1310            system_client,
1311            clients: BTreeMap::new(),
1312            auto_index_tables: config.auto_index_tables,
1313            auto_index_selects: config.auto_index_selects,
1314            auto_transactions: config.auto_transactions,
1315            enable_table_keys: config.enable_table_keys,
1316            verbose: config.verbose,
1317            stdout: config.stdout,
1318        };
1319        inner.ensure_fixed_features().await?;
1320
1321        Ok(inner)
1322    }
1323
1324    /// Set features that should be enabled regardless of whether reset-server was
1325    /// called. These features may be set conditionally depending on the run configuration.
1326    async fn ensure_fixed_features(&self) -> Result<(), anyhow::Error> {
1327        // We turn on enable_reduce_mfp_fusion, as we wish
1328        // to get as much coverage of these features as we can.
1329        // TODO(vmarcos): Remove this code when we retire this feature flag.
1330        self.system_client
1331            .execute("ALTER SYSTEM SET enable_reduce_mfp_fusion = on", &[])
1332            .await?;
1333
1334        // Dangerous functions are useful for tests so we enable it for all tests.
1335        self.system_client
1336            .execute("ALTER SYSTEM SET unsafe_enable_unsafe_functions = on", &[])
1337            .await?;
1338        Ok(())
1339    }
1340
1341    async fn run_record<'r>(
1342        &mut self,
1343        record: &'r Record<'r>,
1344        in_transaction: &mut bool,
1345    ) -> Result<Outcome<'r>, anyhow::Error> {
1346        match &record {
1347            Record::Statement {
1348                expected_error,
1349                rows_affected,
1350                sql,
1351                location,
1352            } => {
1353                if self.auto_transactions && *in_transaction {
1354                    self.client.execute("COMMIT", &[]).await?;
1355                    *in_transaction = false;
1356                }
1357                match self
1358                    .run_statement(*expected_error, *rows_affected, sql, location.clone())
1359                    .await?
1360                {
1361                    Outcome::Success => {
1362                        if self.auto_index_tables {
1363                            let additional = mutate(sql);
1364                            for stmt in additional {
1365                                self.client.execute(&stmt, &[]).await?;
1366                            }
1367                        }
1368                        Ok(Outcome::Success)
1369                    }
1370                    other => {
1371                        if expected_error.is_some() {
1372                            Ok(other)
1373                        } else {
1374                            // If we failed to execute a statement that was supposed to succeed,
1375                            // running the rest of the tests in this file will probably cause
1376                            // false positives, so just give up on the file entirely.
1377                            Ok(Outcome::Bail {
1378                                cause: Box::new(other),
1379                                location: location.clone(),
1380                            })
1381                        }
1382                    }
1383                }
1384            }
1385            Record::Query {
1386                sql,
1387                output,
1388                location,
1389            } => {
1390                self.run_query(sql, output, location.clone(), in_transaction)
1391                    .await
1392            }
1393            Record::Simple {
1394                conn,
1395                user,
1396                password,
1397                sql,
1398                sort,
1399                output,
1400                location,
1401                ..
1402            } => {
1403                self.run_simple(
1404                    *conn,
1405                    *user,
1406                    *password,
1407                    sql,
1408                    sort.clone(),
1409                    output,
1410                    location.clone(),
1411                )
1412                .await
1413            }
1414            Record::Copy {
1415                table_name,
1416                tsv_path,
1417            } => {
1418                let tsv = tokio::fs::read(tsv_path).await?;
1419                let copy = self
1420                    .client
1421                    .copy_in(&*format!("COPY {} FROM STDIN", table_name))
1422                    .await?;
1423                tokio::pin!(copy);
1424                copy.send(bytes::Bytes::from(tsv)).await?;
1425                copy.finish().await?;
1426                Ok(Outcome::Success)
1427            }
1428            _ => Ok(Outcome::Success),
1429        }
1430    }
1431
1432    async fn run_statement<'r>(
1433        &self,
1434        expected_error: Option<&'r str>,
1435        expected_rows_affected: Option<u64>,
1436        sql: &'r str,
1437        location: Location,
1438    ) -> Result<Outcome<'r>, anyhow::Error> {
1439        static UNSUPPORTED_INDEX_STATEMENT_REGEX: LazyLock<Regex> =
1440            LazyLock::new(|| Regex::new("^(CREATE UNIQUE INDEX|REINDEX)").unwrap());
1441        if UNSUPPORTED_INDEX_STATEMENT_REGEX.is_match(sql) {
1442            // sure, we totally made you an index
1443            return Ok(Outcome::Success);
1444        }
1445
1446        match self.client.execute(sql, &[]).await {
1447            Ok(actual) => {
1448                if let Some(expected_error) = expected_error {
1449                    return Ok(Outcome::UnexpectedPlanSuccess {
1450                        expected_error,
1451                        location,
1452                    });
1453                }
1454                match expected_rows_affected {
1455                    None => Ok(Outcome::Success),
1456                    Some(expected) => {
1457                        if expected != actual {
1458                            Ok(Outcome::WrongNumberOfRowsInserted {
1459                                expected_count: expected,
1460                                actual_count: actual,
1461                                location,
1462                            })
1463                        } else {
1464                            Ok(Outcome::Success)
1465                        }
1466                    }
1467                }
1468            }
1469            Err(error) => {
1470                if let Some(expected_error) = expected_error {
1471                    if Regex::new(expected_error)?.is_match(&error.to_string_with_causes()) {
1472                        return Ok(Outcome::Success);
1473                    }
1474                }
1475                Ok(Outcome::PlanFailure {
1476                    error: anyhow!(error),
1477                    location,
1478                })
1479            }
1480        }
1481    }
1482
1483    async fn prepare_query<'r>(
1484        &self,
1485        sql: &str,
1486        output: &'r Result<QueryOutput<'_>, &'r str>,
1487        location: Location,
1488        in_transaction: &mut bool,
1489    ) -> Result<PrepareQueryOutcome<'r>, anyhow::Error> {
1490        // get statement
1491        let statements = match mz_sql::parse::parse(sql) {
1492            Ok(statements) => statements,
1493            Err(e) => match output {
1494                Ok(_) => {
1495                    return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1496                        error: e.into(),
1497                        location,
1498                    }));
1499                }
1500                Err(expected_error) => {
1501                    if Regex::new(expected_error)?.is_match(&e.to_string_with_causes()) {
1502                        return Ok(PrepareQueryOutcome::Outcome(Outcome::Success));
1503                    } else {
1504                        return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1505                            error: e.into(),
1506                            location,
1507                        }));
1508                    }
1509                }
1510            },
1511        };
1512        let statement = match &*statements {
1513            [] => bail!("Got zero statements?"),
1514            [statement] => &statement.ast,
1515            _ => bail!("Got multiple statements: {:?}", statements),
1516        };
1517        let (is_select, num_attributes) = match statement {
1518            Statement::Select(stmt) => (true, derive_num_attributes(&stmt.query.body)),
1519            _ => (false, None),
1520        };
1521
1522        match output {
1523            Ok(_) => {
1524                if self.auto_transactions && !*in_transaction {
1525                    // No ISOLATION LEVEL SERIALIZABLE because of database-issues#5323
1526                    self.client.execute("BEGIN", &[]).await?;
1527                    *in_transaction = true;
1528                }
1529            }
1530            Err(_) => {
1531                if self.auto_transactions && *in_transaction {
1532                    self.client.execute("COMMIT", &[]).await?;
1533                    *in_transaction = false;
1534                }
1535            }
1536        }
1537
1538        // `SHOW` commands reference catalog schema, thus are not in the same timedomain and not
1539        // allowed in the same transaction, see:
1540        // https://materialize.com/docs/sql/begin/#same-timedomain-error
1541        match statement {
1542            Statement::Show(..) => {
1543                if self.auto_transactions && *in_transaction {
1544                    self.client.execute("COMMIT", &[]).await?;
1545                    *in_transaction = false;
1546                }
1547            }
1548            _ => (),
1549        }
1550        Ok(PrepareQueryOutcome::QueryPrepared(QueryInfo {
1551            is_select,
1552            num_attributes,
1553        }))
1554    }
1555
1556    async fn execute_query<'r>(
1557        &self,
1558        sql: &str,
1559        output: &'r Result<QueryOutput<'_>, &'r str>,
1560        location: Location,
1561    ) -> Result<Outcome<'r>, anyhow::Error> {
1562        let rows = match self.client.query(sql, &[]).await {
1563            Ok(rows) => rows,
1564            Err(error) => {
1565                let error_string = error.to_string_with_causes();
1566                return match output {
1567                    Ok(_) => {
1568                        if error_string.contains("supported") || error_string.contains("overload") {
1569                            // this is a failure, but it's caused by lack of support rather than by bugs
1570                            Ok(Outcome::Unsupported {
1571                                error: anyhow!(error),
1572                                location,
1573                            })
1574                        } else {
1575                            Ok(Outcome::PlanFailure {
1576                                error: anyhow!(error),
1577                                location,
1578                            })
1579                        }
1580                    }
1581                    Err(expected_error) => {
1582                        if Regex::new(expected_error)?.is_match(&error_string) {
1583                            Ok(Outcome::Success)
1584                        } else {
1585                            Ok(Outcome::PlanFailure {
1586                                error: anyhow!(
1587                                    "error does not match expected pattern:\n  expected: /{}/\n  actual:    {}",
1588                                    expected_error,
1589                                    error_string
1590                                ),
1591                                location,
1592                            })
1593                        }
1594                    }
1595                };
1596            }
1597        };
1598
1599        // unpack expected output
1600        let QueryOutput {
1601            sort,
1602            types: expected_types,
1603            column_names: expected_column_names,
1604            output: expected_output,
1605            mode,
1606            ..
1607        } = match output {
1608            Err(expected_error) => {
1609                return Ok(Outcome::UnexpectedPlanSuccess {
1610                    expected_error,
1611                    location,
1612                });
1613            }
1614            Ok(query_output) => query_output,
1615        };
1616
1617        // format output
1618        let mut formatted_rows = vec![];
1619        for row in &rows {
1620            if row.len() != expected_types.len() {
1621                return Ok(Outcome::WrongColumnCount {
1622                    expected_count: expected_types.len(),
1623                    actual_count: row.len(),
1624                    location,
1625                });
1626            }
1627            let row = format_row(row, expected_types, *mode);
1628            formatted_rows.push(row);
1629        }
1630
1631        // sort formatted output
1632        if let Sort::Row = sort {
1633            formatted_rows.sort();
1634        }
1635        let mut values = formatted_rows.into_iter().flatten().collect::<Vec<_>>();
1636        if let Sort::Value = sort {
1637            values.sort();
1638        }
1639
1640        // Various checks as long as there are returned rows.
1641        if let Some(row) = rows.get(0) {
1642            // check column names
1643            if let Some(expected_column_names) = expected_column_names {
1644                let actual_column_names = row
1645                    .columns()
1646                    .iter()
1647                    .map(|t| ColumnName::from(t.name()))
1648                    .collect::<Vec<_>>();
1649                if expected_column_names != &actual_column_names {
1650                    return Ok(Outcome::WrongColumnNames {
1651                        expected_column_names,
1652                        actual_column_names,
1653                        actual_output: Output::Values(values),
1654                        location,
1655                    });
1656                }
1657            }
1658        }
1659
1660        // check output
1661        match expected_output {
1662            Output::Values(expected_values) => {
1663                if values != *expected_values {
1664                    return Ok(Outcome::OutputFailure {
1665                        expected_output,
1666                        actual_raw_output: rows,
1667                        actual_output: Output::Values(values),
1668                        location,
1669                    });
1670                }
1671            }
1672            Output::Hashed {
1673                num_values,
1674                md5: expected_md5,
1675            } => {
1676                let mut hasher = Md5::new();
1677                for value in &values {
1678                    hasher.update(value);
1679                    hasher.update("\n");
1680                }
1681                let md5 = format!("{:x}", hasher.finalize());
1682                if values.len() != *num_values || md5 != *expected_md5 {
1683                    return Ok(Outcome::OutputFailure {
1684                        expected_output,
1685                        actual_raw_output: rows,
1686                        actual_output: Output::Hashed {
1687                            num_values: values.len(),
1688                            md5,
1689                        },
1690                        location,
1691                    });
1692                }
1693            }
1694        }
1695
1696        Ok(Outcome::Success)
1697    }
1698
1699    async fn execute_view_inner<'r>(
1700        &self,
1701        sql: &str,
1702        output: &'r Result<QueryOutput<'_>, &'r str>,
1703        location: Location,
1704    ) -> Result<Option<Outcome<'r>>, anyhow::Error> {
1705        print_sql_if(self.stdout, sql, self.verbose);
1706        let sql_result = self.client.execute(sql, &[]).await;
1707
1708        // Evaluate if we already reached an outcome or not.
1709        let tentative_outcome = if let Err(view_error) = sql_result {
1710            if let Err(expected_error) = output {
1711                if Regex::new(expected_error)?.is_match(&view_error.to_string_with_causes()) {
1712                    Some(Outcome::Success)
1713                } else {
1714                    Some(Outcome::PlanFailure {
1715                        error: anyhow!(
1716                            "error does not match expected pattern:\n  expected: /{}/\n  actual:    {}",
1717                            expected_error,
1718                            view_error.to_string_with_causes()
1719                        ),
1720                        location: location.clone(),
1721                    })
1722                }
1723            } else {
1724                Some(Outcome::PlanFailure {
1725                    error: view_error.into(),
1726                    location: location.clone(),
1727                })
1728            }
1729        } else {
1730            None
1731        };
1732        Ok(tentative_outcome)
1733    }
1734
1735    async fn execute_view<'r>(
1736        &self,
1737        sql: &str,
1738        num_attributes: Option<usize>,
1739        output: &'r Result<QueryOutput<'_>, &'r str>,
1740        location: Location,
1741    ) -> Result<Outcome<'r>, anyhow::Error> {
1742        // Create indexed view SQL commands and execute `CREATE VIEW`.
1743        let expected_column_names = if let Ok(QueryOutput { column_names, .. }) = output {
1744            column_names.clone()
1745        } else {
1746            None
1747        };
1748        let (create_view, create_index, view_sql, drop_view) = generate_view_sql(
1749            sql,
1750            Uuid::new_v4().as_simple(),
1751            num_attributes,
1752            expected_column_names,
1753        );
1754        let tentative_outcome = self
1755            .execute_view_inner(create_view.as_str(), output, location.clone())
1756            .await?;
1757
1758        // Either we already have an outcome or alternatively,
1759        // we proceed to index and query the view.
1760        if let Some(view_outcome) = tentative_outcome {
1761            return Ok(view_outcome);
1762        }
1763
1764        let tentative_outcome = self
1765            .execute_view_inner(create_index.as_str(), output, location.clone())
1766            .await?;
1767
1768        let view_outcome;
1769        if let Some(outcome) = tentative_outcome {
1770            view_outcome = outcome;
1771        } else {
1772            print_sql_if(self.stdout, view_sql.as_str(), self.verbose);
1773            view_outcome = self
1774                .execute_query(view_sql.as_str(), output, location.clone())
1775                .await?;
1776        }
1777
1778        // Remember to clean up after ourselves by dropping the view.
1779        print_sql_if(self.stdout, drop_view.as_str(), self.verbose);
1780        self.client.execute(drop_view.as_str(), &[]).await?;
1781
1782        Ok(view_outcome)
1783    }
1784
1785    async fn run_query<'r>(
1786        &self,
1787        sql: &'r str,
1788        output: &'r Result<QueryOutput<'_>, &'r str>,
1789        location: Location,
1790        in_transaction: &mut bool,
1791    ) -> Result<Outcome<'r>, anyhow::Error> {
1792        let prepare_outcome = self
1793            .prepare_query(sql, output, location.clone(), in_transaction)
1794            .await?;
1795        match prepare_outcome {
1796            PrepareQueryOutcome::QueryPrepared(QueryInfo {
1797                is_select,
1798                num_attributes,
1799            }) => {
1800                let query_outcome = self.execute_query(sql, output, location.clone()).await?;
1801                if is_select && self.auto_index_selects {
1802                    let view_outcome = self
1803                        .execute_view(sql, None, output, location.clone())
1804                        .await?;
1805
1806                    // We compare here the query-based and view-based outcomes.
1807                    // We only produce a test failure if the outcomes are of different
1808                    // variant types, thus accepting smaller deviations in the details
1809                    // produced for each variant.
1810                    if std::mem::discriminant::<Outcome>(&query_outcome)
1811                        != std::mem::discriminant::<Outcome>(&view_outcome)
1812                    {
1813                        // Before producing a failure outcome, we try to obtain a new
1814                        // outcome for view-based execution exploiting analysis of the
1815                        // number of attributes. This two-level strategy can avoid errors
1816                        // produced by column ambiguity in the `SELECT`.
1817                        let view_outcome = if num_attributes.is_some() {
1818                            self.execute_view(sql, num_attributes, output, location.clone())
1819                                .await?
1820                        } else {
1821                            view_outcome
1822                        };
1823
1824                        if std::mem::discriminant::<Outcome>(&query_outcome)
1825                            != std::mem::discriminant::<Outcome>(&view_outcome)
1826                        {
1827                            let inconsistent_view_outcome = Outcome::InconsistentViewOutcome {
1828                                query_outcome: Box::new(query_outcome),
1829                                view_outcome: Box::new(view_outcome),
1830                                location: location.clone(),
1831                            };
1832                            // Determine if this inconsistent view outcome should be reported
1833                            // as an error or only as a warning.
1834                            let outcome = if should_warn(&inconsistent_view_outcome) {
1835                                Outcome::Warning {
1836                                    cause: Box::new(inconsistent_view_outcome),
1837                                    location: location.clone(),
1838                                }
1839                            } else {
1840                                inconsistent_view_outcome
1841                            };
1842                            return Ok(outcome);
1843                        }
1844                    }
1845                }
1846                Ok(query_outcome)
1847            }
1848            PrepareQueryOutcome::Outcome(outcome) => Ok(outcome),
1849        }
1850    }
1851
1852    async fn get_conn(
1853        &mut self,
1854        name: Option<&str>,
1855        user: Option<&str>,
1856        password: Option<&str>,
1857    ) -> Result<&tokio_postgres::Client, tokio_postgres::Error> {
1858        match name {
1859            None => Ok(&self.client),
1860            Some(name) => {
1861                if !self.clients.contains_key(name) {
1862                    let addr = if matches!(user, Some("mz_system") | Some("mz_support")) {
1863                        self.internal_server_addr
1864                    } else if password.is_some() {
1865                        // Use password server for password authentication
1866                        self.password_server_addr
1867                    } else {
1868                        self.server_addr
1869                    };
1870                    let client = connect(addr, user, password).await?;
1871                    self.clients.insert(name.into(), client);
1872                }
1873                Ok(self.clients.get(name).unwrap())
1874            }
1875        }
1876    }
1877
1878    async fn run_simple<'r>(
1879        &mut self,
1880        conn: Option<&'r str>,
1881        user: Option<&'r str>,
1882        password: Option<&'r str>,
1883        sql: &'r str,
1884        sort: Sort,
1885        output: &'r Output,
1886        location: Location,
1887    ) -> Result<Outcome<'r>, anyhow::Error> {
1888        let actual = match self.get_conn(conn, user, password).await {
1889            Ok(client) => match client.simple_query(sql).await {
1890                Ok(result) => {
1891                    let mut rows = Vec::new();
1892
1893                    for m in result.into_iter() {
1894                        match m {
1895                            SimpleQueryMessage::Row(row) => {
1896                                let mut s = vec![];
1897                                for i in 0..row.len() {
1898                                    s.push(row.get(i).unwrap_or("NULL"));
1899                                }
1900                                rows.push(s.join(","));
1901                            }
1902                            SimpleQueryMessage::CommandComplete(count) => {
1903                                // This applies any sort on the COMPLETE line as
1904                                // well, but we do the same for the expected output.
1905                                rows.push(format!("COMPLETE {}", count));
1906                            }
1907                            SimpleQueryMessage::RowDescription(_) => {}
1908                            _ => panic!("unexpected"),
1909                        }
1910                    }
1911
1912                    if let Sort::Row = sort {
1913                        rows.sort();
1914                    }
1915
1916                    Output::Values(rows)
1917                }
1918                // Errors can contain multiple lines (say if there are details), and rewrite
1919                // sticks them each on their own line, so we need to split up the lines here to
1920                // each be its own String in the Vec.
1921                Err(error) => Output::Values(
1922                    error
1923                        .to_string_with_causes()
1924                        .lines()
1925                        .map(|s| s.to_string())
1926                        .collect(),
1927                ),
1928            },
1929            Err(error) => Output::Values(
1930                error
1931                    .to_string_with_causes()
1932                    .lines()
1933                    .map(|s| s.to_string())
1934                    .collect(),
1935            ),
1936        };
1937        if *output != actual {
1938            Ok(Outcome::OutputFailure {
1939                expected_output: output,
1940                actual_raw_output: vec![],
1941                actual_output: actual,
1942                location,
1943            })
1944        } else {
1945            Ok(Outcome::Success)
1946        }
1947    }
1948
1949    async fn check_catalog(&self) -> Result<(), anyhow::Error> {
1950        let url = format!(
1951            "http://{}/api/catalog/check",
1952            self.internal_http_server_addr
1953        );
1954        let response: serde_json::Value = reqwest::get(&url).await?.json().await?;
1955
1956        if let Some(inconsistencies) = response.get("err") {
1957            let inconsistencies = serde_json::to_string_pretty(&inconsistencies)
1958                .expect("serializing Value cannot fail");
1959            Err(anyhow::anyhow!("Catalog inconsistency\n{inconsistencies}"))
1960        } else {
1961            Ok(())
1962        }
1963    }
1964}
1965
1966async fn connect(
1967    addr: SocketAddr,
1968    user: Option<&str>,
1969    password: Option<&str>,
1970) -> Result<tokio_postgres::Client, tokio_postgres::Error> {
1971    let mut config = tokio_postgres::Config::new();
1972    config.host(addr.ip().to_string());
1973    config.port(addr.port());
1974    config.user(user.unwrap_or("materialize"));
1975    if let Some(password) = password {
1976        config.password(password);
1977    }
1978    let (client, connection) = config.connect(NoTls).await?;
1979
1980    task::spawn(|| "sqllogictest_connect", async move {
1981        if let Err(e) = connection.await {
1982            eprintln!("connection error: {}", e);
1983        }
1984    });
1985    Ok(client)
1986}
1987
1988pub trait WriteFmt {
1989    fn write_fmt(&self, fmt: fmt::Arguments<'_>);
1990}
1991
1992pub struct RunConfig<'a> {
1993    pub stdout: &'a dyn WriteFmt,
1994    pub stderr: &'a dyn WriteFmt,
1995    pub verbose: bool,
1996    pub quiet: bool,
1997    pub postgres_url: String,
1998    pub prefix: String,
1999    pub no_fail: bool,
2000    pub fail_fast: bool,
2001    pub auto_index_tables: bool,
2002    pub auto_index_selects: bool,
2003    pub auto_transactions: bool,
2004    pub enable_table_keys: bool,
2005    pub orchestrator_process_wrapper: Option<String>,
2006    pub tracing: TracingCliArgs,
2007    pub tracing_handle: TracingHandle,
2008    pub system_parameter_defaults: BTreeMap<String, String>,
2009    /// Persist state is handled specially because:
2010    /// - Persist background workers do not necessarily shut down immediately once the server is
2011    ///   shut down, and may panic if their storage is deleted out from under them.
2012    /// - It's safe for different databases to reference the same state: all data is scoped by UUID.
2013    pub persist_dir: TempDir,
2014    pub replicas: usize,
2015    pub replica_size: String,
2016}
2017
2018/// Indentation used for verbose output of SQL statements.
2019const PRINT_INDENT: usize = 4;
2020
2021fn print_record(config: &RunConfig<'_>, record: &Record) {
2022    match record {
2023        Record::Statement { sql, .. } | Record::Query { sql, .. } => {
2024            print_sql(config.stdout, sql, None)
2025        }
2026        Record::Simple { conn, sql, .. } => print_sql(config.stdout, sql, *conn),
2027        Record::Copy {
2028            table_name,
2029            tsv_path,
2030        } => {
2031            writeln!(
2032                config.stdout,
2033                "{}slt copy {} from {}",
2034                " ".repeat(PRINT_INDENT),
2035                table_name,
2036                tsv_path
2037            )
2038        }
2039        Record::ResetServer => {
2040            writeln!(config.stdout, "{}reset-server", " ".repeat(PRINT_INDENT))
2041        }
2042        Record::Halt => {
2043            writeln!(config.stdout, "{}halt", " ".repeat(PRINT_INDENT))
2044        }
2045        Record::HashThreshold { threshold } => {
2046            writeln!(
2047                config.stdout,
2048                "{}hash-threshold {}",
2049                " ".repeat(PRINT_INDENT),
2050                threshold
2051            )
2052        }
2053    }
2054}
2055
2056fn print_sql_if<'a>(stdout: &'a dyn WriteFmt, sql: &str, cond: bool) {
2057    if cond {
2058        print_sql(stdout, sql, None)
2059    }
2060}
2061
2062fn print_sql<'a>(stdout: &'a dyn WriteFmt, sql: &str, conn: Option<&str>) {
2063    let text = if let Some(conn) = conn {
2064        format!("[conn={}] {}", conn, sql)
2065    } else {
2066        sql.to_string()
2067    };
2068    writeln!(stdout, "{}", util::indent(&text, PRINT_INDENT))
2069}
2070
2071/// Regular expressions for matching error messages that should force a plan failure
2072/// in an inconsistent view outcome into a warning if the corresponding query succeeds.
2073const INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS: [&str; 9] = [
2074    // The following are unfixable errors in indexed views given our
2075    // current constraints.
2076    "cannot materialize call to",
2077    "SHOW commands are not allowed in views",
2078    "cannot create view with unstable dependencies",
2079    "cannot use wildcard expansions or NATURAL JOINs in a view that depends on system objects",
2080    "no valid schema selected",
2081    r#"system schema '\w+' cannot be modified"#,
2082    r#"permission denied for (SCHEMA|CLUSTER) "(\w+\.)?\w+""#,
2083    // NOTE(vmarcos): Column ambiguity that could not be eliminated by our
2084    // currently implemented syntactic rewrites is considered unfixable.
2085    // In addition, if some column cannot be dealt with, e.g., in `ORDER BY`
2086    // references, we treat this condition as unfixable as well.
2087    r#"column "[\w\?]+" specified more than once"#,
2088    r#"column "(\w+\.)?\w+" does not exist"#,
2089];
2090
2091/// Evaluates if the given outcome should be returned directly or if it should
2092/// be wrapped as a warning. Note that this function should be used for outcomes
2093/// that can be judged in a context-independent manner, i.e., the outcome itself
2094/// provides enough information as to whether a warning should be emitted or not.
2095fn should_warn(outcome: &Outcome) -> bool {
2096    match outcome {
2097        Outcome::InconsistentViewOutcome {
2098            query_outcome,
2099            view_outcome,
2100            ..
2101        } => match (query_outcome.as_ref(), view_outcome.as_ref()) {
2102            (Outcome::Success, Outcome::PlanFailure { error, .. }) => {
2103                INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS.iter().any(|s| {
2104                    Regex::new(s)
2105                        .expect("unexpected error in regular expression parsing")
2106                        .is_match(&error.to_string_with_causes())
2107                })
2108            }
2109            _ => false,
2110        },
2111        _ => false,
2112    }
2113}
2114
2115pub async fn run_string(
2116    runner: &mut Runner<'_>,
2117    source: &str,
2118    input: &str,
2119) -> Result<Outcomes, anyhow::Error> {
2120    runner.reset_database().await?;
2121
2122    let mut outcomes = Outcomes::default();
2123    let mut parser = crate::parser::Parser::new(source, input);
2124    // Transactions are currently relatively slow. Since sqllogictest runs in a single connection
2125    // there should be no difference in having longer running transactions.
2126    let mut in_transaction = false;
2127    writeln!(runner.config.stdout, "--- {}", source);
2128
2129    for record in parser.parse_records()? {
2130        // In maximal-verbose mode, print the query before attempting to run
2131        // it. Running the query might panic, so it is important to print out
2132        // what query we are trying to run *before* we panic.
2133        if runner.config.verbose {
2134            print_record(runner.config, &record);
2135        }
2136
2137        let outcome = runner
2138            .run_record(&record, &mut in_transaction)
2139            .await
2140            .map_err(|err| format!("In {}:\n{}", source, err))
2141            .unwrap();
2142
2143        // Print warnings and failures in verbose mode.
2144        if !runner.config.quiet && !outcome.success() {
2145            if !runner.config.verbose {
2146                // If `verbose` is enabled, we'll already have printed the record,
2147                // so don't print it again. Yes, this is an ugly bit of logic.
2148                // Please don't try to consolidate it with the `print_record`
2149                // call above, as it's important to have a mode in which records
2150                // are printed before they are run, so that if running the
2151                // record panics, you can tell which record caused it.
2152                if !outcome.failure() {
2153                    writeln!(
2154                        runner.config.stdout,
2155                        "{}",
2156                        util::indent("Warning detected for: ", 4)
2157                    );
2158                }
2159                print_record(runner.config, &record);
2160            }
2161            if runner.config.verbose || outcome.failure() {
2162                writeln!(
2163                    runner.config.stdout,
2164                    "{}",
2165                    util::indent(&outcome.to_string(), 4)
2166                );
2167                writeln!(runner.config.stdout, "{}", util::indent("----", 4));
2168            }
2169        }
2170
2171        outcomes.stats[outcome.code()] += 1;
2172        if outcome.failure() {
2173            outcomes.details.push(format!("{}", outcome));
2174        }
2175
2176        if let Outcome::Bail { .. } = outcome {
2177            break;
2178        }
2179
2180        if runner.config.fail_fast && outcome.failure() {
2181            break;
2182        }
2183    }
2184    Ok(outcomes)
2185}
2186
2187pub async fn run_file(runner: &mut Runner<'_>, filename: &Path) -> Result<Outcomes, anyhow::Error> {
2188    let mut input = String::new();
2189    File::open(filename)?.read_to_string(&mut input)?;
2190    let outcomes = run_string(runner, &format!("{}", filename.display()), &input).await?;
2191    runner.check_catalog().await?;
2192
2193    Ok(outcomes)
2194}
2195
2196pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<(), anyhow::Error> {
2197    runner.reset_database().await?;
2198
2199    let mut file = OpenOptions::new().read(true).write(true).open(filename)?;
2200
2201    let mut input = String::new();
2202    file.read_to_string(&mut input)?;
2203
2204    let mut buf = RewriteBuffer::new(&input);
2205
2206    let mut parser = crate::parser::Parser::new(filename.to_str().unwrap_or(""), &input);
2207    writeln!(runner.config.stdout, "--- {}", filename.display());
2208    let mut in_transaction = false;
2209
2210    fn append_values_output(
2211        buf: &mut RewriteBuffer,
2212        input: &String,
2213        expected_output: &str,
2214        mode: &Mode,
2215        types: &Vec<Type>,
2216        column_names: Option<&Vec<ColumnName>>,
2217        actual_output: &Vec<String>,
2218        multiline: bool,
2219    ) {
2220        buf.append_header(input, expected_output, column_names);
2221
2222        for (i, row) in actual_output.chunks(types.len()).enumerate() {
2223            match mode {
2224                // In Cockroach mode, output each row on its own line, with
2225                // two spaces between each column.
2226                Mode::Cockroach => {
2227                    if i != 0 {
2228                        buf.append("\n");
2229                    }
2230
2231                    if row.len() == 0 {
2232                        // nothing to do
2233                    } else if row.len() == 1 {
2234                        // If there is only one column, then there is no need for space
2235                        // substitution, so we only do newline substitution.
2236                        if multiline {
2237                            buf.append(&row[0]);
2238                        } else {
2239                            buf.append(&row[0].replace('\n', "⏎"))
2240                        }
2241                    } else {
2242                        // Substitute spaces with ␠ to avoid mistaking the spaces in the result
2243                        // values with spaces that separate columns.
2244                        buf.append(
2245                            &row.iter()
2246                                .map(|col| {
2247                                    let mut col = col.replace(' ', "␠");
2248                                    if !multiline {
2249                                        col = col.replace('\n', "⏎");
2250                                    }
2251                                    col
2252                                })
2253                                .join("  "),
2254                        );
2255                    }
2256                }
2257                // In standard mode, output each value on its own line,
2258                // and ignore row boundaries.
2259                // No need to substitute spaces, because every value (not row) is on a separate
2260                // line. But we do need to substitute newlines.
2261                Mode::Standard => {
2262                    for (j, col) in row.iter().enumerate() {
2263                        if i != 0 || j != 0 {
2264                            buf.append("\n");
2265                        }
2266                        buf.append(&if multiline {
2267                            col.clone()
2268                        } else {
2269                            col.replace('\n', "⏎")
2270                        });
2271                    }
2272                }
2273            }
2274        }
2275    }
2276
2277    for record in parser.parse_records()? {
2278        let outcome = runner.run_record(&record, &mut in_transaction).await?;
2279
2280        match (&record, &outcome) {
2281            // If we see an output failure for a query, rewrite the expected output
2282            // to match the observed output.
2283            (
2284                Record::Query {
2285                    output:
2286                        Ok(QueryOutput {
2287                            mode,
2288                            output: Output::Values(_),
2289                            output_str: expected_output,
2290                            types,
2291                            column_names,
2292                            multiline,
2293                            ..
2294                        }),
2295                    ..
2296                },
2297                Outcome::OutputFailure {
2298                    actual_output: Output::Values(actual_output),
2299                    ..
2300                },
2301            ) => {
2302                append_values_output(
2303                    &mut buf,
2304                    &input,
2305                    expected_output,
2306                    mode,
2307                    types,
2308                    column_names.as_ref(),
2309                    actual_output,
2310                    *multiline,
2311                );
2312            }
2313            (
2314                Record::Query {
2315                    output:
2316                        Ok(QueryOutput {
2317                            mode,
2318                            output: Output::Values(_),
2319                            output_str: expected_output,
2320                            types,
2321                            multiline,
2322                            ..
2323                        }),
2324                    ..
2325                },
2326                Outcome::WrongColumnNames {
2327                    actual_column_names,
2328                    actual_output: Output::Values(actual_output),
2329                    ..
2330                },
2331            ) => {
2332                append_values_output(
2333                    &mut buf,
2334                    &input,
2335                    expected_output,
2336                    mode,
2337                    types,
2338                    Some(actual_column_names),
2339                    actual_output,
2340                    *multiline,
2341                );
2342            }
2343            (
2344                Record::Query {
2345                    output:
2346                        Ok(QueryOutput {
2347                            output: Output::Hashed { .. },
2348                            output_str: expected_output,
2349                            column_names,
2350                            ..
2351                        }),
2352                    ..
2353                },
2354                Outcome::OutputFailure {
2355                    actual_output: Output::Hashed { num_values, md5 },
2356                    ..
2357                },
2358            ) => {
2359                buf.append_header(&input, expected_output, column_names.as_ref());
2360
2361                buf.append(format!("{} values hashing to {}\n", num_values, md5).as_str())
2362            }
2363            (
2364                Record::Simple {
2365                    output_str: expected_output,
2366                    ..
2367                },
2368                Outcome::OutputFailure {
2369                    actual_output: Output::Values(actual_output),
2370                    ..
2371                },
2372            ) => {
2373                buf.append_header(&input, expected_output, None);
2374
2375                for (i, row) in actual_output.iter().enumerate() {
2376                    if i != 0 {
2377                        buf.append("\n");
2378                    }
2379                    buf.append(row);
2380                }
2381            }
2382            (
2383                Record::Query {
2384                    sql,
2385                    output: Err(err),
2386                    ..
2387                },
2388                outcome,
2389            )
2390            | (
2391                Record::Statement {
2392                    expected_error: Some(err),
2393                    sql,
2394                    ..
2395                },
2396                outcome,
2397            ) if outcome.err_msg().is_some() => {
2398                buf.rewrite_expected_error(&input, err, &outcome.err_msg().unwrap(), sql)
2399            }
2400            (_, Outcome::Success) => {}
2401            _ => bail!("unexpected: {:?} {:?}", record, outcome),
2402        }
2403    }
2404
2405    file.set_len(0)?;
2406    file.seek(SeekFrom::Start(0))?;
2407    file.write_all(buf.finish().as_bytes())?;
2408    file.sync_all()?;
2409    Ok(())
2410}
2411
2412/// Provides a means to rewrite the `.slt` file while iterating over it.
2413///
2414/// This struct takes the slt file as its `input`, tracks a cursor into it
2415/// (`input_offset`), and provides a buffer (`output`) to store the rewritten
2416/// results.
2417///
2418/// Functions that modify the file will lazily move `input` into `output` using
2419/// `flush_to`. However, those calls should all be interior to other functions.
2420#[derive(Debug)]
2421struct RewriteBuffer<'a> {
2422    input: &'a str,
2423    input_offset: usize,
2424    output: String,
2425}
2426
2427impl<'a> RewriteBuffer<'a> {
2428    fn new(input: &'a str) -> RewriteBuffer<'a> {
2429        RewriteBuffer {
2430            input,
2431            input_offset: 0,
2432            output: String::new(),
2433        }
2434    }
2435
2436    fn flush_to(&mut self, offset: usize) {
2437        assert!(offset >= self.input_offset);
2438        let chunk = &self.input[self.input_offset..offset];
2439        self.output.push_str(chunk);
2440        self.input_offset = offset;
2441    }
2442
2443    fn skip_to(&mut self, offset: usize) {
2444        assert!(offset >= self.input_offset);
2445        self.input_offset = offset;
2446    }
2447
2448    fn append(&mut self, s: &str) {
2449        self.output.push_str(s);
2450    }
2451
2452    fn append_header(
2453        &mut self,
2454        input: &String,
2455        expected_output: &str,
2456        column_names: Option<&Vec<ColumnName>>,
2457    ) {
2458        // Output everything before this record.
2459        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2460        #[allow(clippy::as_conversions)]
2461        let offset = expected_output.as_ptr() as usize - input.as_ptr() as usize;
2462        self.flush_to(offset);
2463        self.skip_to(offset + expected_output.len());
2464
2465        // Attempt to install the result separator (----), if it does
2466        // not already exist.
2467        if self.peek_last(5) == "\n----" {
2468            self.append("\n");
2469        } else if self.peek_last(6) != "\n----\n" {
2470            self.append("\n----\n");
2471        }
2472
2473        let Some(names) = column_names else {
2474            return;
2475        };
2476        self.append(
2477            &names
2478                .iter()
2479                .map(|name| name.replace(' ', "␠"))
2480                .collect::<Vec<_>>()
2481                .join(" "),
2482        );
2483        self.append("\n");
2484    }
2485
2486    fn rewrite_expected_error(
2487        &mut self,
2488        input: &String,
2489        old_err: &str,
2490        new_err: &str,
2491        query: &str,
2492    ) {
2493        // Output everything before this error message.
2494        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2495        #[allow(clippy::as_conversions)]
2496        let err_offset = old_err.as_ptr() as usize - input.as_ptr() as usize;
2497        self.flush_to(err_offset);
2498        self.append(new_err);
2499        self.append("\n");
2500        self.append(query);
2501        // TODO(benesch): is it possible to rewrite this to avoid `as`?
2502        #[allow(clippy::as_conversions)]
2503        self.skip_to(query.as_ptr() as usize - input.as_ptr() as usize + query.len())
2504    }
2505
2506    fn peek_last(&self, n: usize) -> &str {
2507        &self.output[self.output.len() - n..]
2508    }
2509
2510    fn finish(mut self) -> String {
2511        self.flush_to(self.input.len());
2512        self.output
2513    }
2514}
2515
2516/// Generates view creation, view indexing, view querying, and view
2517/// dropping SQL commands for a given `SELECT` query. If the number
2518/// of attributes produced by the query is known, the view commands
2519/// are specialized to avoid issues with column ambiguity. This
2520/// function is a helper for `--auto_index_selects` and assumes that
2521/// the provided input SQL has already been run through the parser,
2522/// resulting in a valid `SELECT` statement.
2523fn generate_view_sql(
2524    sql: &str,
2525    view_uuid: &Simple,
2526    num_attributes: Option<usize>,
2527    expected_column_names: Option<Vec<ColumnName>>,
2528) -> (String, String, String, String) {
2529    // To create the view, re-parse the sql; note that we must find exactly
2530    // one statement and it must be a `SELECT`.
2531    // NOTE(vmarcos): Direct string manipulation was attempted while
2532    // prototyping the code below, which avoids the extra parsing and
2533    // data structure cloning. However, running DDL is so slow that
2534    // it did not matter in terms of runtime. We can revisit this if
2535    // DDL cost drops dramatically in the future.
2536    let stmts = parser::parse_statements(sql).unwrap_or_default();
2537    assert!(stmts.len() == 1);
2538    let (query, query_as_of) = match &stmts[0].ast {
2539        Statement::Select(stmt) => (&stmt.query, &stmt.as_of),
2540        _ => unreachable!("This function should only be called for SELECTs"),
2541    };
2542
2543    // Prior to creating the view, process the `ORDER BY` clause of
2544    // the `SELECT` query, if any. Ordering is not preserved when a
2545    // view includes an `ORDER BY` clause and must be re-enforced by
2546    // an external `ORDER BY` clause when querying the view.
2547    let (view_order_by, extra_columns, distinct) = if num_attributes.is_none() {
2548        (query.order_by.clone(), vec![], None)
2549    } else {
2550        derive_order_by(&query.body, &query.order_by)
2551    };
2552
2553    // Since one-shot SELECT statements may contain ambiguous column names,
2554    // we either use the expected column names, if that option was
2555    // provided, or else just rename the output schema of the view
2556    // using numerically increasing attribute names, whenever possible.
2557    // This strategy makes it possible to use `CREATE INDEX`, thus
2558    // matching the behavior of the option `auto_index_tables`. However,
2559    // we may be presented with a `SELECT *` query, in which case the parser
2560    // does not produce sufficient information to allow us to compute
2561    // the number of output columns. In the latter case, we are supplied
2562    // with `None` for `num_attributes` and just employ the command
2563    // `CREATE DEFAULT INDEX` instead. Additionally, the view is created
2564    // without schema renaming. This strategy is insufficient to dodge
2565    // column name ambiguity in all cases, but we assume here that we
2566    // can adjust the (hopefully) small number of tests that eventually
2567    // challenge us in this particular way.
2568    let name = UnresolvedItemName(vec![Ident::new_unchecked(format!("v{}", view_uuid))]);
2569    let projection = expected_column_names.map_or_else(
2570        || {
2571            num_attributes.map_or(vec![], |n| {
2572                (1..=n)
2573                    .map(|i| Ident::new_unchecked(format!("a{i}")))
2574                    .collect()
2575            })
2576        },
2577        |cols| {
2578            cols.iter()
2579                .map(|c| Ident::new_unchecked(c.as_str()))
2580                .collect()
2581        },
2582    );
2583    let columns: Vec<Ident> = projection
2584        .iter()
2585        .cloned()
2586        .chain(extra_columns.iter().map(|item| {
2587            if let SelectItem::Expr {
2588                expr: _,
2589                alias: Some(ident),
2590            } = item
2591            {
2592                ident.clone()
2593            } else {
2594                unreachable!("alias must be given for extra column")
2595            }
2596        }))
2597        .collect();
2598
2599    // Build a `CREATE VIEW` with the columns computed above.
2600    let mut query = query.clone();
2601    if extra_columns.len() > 0 {
2602        match &mut query.body {
2603            SetExpr::Select(stmt) => stmt.projection.extend(extra_columns.iter().cloned()),
2604            _ => unimplemented!("cannot yet rewrite projections of nested queries"),
2605        }
2606    }
2607    let create_view = AstStatement::<Raw>::CreateView(CreateViewStatement {
2608        if_exists: IfExistsBehavior::Error,
2609        temporary: false,
2610        definition: ViewDefinition {
2611            name: name.clone(),
2612            columns: columns.clone(),
2613            query,
2614        },
2615    })
2616    .to_ast_string_stable();
2617
2618    // We then create either a `CREATE INDEX` or a `CREATE DEFAULT INDEX`
2619    // statement, depending on whether we could obtain the number of
2620    // attributes from the original `SELECT`.
2621    let create_index = AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2622        name: None,
2623        in_cluster: None,
2624        on_name: RawItemName::Name(name.clone()),
2625        key_parts: if columns.len() == 0 {
2626            None
2627        } else {
2628            Some(
2629                columns
2630                    .iter()
2631                    .map(|ident| Expr::Identifier(vec![ident.clone()]))
2632                    .collect(),
2633            )
2634        },
2635        with_options: Vec::new(),
2636        if_not_exists: false,
2637    })
2638    .to_ast_string_stable();
2639
2640    // Assert if DISTINCT semantics are unchanged from view
2641    let distinct_unneeded = extra_columns.len() == 0
2642        || match distinct {
2643            None | Some(Distinct::On(_)) => true,
2644            Some(Distinct::EntireRow) => false,
2645        };
2646    let distinct = if distinct_unneeded { None } else { distinct };
2647
2648    // `SELECT [* | {projection}] FROM {name} [ORDER BY {view_order_by}]`
2649    let view_sql = AstStatement::<Raw>::Select(SelectStatement {
2650        query: Query {
2651            ctes: CteBlock::Simple(vec![]),
2652            body: SetExpr::Select(Box::new(Select {
2653                distinct,
2654                projection: if projection.len() == 0 {
2655                    vec![SelectItem::Wildcard]
2656                } else {
2657                    projection
2658                        .iter()
2659                        .map(|ident| SelectItem::Expr {
2660                            expr: Expr::Identifier(vec![ident.clone()]),
2661                            alias: None,
2662                        })
2663                        .collect()
2664                },
2665                from: vec![TableWithJoins {
2666                    relation: TableFactor::Table {
2667                        name: RawItemName::Name(name.clone()),
2668                        alias: None,
2669                    },
2670                    joins: vec![],
2671                }],
2672                selection: None,
2673                group_by: vec![],
2674                having: None,
2675                qualify: None,
2676                options: vec![],
2677            })),
2678            order_by: view_order_by,
2679            limit: None,
2680            offset: None,
2681        },
2682        as_of: query_as_of.clone(),
2683    })
2684    .to_ast_string_stable();
2685
2686    // `DROP VIEW {name}`
2687    let drop_view = AstStatement::<Raw>::DropObjects(DropObjectsStatement {
2688        object_type: ObjectType::View,
2689        if_exists: false,
2690        names: vec![UnresolvedObjectName::Item(name)],
2691        cascade: false,
2692    })
2693    .to_ast_string_stable();
2694
2695    (create_view, create_index, view_sql, drop_view)
2696}
2697
2698/// Analyzes the provided query `body` to derive the number of
2699/// attributes in the query. We only consider syntactic cues,
2700/// so we may end up deriving `None` for the number of attributes
2701/// as a conservative approximation.
2702fn derive_num_attributes(body: &SetExpr<Raw>) -> Option<usize> {
2703    let Some((projection, _)) = find_projection(body) else {
2704        return None;
2705    };
2706    derive_num_attributes_from_projection(projection)
2707}
2708
2709/// Analyzes a query's `ORDER BY` clause to derive an `ORDER BY`
2710/// clause that makes numeric references to any expressions in
2711/// the projection and generated-attribute references to expressions
2712/// that need to be added as extra columns to the projection list.
2713/// The rewritten `ORDER BY` clause is then usable when querying a
2714/// view that contains the same `SELECT` as the given query.
2715/// This function returns both the rewritten `ORDER BY` clause
2716/// as well as a list of extra columns that need to be added
2717/// to the query's projection for the `ORDER BY` clause to
2718/// succeed.
2719fn derive_order_by(
2720    body: &SetExpr<Raw>,
2721    order_by: &Vec<OrderByExpr<Raw>>,
2722) -> (
2723    Vec<OrderByExpr<Raw>>,
2724    Vec<SelectItem<Raw>>,
2725    Option<Distinct<Raw>>,
2726) {
2727    let Some((projection, distinct)) = find_projection(body) else {
2728        return (vec![], vec![], None);
2729    };
2730    let (view_order_by, extra_columns) = derive_order_by_from_projection(projection, order_by);
2731    (view_order_by, extra_columns, distinct.clone())
2732}
2733
2734/// Finds the projection list in a `SELECT` query body.
2735fn find_projection(body: &SetExpr<Raw>) -> Option<(&Vec<SelectItem<Raw>>, &Option<Distinct<Raw>>)> {
2736    // Iterate to peel off the query body until the query's
2737    // projection list is found.
2738    let mut set_expr = body;
2739    loop {
2740        match set_expr {
2741            SetExpr::Select(select) => {
2742                return Some((&select.projection, &select.distinct));
2743            }
2744            SetExpr::SetOperation { left, .. } => set_expr = left.as_ref(),
2745            SetExpr::Query(query) => set_expr = &query.body,
2746            _ => return None,
2747        }
2748    }
2749}
2750
2751/// Computes the number of attributes that are obtained by the
2752/// projection of a `SELECT` query. The projection may include
2753/// wildcards, in which case the analysis just returns `None`.
2754fn derive_num_attributes_from_projection(projection: &Vec<SelectItem<Raw>>) -> Option<usize> {
2755    let mut num_attributes = 0usize;
2756    for item in projection.iter() {
2757        let SelectItem::Expr { expr, .. } = item else {
2758            return None;
2759        };
2760        match expr {
2761            Expr::QualifiedWildcard(..) | Expr::WildcardAccess(..) => {
2762                return None;
2763            }
2764            _ => {
2765                num_attributes += 1;
2766            }
2767        }
2768    }
2769    Some(num_attributes)
2770}
2771
2772/// Computes an `ORDER BY` clause with only numeric references
2773/// from given projection and `ORDER BY` of a `SELECT` query.
2774/// If the derivation fails to match a given expression, the
2775/// matched prefix is returned. Note that this could be empty.
2776fn derive_order_by_from_projection(
2777    projection: &Vec<SelectItem<Raw>>,
2778    order_by: &Vec<OrderByExpr<Raw>>,
2779) -> (Vec<OrderByExpr<Raw>>, Vec<SelectItem<Raw>>) {
2780    let mut view_order_by: Vec<OrderByExpr<Raw>> = vec![];
2781    let mut extra_columns: Vec<SelectItem<Raw>> = vec![];
2782    for order_by_expr in order_by.iter() {
2783        let query_expr = &order_by_expr.expr;
2784        let view_expr = match query_expr {
2785            Expr::Value(mz_sql_parser::ast::Value::Number(_)) => query_expr.clone(),
2786            _ => {
2787                // Find expression in query projection, if we can.
2788                if let Some(i) = projection.iter().position(|item| match item {
2789                    SelectItem::Expr { expr, alias } => {
2790                        expr == query_expr
2791                            || match query_expr {
2792                                Expr::Identifier(ident) => {
2793                                    ident.len() == 1 && Some(&ident[0]) == alias.as_ref()
2794                                }
2795                                _ => false,
2796                            }
2797                    }
2798                    SelectItem::Wildcard => false,
2799                }) {
2800                    Expr::Value(mz_sql_parser::ast::Value::Number((i + 1).to_string()))
2801                } else {
2802                    // If the expression is not found in the
2803                    // projection, add extra column.
2804                    let ident = Ident::new_unchecked(format!(
2805                        "a{}",
2806                        (projection.len() + extra_columns.len() + 1)
2807                    ));
2808                    extra_columns.push(SelectItem::Expr {
2809                        expr: query_expr.clone(),
2810                        alias: Some(ident.clone()),
2811                    });
2812                    Expr::Identifier(vec![ident])
2813                }
2814            }
2815        };
2816        view_order_by.push(OrderByExpr {
2817            expr: view_expr,
2818            asc: order_by_expr.asc,
2819            nulls_last: order_by_expr.nulls_last,
2820        });
2821    }
2822    (view_order_by, extra_columns)
2823}
2824
2825/// Returns extra statements to execute after `stmt` is executed.
2826fn mutate(sql: &str) -> Vec<String> {
2827    let stmts = parser::parse_statements(sql).unwrap_or_default();
2828    let mut additional = Vec::new();
2829    for stmt in stmts {
2830        match stmt.ast {
2831            AstStatement::CreateTable(stmt) => additional.push(
2832                // CREATE TABLE -> CREATE INDEX. Specify all columns manually in case CREATE
2833                // DEFAULT INDEX ever goes away.
2834                AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2835                    name: None,
2836                    in_cluster: None,
2837                    on_name: RawItemName::Name(stmt.name.clone()),
2838                    key_parts: Some(
2839                        stmt.columns
2840                            .iter()
2841                            .map(|def| Expr::Identifier(vec![def.name.clone()]))
2842                            .collect(),
2843                    ),
2844                    with_options: Vec::new(),
2845                    if_not_exists: false,
2846                })
2847                .to_ast_string_stable(),
2848            ),
2849            _ => {}
2850        }
2851    }
2852    additional
2853}
2854
2855#[mz_ore::test]
2856#[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
2857fn test_generate_view_sql() {
2858    let uuid = Uuid::parse_str("67e5504410b1426f9247bb680e5fe0c8").unwrap();
2859    let cases = vec![
2860        (("SELECT * FROM t", None, None),
2861        (
2862            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM "t""#.to_string(),
2863            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2864            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2865            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2866        )),
2867        (("SELECT a, b, c FROM t1, t2", Some(3), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c")])),
2868        (
2869            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2870            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c")"#.to_string(),
2871            r#"SELECT "a", "b", "c" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2872            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2873        )),
2874        (("SELECT a, b, c FROM t1, t2", Some(3), None),
2875        (
2876            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2877            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2878            r#"SELECT "a1", "a2", "a3" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2879            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2880        )),
2881        // A case with ambiguity that is accepted by the function, illustrating that
2882        // our measures to dodge this issue are imperfect.
2883        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a)", None, None),
2884        (
2885            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a")"#.to_string(),
2886            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2887            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2888            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2889        )),
2890        (("SELECT a, b, b + d AS c, a + b AS d FROM t1, t2 ORDER BY a, c, a + b", Some(4), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c"), ColumnName::from("d")])),
2891        (
2892            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d") AS SELECT "a", "b", "b" + "d" AS "c", "a" + "b" AS "d" FROM "t1", "t2" ORDER BY "a", "c", "a" + "b""#.to_string(),
2893            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d")"#.to_string(),
2894            r#"SELECT "a", "b", "c", "d" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, 3, 4"#.to_string(),
2895            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2896        )),
2897        (("((SELECT 1 AS a UNION SELECT 2 AS b) UNION SELECT 3 AS c) ORDER BY a", Some(1), None),
2898        (
2899            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1") AS (SELECT 1 AS "a" UNION SELECT 2 AS "b") UNION SELECT 3 AS "c" ORDER BY "a""#.to_string(),
2900            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1")"#.to_string(),
2901            r#"SELECT "a1" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2902            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2903        )),
2904        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY 1", None, None),
2905        (
2906            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY 1"#.to_string(),
2907            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2908            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2909            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2910        )),
2911        (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY a", None, None),
2912        (
2913            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY "a""#.to_string(),
2914            r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2915            r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a""#.to_string(),
2916            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2917        )),
2918        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY a, c", Some(2), None),
2919        (
2920            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "a", "c""#.to_string(),
2921            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2922            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, "a3""#.to_string(),
2923            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2924        )),
2925        (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY c, a", Some(2), None),
2926        (
2927            r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "c", "a""#.to_string(),
2928            r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2929            r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a3", 1"#.to_string(),
2930            r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2931        )),
2932    ];
2933    for ((sql, num_attributes, expected_column_names), expected) in cases {
2934        let view_sql =
2935            generate_view_sql(sql, uuid.as_simple(), num_attributes, expected_column_names);
2936        assert_eq!(expected, view_sql);
2937    }
2938}
2939
2940#[mz_ore::test]
2941fn test_mutate() {
2942    let cases = vec![
2943        ("CREATE TABLE t ()", vec![r#"CREATE INDEX ON "t" ()"#]),
2944        (
2945            "CREATE TABLE t (a INT)",
2946            vec![r#"CREATE INDEX ON "t" ("a")"#],
2947        ),
2948        (
2949            "CREATE TABLE t (a INT, b TEXT)",
2950            vec![r#"CREATE INDEX ON "t" ("a", "b")"#],
2951        ),
2952        // Invalid syntax works, just returns nothing.
2953        ("BAD SYNTAX", Vec::new()),
2954    ];
2955    for (sql, expected) in cases {
2956        let stmts = mutate(sql);
2957        assert_eq!(expected, stmts, "sql: {sql}");
2958    }
2959}