1use std::collections::BTreeMap;
29use std::error::Error;
30use std::fs::{File, OpenOptions};
31use std::io::{Read, Seek, SeekFrom, Write};
32use std::net::{IpAddr, Ipv4Addr, SocketAddr};
33use std::path::Path;
34use std::sync::Arc;
35use std::sync::LazyLock;
36use std::time::Duration;
37use std::{env, fmt, ops, str, thread};
38
39use anyhow::{anyhow, bail};
40use bytes::BytesMut;
41use chrono::{DateTime, NaiveDateTime, NaiveTime, Utc};
42use fallible_iterator::FallibleIterator;
43use futures::sink::SinkExt;
44use itertools::Itertools;
45use md5::{Digest, Md5};
46use mz_adapter_types::bootstrap_builtin_cluster_config::{
47 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
48 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
49 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
50};
51use mz_catalog::config::ClusterReplicaSizeMap;
52use mz_controller::ControllerConfig;
53use mz_environmentd::CatalogConfig;
54use mz_license_keys::ValidatedLicenseKey;
55use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
56use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
57use mz_ore::cast::{CastFrom, ReinterpretCast};
58use mz_ore::channel::trigger;
59use mz_ore::error::ErrorExt;
60use mz_ore::metrics::MetricsRegistry;
61use mz_ore::now::SYSTEM_TIME;
62use mz_ore::retry::Retry;
63use mz_ore::task;
64use mz_ore::thread::{JoinHandleExt, JoinOnDropHandle};
65use mz_ore::tracing::TracingHandle;
66use mz_ore::url::SensitiveUrl;
67use mz_persist_client::PersistLocation;
68use mz_persist_client::cache::PersistClientCache;
69use mz_persist_client::cfg::PersistConfig;
70use mz_persist_client::rpc::{
71 MetricsSameProcessPubSubSender, PersistGrpcPubSubServer, PubSubClientConnection, PubSubSender,
72};
73use mz_pgrepr::{Interval, Jsonb, Numeric, UInt2, UInt4, UInt8, Value, oid};
74use mz_repr::ColumnName;
75use mz_repr::adt::date::Date;
76use mz_repr::adt::mz_acl_item::{AclItem, MzAclItem};
77use mz_repr::adt::numeric;
78use mz_secrets::SecretsController;
79use mz_sql::ast::{Expr, Raw, Statement};
80use mz_sql::catalog::EnvironmentId;
81use mz_sql_parser::ast::display::AstDisplay;
82use mz_sql_parser::ast::{
83 CreateIndexStatement, CreateViewStatement, CteBlock, Distinct, DropObjectsStatement, Ident,
84 IfExistsBehavior, ObjectType, OrderByExpr, Query, RawItemName, Select, SelectItem,
85 SelectStatement, SetExpr, Statement as AstStatement, TableFactor, TableWithJoins,
86 UnresolvedItemName, UnresolvedObjectName, ViewDefinition,
87};
88use mz_sql_parser::parser;
89use mz_storage_types::connections::ConnectionContext;
90use postgres_protocol::types;
91use regex::Regex;
92use tempfile::TempDir;
93use tokio::net::TcpListener;
94use tokio::runtime::Runtime;
95use tokio::sync::oneshot;
96use tokio_postgres::types::{FromSql, Kind as PgKind, Type as PgType};
97use tokio_postgres::{NoTls, Row, SimpleQueryMessage};
98use tokio_stream::wrappers::TcpListenerStream;
99use tower_http::cors::AllowOrigin;
100use tracing::{error, info};
101use uuid::Uuid;
102use uuid::fmt::Simple;
103
104use crate::ast::{Location, Mode, Output, QueryOutput, Record, Sort, Type};
105use crate::util;
106
107#[derive(Debug)]
108pub enum Outcome<'a> {
109 Unsupported {
110 error: anyhow::Error,
111 location: Location,
112 },
113 ParseFailure {
114 error: anyhow::Error,
115 location: Location,
116 },
117 PlanFailure {
118 error: anyhow::Error,
119 location: Location,
120 },
121 UnexpectedPlanSuccess {
122 expected_error: &'a str,
123 location: Location,
124 },
125 WrongNumberOfRowsInserted {
126 expected_count: u64,
127 actual_count: u64,
128 location: Location,
129 },
130 WrongColumnCount {
131 expected_count: usize,
132 actual_count: usize,
133 location: Location,
134 },
135 WrongColumnNames {
136 expected_column_names: &'a Vec<ColumnName>,
137 actual_column_names: Vec<ColumnName>,
138 actual_output: Output,
139 location: Location,
140 },
141 OutputFailure {
142 expected_output: &'a Output,
143 actual_raw_output: Vec<Row>,
144 actual_output: Output,
145 location: Location,
146 },
147 InconsistentViewOutcome {
148 query_outcome: Box<Outcome<'a>>,
149 view_outcome: Box<Outcome<'a>>,
150 location: Location,
151 },
152 Bail {
153 cause: Box<Outcome<'a>>,
154 location: Location,
155 },
156 Warning {
157 cause: Box<Outcome<'a>>,
158 location: Location,
159 },
160 Success,
161}
162
163const NUM_OUTCOMES: usize = 12;
164const WARNING_OUTCOME: usize = NUM_OUTCOMES - 2;
165const SUCCESS_OUTCOME: usize = NUM_OUTCOMES - 1;
166
167impl<'a> Outcome<'a> {
168 fn code(&self) -> usize {
169 match self {
170 Outcome::Unsupported { .. } => 0,
171 Outcome::ParseFailure { .. } => 1,
172 Outcome::PlanFailure { .. } => 2,
173 Outcome::UnexpectedPlanSuccess { .. } => 3,
174 Outcome::WrongNumberOfRowsInserted { .. } => 4,
175 Outcome::WrongColumnCount { .. } => 5,
176 Outcome::WrongColumnNames { .. } => 6,
177 Outcome::OutputFailure { .. } => 7,
178 Outcome::InconsistentViewOutcome { .. } => 8,
179 Outcome::Bail { .. } => 9,
180 Outcome::Warning { .. } => 10,
181 Outcome::Success => 11,
182 }
183 }
184
185 fn success(&self) -> bool {
186 matches!(self, Outcome::Success)
187 }
188
189 fn failure(&self) -> bool {
190 !matches!(self, Outcome::Success) && !matches!(self, Outcome::Warning { .. })
191 }
192
193 fn err_msg(&self) -> Option<String> {
197 match self {
198 Outcome::Unsupported { error, .. }
199 | Outcome::ParseFailure { error, .. }
200 | Outcome::PlanFailure { error, .. } => Some(
201 regex::escape(
204 error.to_string().split('\n').next().unwrap(),
207 ),
208 ),
209 _ => None,
210 }
211 }
212}
213
214impl fmt::Display for Outcome<'_> {
215 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216 use Outcome::*;
217 const INDENT: &str = "\n ";
218 match self {
219 Unsupported { error, location } => write!(
220 f,
221 "Unsupported:{}:\n{}",
222 location,
223 error.display_with_causes()
224 ),
225 ParseFailure { error, location } => {
226 write!(
227 f,
228 "ParseFailure:{}:\n{}",
229 location,
230 error.display_with_causes()
231 )
232 }
233 PlanFailure { error, location } => write!(f, "PlanFailure:{}:\n{:#}", location, error),
234 UnexpectedPlanSuccess {
235 expected_error,
236 location,
237 } => write!(
238 f,
239 "UnexpectedPlanSuccess:{} expected error: {}",
240 location, expected_error
241 ),
242 WrongNumberOfRowsInserted {
243 expected_count,
244 actual_count,
245 location,
246 } => write!(
247 f,
248 "WrongNumberOfRowsInserted:{}{}expected: {}{}actually: {}",
249 location, INDENT, expected_count, INDENT, actual_count
250 ),
251 WrongColumnCount {
252 expected_count,
253 actual_count,
254 location,
255 } => write!(
256 f,
257 "WrongColumnCount:{}{}expected: {}{}actually: {}",
258 location, INDENT, expected_count, INDENT, actual_count
259 ),
260 WrongColumnNames {
261 expected_column_names,
262 actual_column_names,
263 actual_output: _,
264 location,
265 } => write!(
266 f,
267 "Wrong Column Names:{}:{}expected column names: {}{}inferred column names: {}",
268 location,
269 INDENT,
270 expected_column_names
271 .iter()
272 .map(|n| n.to_string())
273 .collect::<Vec<_>>()
274 .join(" "),
275 INDENT,
276 actual_column_names
277 .iter()
278 .map(|n| n.to_string())
279 .collect::<Vec<_>>()
280 .join(" ")
281 ),
282 OutputFailure {
283 expected_output,
284 actual_raw_output,
285 actual_output,
286 location,
287 } => write!(
288 f,
289 "OutputFailure:{}{}expected: {:?}{}actually: {:?}{}actual raw: {:?}",
290 location, INDENT, expected_output, INDENT, actual_output, INDENT, actual_raw_output
291 ),
292 InconsistentViewOutcome {
293 query_outcome,
294 view_outcome,
295 location,
296 } => write!(
297 f,
298 "InconsistentViewOutcome:{}{}expected from query: {:?}{}actually from indexed view: {:?}{}",
299 location, INDENT, query_outcome, INDENT, view_outcome, INDENT
300 ),
301 Bail { cause, location } => write!(f, "Bail:{} {}", location, cause),
302 Warning { cause, location } => write!(f, "Warning:{} {}", location, cause),
303 Success => f.write_str("Success"),
304 }
305 }
306}
307
308#[derive(Default, Debug)]
309pub struct Outcomes {
310 stats: [usize; NUM_OUTCOMES],
311 details: Vec<String>,
312}
313
314impl ops::AddAssign<Outcomes> for Outcomes {
315 fn add_assign(&mut self, rhs: Outcomes) {
316 for (lhs, rhs) in self.stats.iter_mut().zip(rhs.stats.iter()) {
317 *lhs += rhs
318 }
319 }
320}
321impl Outcomes {
322 pub fn any_failed(&self) -> bool {
323 self.stats[SUCCESS_OUTCOME] + self.stats[WARNING_OUTCOME] < self.stats.iter().sum::<usize>()
324 }
325
326 pub fn as_json(&self) -> serde_json::Value {
327 serde_json::json!({
328 "unsupported": self.stats[0],
329 "parse_failure": self.stats[1],
330 "plan_failure": self.stats[2],
331 "unexpected_plan_success": self.stats[3],
332 "wrong_number_of_rows_affected": self.stats[4],
333 "wrong_column_count": self.stats[5],
334 "wrong_column_names": self.stats[6],
335 "output_failure": self.stats[7],
336 "inconsistent_view_outcome": self.stats[8],
337 "bail": self.stats[9],
338 "warning": self.stats[10],
339 "success": self.stats[11],
340 })
341 }
342
343 pub fn display(&self, no_fail: bool, failure_details: bool) -> OutcomesDisplay<'_> {
344 OutcomesDisplay {
345 inner: self,
346 no_fail,
347 failure_details,
348 }
349 }
350}
351
352pub struct OutcomesDisplay<'a> {
353 inner: &'a Outcomes,
354 no_fail: bool,
355 failure_details: bool,
356}
357
358impl<'a> fmt::Display for OutcomesDisplay<'a> {
359 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
360 let total: usize = self.inner.stats.iter().sum();
361 if self.failure_details
362 && (self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] != total
363 || self.no_fail)
364 {
365 for outcome in &self.inner.details {
366 writeln!(f, "{}", outcome)?;
367 }
368 Ok(())
369 } else {
370 write!(
371 f,
372 "{}:",
373 if self.inner.stats[SUCCESS_OUTCOME] + self.inner.stats[WARNING_OUTCOME] == total {
374 "PASS"
375 } else if self.no_fail {
376 "FAIL-IGNORE"
377 } else {
378 "FAIL"
379 }
380 )?;
381 static NAMES: LazyLock<Vec<&'static str>> = LazyLock::new(|| {
382 vec![
383 "unsupported",
384 "parse-failure",
385 "plan-failure",
386 "unexpected-plan-success",
387 "wrong-number-of-rows-inserted",
388 "wrong-column-count",
389 "wrong-column-names",
390 "output-failure",
391 "inconsistent-view-outcome",
392 "bail",
393 "warning",
394 "success",
395 "total",
396 ]
397 });
398 for (i, n) in self.inner.stats.iter().enumerate() {
399 if *n > 0 {
400 write!(f, " {}={}", NAMES[i], n)?;
401 }
402 }
403 write!(f, " total={}", total)
404 }
405 }
406}
407
408struct QueryInfo {
409 is_select: bool,
410 num_attributes: Option<usize>,
411}
412
413enum PrepareQueryOutcome<'a> {
414 QueryPrepared(QueryInfo),
415 Outcome(Outcome<'a>),
416}
417
418pub struct Runner<'a> {
419 config: &'a RunConfig<'a>,
420 inner: Option<RunnerInner<'a>>,
421}
422
423pub struct RunnerInner<'a> {
424 server_addr: SocketAddr,
425 internal_server_addr: SocketAddr,
426 internal_http_server_addr: SocketAddr,
427 client: tokio_postgres::Client,
429 system_client: tokio_postgres::Client,
430 clients: BTreeMap<String, tokio_postgres::Client>,
431 auto_index_tables: bool,
432 auto_index_selects: bool,
433 auto_transactions: bool,
434 enable_table_keys: bool,
435 verbosity: u8,
436 stdout: &'a dyn WriteFmt,
437 _shutdown_trigger: trigger::Trigger,
438 _server_thread: JoinOnDropHandle<()>,
439 _temp_dir: TempDir,
440}
441
442#[derive(Debug)]
443pub struct Slt(Value);
444
445impl<'a> FromSql<'a> for Slt {
446 fn from_sql(
447 ty: &PgType,
448 mut raw: &'a [u8],
449 ) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
450 Ok(match *ty {
451 PgType::ACLITEM => Self(Value::AclItem(AclItem::decode_binary(
452 types::bytea_from_sql(raw),
453 )?)),
454 PgType::BOOL => Self(Value::Bool(types::bool_from_sql(raw)?)),
455 PgType::BYTEA => Self(Value::Bytea(types::bytea_from_sql(raw).to_vec())),
456 PgType::CHAR => Self(Value::Char(u8::from_be_bytes(
457 types::char_from_sql(raw)?.to_be_bytes(),
458 ))),
459 PgType::FLOAT4 => Self(Value::Float4(types::float4_from_sql(raw)?)),
460 PgType::FLOAT8 => Self(Value::Float8(types::float8_from_sql(raw)?)),
461 PgType::DATE => Self(Value::Date(Date::from_pg_epoch(types::int4_from_sql(
462 raw,
463 )?)?)),
464 PgType::INT2 => Self(Value::Int2(types::int2_from_sql(raw)?)),
465 PgType::INT4 => Self(Value::Int4(types::int4_from_sql(raw)?)),
466 PgType::INT8 => Self(Value::Int8(types::int8_from_sql(raw)?)),
467 PgType::INTERVAL => Self(Value::Interval(Interval::from_sql(ty, raw)?)),
468 PgType::JSONB => Self(Value::Jsonb(Jsonb::from_sql(ty, raw)?)),
469 PgType::NAME => Self(Value::Name(types::text_from_sql(raw)?.to_string())),
470 PgType::NUMERIC => Self(Value::Numeric(Numeric::from_sql(ty, raw)?)),
471 PgType::OID => Self(Value::Oid(types::oid_from_sql(raw)?)),
472 PgType::REGCLASS => Self(Value::Oid(types::oid_from_sql(raw)?)),
473 PgType::REGPROC => Self(Value::Oid(types::oid_from_sql(raw)?)),
474 PgType::REGTYPE => Self(Value::Oid(types::oid_from_sql(raw)?)),
475 PgType::TEXT | PgType::BPCHAR | PgType::VARCHAR => {
476 Self(Value::Text(types::text_from_sql(raw)?.to_string()))
477 }
478 PgType::TIME => Self(Value::Time(NaiveTime::from_sql(ty, raw)?)),
479 PgType::TIMESTAMP => Self(Value::Timestamp(
480 NaiveDateTime::from_sql(ty, raw)?.try_into()?,
481 )),
482 PgType::TIMESTAMPTZ => Self(Value::TimestampTz(
483 DateTime::<Utc>::from_sql(ty, raw)?.try_into()?,
484 )),
485 PgType::UUID => Self(Value::Uuid(Uuid::from_sql(ty, raw)?)),
486 PgType::RECORD => {
487 let num_fields = read_be_i32(&mut raw)?;
488 let mut tuple = vec![];
489 for _ in 0..num_fields {
490 let oid = u32::reinterpret_cast(read_be_i32(&mut raw)?);
491 let typ = match PgType::from_oid(oid) {
492 Some(typ) => typ,
493 None => return Err("unknown oid".into()),
494 };
495 let v = read_value::<Option<Slt>>(&typ, &mut raw)?;
496 tuple.push(v.map(|v| v.0));
497 }
498 Self(Value::Record(tuple))
499 }
500 PgType::INT4_RANGE
501 | PgType::INT8_RANGE
502 | PgType::DATE_RANGE
503 | PgType::NUM_RANGE
504 | PgType::TS_RANGE
505 | PgType::TSTZ_RANGE => {
506 use mz_repr::adt::range::Range;
507 let range: Range<Slt> = Range::from_sql(ty, raw)?;
508 Self(Value::Range(range.into_bounds(|b| Box::new(b.0))))
509 }
510
511 _ => match ty.kind() {
512 PgKind::Array(arr_type) => {
513 let arr = types::array_from_sql(raw)?;
514 let elements: Vec<Option<Value>> = arr
515 .values()
516 .map(|v| match v {
517 Some(v) => Ok(Some(Slt::from_sql(arr_type, v)?)),
518 None => Ok(None),
519 })
520 .collect::<Vec<Option<Slt>>>()?
521 .into_iter()
522 .map(|v| v.map(|v| v.0))
524 .collect();
525
526 Self(Value::Array {
527 dims: arr
528 .dimensions()
529 .map(|d| {
530 Ok(mz_repr::adt::array::ArrayDimension {
531 lower_bound: isize::cast_from(d.lower_bound),
532 length: usize::try_from(d.len)
533 .expect("cannot have negative length"),
534 })
535 })
536 .collect()?,
537 elements,
538 })
539 }
540 _ => match ty.oid() {
541 oid::TYPE_UINT2_OID => Self(Value::UInt2(UInt2::from_sql(ty, raw)?)),
542 oid::TYPE_UINT4_OID => Self(Value::UInt4(UInt4::from_sql(ty, raw)?)),
543 oid::TYPE_UINT8_OID => Self(Value::UInt8(UInt8::from_sql(ty, raw)?)),
544 oid::TYPE_MZ_TIMESTAMP_OID => {
545 let s = types::text_from_sql(raw)?;
546 let t: mz_repr::Timestamp = s.parse()?;
547 Self(Value::MzTimestamp(t))
548 }
549 oid::TYPE_MZ_ACL_ITEM_OID => Self(Value::MzAclItem(MzAclItem::decode_binary(
550 types::bytea_from_sql(raw),
551 )?)),
552 _ => unreachable!(),
553 },
554 },
555 })
556 }
557 fn accepts(ty: &PgType) -> bool {
558 match ty.kind() {
559 PgKind::Array(_) | PgKind::Composite(_) => return true,
560 _ => {}
561 }
562 match ty.oid() {
563 oid::TYPE_UINT2_OID
564 | oid::TYPE_UINT4_OID
565 | oid::TYPE_UINT8_OID
566 | oid::TYPE_MZ_TIMESTAMP_OID
567 | oid::TYPE_MZ_ACL_ITEM_OID => return true,
568 _ => {}
569 }
570 matches!(
571 *ty,
572 PgType::ACLITEM
573 | PgType::BOOL
574 | PgType::BYTEA
575 | PgType::CHAR
576 | PgType::DATE
577 | PgType::FLOAT4
578 | PgType::FLOAT8
579 | PgType::INT2
580 | PgType::INT4
581 | PgType::INT8
582 | PgType::INTERVAL
583 | PgType::JSONB
584 | PgType::NAME
585 | PgType::NUMERIC
586 | PgType::OID
587 | PgType::REGCLASS
588 | PgType::REGPROC
589 | PgType::REGTYPE
590 | PgType::RECORD
591 | PgType::TEXT
592 | PgType::BPCHAR
593 | PgType::VARCHAR
594 | PgType::TIME
595 | PgType::TIMESTAMP
596 | PgType::TIMESTAMPTZ
597 | PgType::UUID
598 | PgType::INT4_RANGE
599 | PgType::INT4_RANGE_ARRAY
600 | PgType::INT8_RANGE
601 | PgType::INT8_RANGE_ARRAY
602 | PgType::DATE_RANGE
603 | PgType::DATE_RANGE_ARRAY
604 | PgType::NUM_RANGE
605 | PgType::NUM_RANGE_ARRAY
606 | PgType::TS_RANGE
607 | PgType::TS_RANGE_ARRAY
608 | PgType::TSTZ_RANGE
609 | PgType::TSTZ_RANGE_ARRAY
610 )
611 }
612}
613
614fn read_be_i32(buf: &mut &[u8]) -> Result<i32, Box<dyn Error + Sync + Send>> {
616 if buf.len() < 4 {
617 return Err("invalid buffer size".into());
618 }
619 let mut bytes = [0; 4];
620 bytes.copy_from_slice(&buf[..4]);
621 *buf = &buf[4..];
622 Ok(i32::from_be_bytes(bytes))
623}
624
625fn read_value<'a, T>(type_: &PgType, buf: &mut &'a [u8]) -> Result<T, Box<dyn Error + Sync + Send>>
627where
628 T: FromSql<'a>,
629{
630 let value = match usize::try_from(read_be_i32(buf)?) {
631 Err(_) => None,
632 Ok(len) => {
633 if len > buf.len() {
634 return Err("invalid buffer size".into());
635 }
636 let (head, tail) = buf.split_at(len);
637 *buf = tail;
638 Some(head)
639 }
640 };
641 T::from_sql_nullable(type_, value)
642}
643
644fn format_datum(d: Slt, typ: &Type, mode: Mode, col: usize) -> String {
645 match (typ, d.0) {
646 (Type::Bool, Value::Bool(b)) => b.to_string(),
647
648 (Type::Integer, Value::Int2(i)) => i.to_string(),
649 (Type::Integer, Value::Int4(i)) => i.to_string(),
650 (Type::Integer, Value::Int8(i)) => i.to_string(),
651 (Type::Integer, Value::UInt2(u)) => u.0.to_string(),
652 (Type::Integer, Value::UInt4(u)) => u.0.to_string(),
653 (Type::Integer, Value::UInt8(u)) => u.0.to_string(),
654 (Type::Integer, Value::Oid(i)) => i.to_string(),
655 #[allow(clippy::as_conversions)]
657 (Type::Integer, Value::Float4(f)) => format!("{}", f as i64),
658 #[allow(clippy::as_conversions)]
660 (Type::Integer, Value::Float8(f)) => format!("{}", f as i64),
661 (Type::Integer, Value::Text(_)) => "0".to_string(),
663 (Type::Integer, Value::Bool(b)) => i8::from(b).to_string(),
664 (Type::Integer, Value::Numeric(d)) => {
665 let mut d = d.0.0.clone();
666 let mut cx = numeric::cx_datum();
667 if mode == Mode::Standard {
669 cx.set_rounding(dec::Rounding::Down);
670 }
671 cx.round(&mut d);
672 numeric::munge_numeric(&mut d).unwrap();
673 d.to_standard_notation_string()
674 }
675
676 (Type::Real, Value::Int2(i)) => format!("{:.3}", i),
677 (Type::Real, Value::Int4(i)) => format!("{:.3}", i),
678 (Type::Real, Value::Int8(i)) => format!("{:.3}", i),
679 (Type::Real, Value::Float4(f)) => match mode {
680 Mode::Standard => format!("{:.3}", f),
681 Mode::Cockroach => format!("{}", f),
682 },
683 (Type::Real, Value::Float8(f)) => match mode {
684 Mode::Standard => format!("{:.3}", f),
685 Mode::Cockroach => format!("{}", f),
686 },
687 (Type::Real, Value::Numeric(d)) => match mode {
688 Mode::Standard => {
689 let mut d = d.0.0.clone();
690 if d.exponent() < -3 {
691 numeric::rescale(&mut d, 3).unwrap();
692 }
693 numeric::munge_numeric(&mut d).unwrap();
694 d.to_standard_notation_string()
695 }
696 Mode::Cockroach => d.0.0.to_standard_notation_string(),
697 },
698
699 (Type::Text, Value::Text(s)) => {
700 if s.is_empty() {
701 "(empty)".to_string()
702 } else {
703 s
704 }
705 }
706 (Type::Text, Value::Bool(b)) => b.to_string(),
707 (Type::Text, Value::Float4(f)) => format!("{:.3}", f),
708 (Type::Text, Value::Float8(f)) => format!("{:.3}", f),
709 (Type::Text, Value::Bytea(b)) => match str::from_utf8(&b) {
715 Ok(s) => s.to_string(),
716 Err(_) => format!("{:?}", b),
717 },
718 (Type::Text, Value::Numeric(d)) => d.0.0.to_standard_notation_string(),
719 (Type::Text, d) => {
722 let mut buf = BytesMut::new();
723 d.encode_text(&mut buf);
724 String::from_utf8_lossy(&buf).into_owned()
725 }
726
727 (Type::Oid, Value::Oid(o)) => o.to_string(),
728
729 (_, d) => panic!(
730 "Don't know how to format {:?} as {:?} in column {}",
731 d, typ, col,
732 ),
733 }
734}
735
736fn format_row(row: &Row, types: &[Type], mode: Mode) -> Vec<String> {
737 let mut formatted: Vec<String> = vec![];
738 for i in 0..row.len() {
739 let t: Option<Slt> = row.get::<usize, Option<Slt>>(i);
740 let t: Option<String> = t.map(|d| format_datum(d, &types[i], mode, i));
741 formatted.push(match t {
742 Some(t) => t,
743 None => "NULL".into(),
744 });
745 }
746
747 formatted
748}
749
750impl<'a> Runner<'a> {
751 pub async fn start(config: &'a RunConfig<'a>) -> Result<Runner<'a>, anyhow::Error> {
752 let mut runner = Self {
753 config,
754 inner: None,
755 };
756 runner.reset().await?;
757 Ok(runner)
758 }
759
760 pub async fn reset(&mut self) -> Result<(), anyhow::Error> {
761 drop(self.inner.take());
764 self.inner = Some(RunnerInner::start(self.config).await?);
765
766 Ok(())
767 }
768
769 async fn run_record<'r>(
770 &mut self,
771 record: &'r Record<'r>,
772 in_transaction: &mut bool,
773 ) -> Result<Outcome<'r>, anyhow::Error> {
774 if let Record::ResetServer = record {
775 self.reset().await?;
776 Ok(Outcome::Success)
777 } else {
778 self.inner
779 .as_mut()
780 .expect("RunnerInner missing")
781 .run_record(record, in_transaction)
782 .await
783 }
784 }
785
786 async fn check_catalog(&self) -> Result<(), anyhow::Error> {
787 self.inner
788 .as_ref()
789 .expect("RunnerInner missing")
790 .check_catalog()
791 .await
792 }
793
794 async fn reset_database(&mut self) -> Result<(), anyhow::Error> {
795 let inner = self.inner.as_mut().expect("RunnerInner missing");
796
797 inner.client.batch_execute("ROLLBACK;").await?;
798
799 inner
800 .system_client
801 .batch_execute(
802 "ROLLBACK;
803 SET cluster = mz_catalog_server;
804 RESET cluster_replica;",
805 )
806 .await?;
807
808 inner
809 .system_client
810 .batch_execute("ALTER SYSTEM RESET ALL")
811 .await?;
812
813 for row in inner
815 .system_client
816 .query("SELECT name FROM mz_databases", &[])
817 .await?
818 {
819 let name: &str = row.get("name");
820 inner
821 .system_client
822 .batch_execute(&format!("DROP DATABASE \"{name}\""))
823 .await?;
824 }
825 inner
826 .system_client
827 .batch_execute("CREATE DATABASE materialize")
828 .await?;
829
830 let mut needs_default_cluster = true;
834 for row in inner
835 .system_client
836 .query("SELECT name FROM mz_clusters WHERE id LIKE 'u%'", &[])
837 .await?
838 {
839 match row.get("name") {
840 "quickstart" => needs_default_cluster = false,
841 name => {
842 inner
843 .system_client
844 .batch_execute(&format!("DROP CLUSTER {name}"))
845 .await?
846 }
847 }
848 }
849 if needs_default_cluster {
850 inner
851 .system_client
852 .batch_execute("CREATE CLUSTER quickstart REPLICAS ()")
853 .await?;
854 }
855
856 inner
858 .system_client
859 .batch_execute("GRANT USAGE ON DATABASE materialize TO PUBLIC")
860 .await?;
861 inner
862 .system_client
863 .batch_execute("GRANT CREATE ON DATABASE materialize TO materialize")
864 .await?;
865 inner
866 .system_client
867 .batch_execute("GRANT CREATE ON SCHEMA materialize.public TO materialize")
868 .await?;
869 inner
870 .system_client
871 .batch_execute("GRANT USAGE ON CLUSTER quickstart TO PUBLIC")
872 .await?;
873 inner
874 .system_client
875 .batch_execute("GRANT CREATE ON CLUSTER quickstart TO materialize")
876 .await?;
877
878 inner
881 .system_client
882 .simple_query("ALTER SYSTEM SET max_tables = 100")
883 .await?;
884
885 if inner.enable_table_keys {
886 inner
887 .system_client
888 .simple_query("ALTER SYSTEM SET unsafe_enable_table_keys = true")
889 .await?;
890 }
891
892 inner.ensure_fixed_features().await?;
893
894 inner.client = connect(inner.server_addr, None).await;
895 inner.system_client = connect(inner.internal_server_addr, Some("mz_system")).await;
896 inner.clients = BTreeMap::new();
897
898 Ok(())
899 }
900}
901
902impl<'a> RunnerInner<'a> {
903 pub async fn start(config: &RunConfig<'a>) -> Result<RunnerInner<'a>, anyhow::Error> {
904 let temp_dir = tempfile::tempdir()?;
905 let scratch_dir = tempfile::tempdir()?;
906 let environment_id = EnvironmentId::for_tests();
907 let (consensus_uri, timestamp_oracle_url): (SensitiveUrl, SensitiveUrl) = {
908 let postgres_url = &config.postgres_url;
909 info!(%postgres_url, "starting server");
910 let (client, conn) = Retry::default()
911 .max_tries(5)
912 .retry_async(|_| async {
913 match tokio_postgres::connect(postgres_url, NoTls).await {
914 Ok(c) => Ok(c),
915 Err(e) => {
916 error!(%e, "failed to connect to postgres");
917 Err(e)
918 }
919 }
920 })
921 .await?;
922 task::spawn(|| "sqllogictest_connect", async move {
923 if let Err(e) = conn.await {
924 panic!("connection error: {}", e);
925 }
926 });
927 client
928 .batch_execute(
929 "DROP SCHEMA IF EXISTS sqllogictest_tsoracle CASCADE;
930 CREATE SCHEMA IF NOT EXISTS sqllogictest_consensus;
931 CREATE SCHEMA sqllogictest_tsoracle;",
932 )
933 .await?;
934 (
935 format!("{postgres_url}?options=--search_path=sqllogictest_consensus")
936 .parse()
937 .expect("invalid consensus URI"),
938 format!("{postgres_url}?options=--search_path=sqllogictest_tsoracle")
939 .parse()
940 .expect("invalid timestamp oracle URI"),
941 )
942 };
943
944 let secrets_dir = temp_dir.path().join("secrets");
945 let orchestrator = Arc::new(
946 ProcessOrchestrator::new(ProcessOrchestratorConfig {
947 image_dir: env::current_exe()?.parent().unwrap().to_path_buf(),
948 suppress_output: false,
949 environment_id: environment_id.to_string(),
950 secrets_dir: secrets_dir.clone(),
951 command_wrapper: config
952 .orchestrator_process_wrapper
953 .as_ref()
954 .map_or(Ok(vec![]), |s| shell_words::split(s))?,
955 propagate_crashes: true,
956 tcp_proxy: None,
957 scratch_directory: scratch_dir.path().to_path_buf(),
958 })
959 .await?,
960 );
961 let now = SYSTEM_TIME.clone();
962 let metrics_registry = MetricsRegistry::new();
963
964 let persist_config = PersistConfig::new(
965 &mz_environmentd::BUILD_INFO,
966 now.clone(),
967 mz_dyncfgs::all_dyncfgs(),
968 );
969 let persist_pubsub_server =
970 PersistGrpcPubSubServer::new(&persist_config, &metrics_registry);
971 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
972 let persist_pubsub_tcp_listener =
973 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
974 .await
975 .expect("pubsub addr binding");
976 let persist_pubsub_server_port = persist_pubsub_tcp_listener
977 .local_addr()
978 .expect("pubsub addr has local addr")
979 .port();
980 info!("listening for persist pubsub connections on localhost:{persist_pubsub_server_port}");
981 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
982 persist_pubsub_server
983 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
984 .await
985 .expect("success")
986 });
987 let persist_clients =
988 PersistClientCache::new(persist_config, &metrics_registry, |cfg, metrics| {
989 let sender: Arc<dyn PubSubSender> = Arc::new(MetricsSameProcessPubSubSender::new(
990 cfg,
991 persist_pubsub_client.sender,
992 metrics,
993 ));
994 PubSubClientConnection::new(sender, persist_pubsub_client.receiver)
995 });
996 let persist_clients = Arc::new(persist_clients);
997
998 let secrets_controller = Arc::clone(&orchestrator);
999 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
1000 let orchestrator = Arc::new(TracingOrchestrator::new(
1001 orchestrator,
1002 config.tracing.clone(),
1003 ));
1004 let listeners = mz_environmentd::Listeners::bind_any_local().await?;
1005 let host_name = format!("localhost:{}", listeners.http_local_addr().port());
1006 let catalog_config = CatalogConfig {
1007 persist_clients: Arc::clone(&persist_clients),
1008 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
1009 };
1010 let server_config = mz_environmentd::Config {
1011 catalog_config,
1012 timestamp_oracle_url: Some(timestamp_oracle_url),
1013 controller: ControllerConfig {
1014 build_info: &mz_environmentd::BUILD_INFO,
1015 orchestrator,
1016 clusterd_image: "clusterd".into(),
1017 init_container_image: None,
1018 deploy_generation: 0,
1019 persist_location: PersistLocation {
1020 blob_uri: format!(
1021 "file://{}/persist/blob",
1022 config.persist_dir.path().display()
1023 )
1024 .parse()
1025 .expect("invalid blob URI"),
1026 consensus_uri,
1027 },
1028 persist_clients,
1029 now: SYSTEM_TIME.clone(),
1030 metrics_registry: metrics_registry.clone(),
1031 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
1032 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
1033 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
1034 secrets_reader_local_file_dir: Some(secrets_dir),
1035 secrets_reader_kubernetes_context: None,
1036 secrets_reader_aws_prefix: None,
1037 secrets_reader_name_prefix: None,
1038 },
1039 connection_context,
1040 },
1041 secrets_controller,
1042 cloud_resource_controller: None,
1043 tls: None,
1044 frontegg: None,
1045 self_hosted_auth: false,
1046 self_hosted_auth_internal: false,
1047 cors_allowed_origin: AllowOrigin::list([]),
1048 unsafe_mode: true,
1049 all_features: false,
1050 metrics_registry,
1051 now,
1052 environment_id,
1053 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
1054 bootstrap_default_cluster_replica_size: config.replicas.to_string(),
1055 bootstrap_default_cluster_replication_factor: 1,
1056 bootstrap_builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
1057 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1058 size: config.replicas.to_string(),
1059 },
1060 bootstrap_builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
1061 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1062 size: config.replicas.to_string(),
1063 },
1064 bootstrap_builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
1065 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1066 size: config.replicas.to_string(),
1067 },
1068 bootstrap_builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
1069 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1070 size: config.replicas.to_string(),
1071 },
1072 bootstrap_builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
1073 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
1074 size: config.replicas.to_string(),
1075 },
1076 system_parameter_defaults: {
1077 let mut params = BTreeMap::new();
1078 params.insert(
1079 "log_filter".to_string(),
1080 config.tracing.startup_log_filter.to_string(),
1081 );
1082 params.extend(config.system_parameter_defaults.clone());
1083 params
1084 },
1085 availability_zones: Default::default(),
1086 tracing_handle: config.tracing_handle.clone(),
1087 storage_usage_collection_interval: Duration::from_secs(3600),
1088 storage_usage_retention_period: None,
1089 segment_api_key: None,
1090 segment_client_side: false,
1091 egress_addresses: vec![],
1092 aws_account_id: None,
1093 aws_privatelink_availability_zones: None,
1094 launchdarkly_sdk_key: None,
1095 launchdarkly_key_map: Default::default(),
1096 config_sync_timeout: Duration::from_secs(30),
1097 config_sync_loop_interval: None,
1098 bootstrap_role: Some("materialize".into()),
1099 http_host_name: Some(host_name),
1100 internal_console_redirect_url: None,
1101 tls_reload_certs: mz_server_core::cert_reload_never_reload(),
1102 helm_chart_version: None,
1103 license_key: ValidatedLicenseKey::for_tests(),
1104 };
1105 let (server_addr_tx, server_addr_rx): (oneshot::Sender<Result<_, anyhow::Error>>, _) =
1111 oneshot::channel();
1112 let (internal_server_addr_tx, internal_server_addr_rx) = oneshot::channel();
1113 let (internal_http_server_addr_tx, internal_http_server_addr_rx) = oneshot::channel();
1114 let (shutdown_trigger, shutdown_trigger_rx) = trigger::channel();
1115 let server_thread = thread::spawn(|| {
1116 let runtime = match Runtime::new() {
1117 Ok(runtime) => runtime,
1118 Err(e) => {
1119 server_addr_tx
1120 .send(Err(e.into()))
1121 .expect("receiver should not drop first");
1122 return;
1123 }
1124 };
1125 let server = match runtime.block_on(listeners.serve(server_config)) {
1126 Ok(runtime) => runtime,
1127 Err(e) => {
1128 server_addr_tx
1129 .send(Err(e.into()))
1130 .expect("receiver should not drop first");
1131 return;
1132 }
1133 };
1134 server_addr_tx
1135 .send(Ok(server.sql_local_addr()))
1136 .expect("receiver should not drop first");
1137 internal_server_addr_tx
1138 .send(server.internal_sql_local_addr())
1139 .expect("receiver should not drop first");
1140 internal_http_server_addr_tx
1141 .send(server.internal_http_local_addr())
1142 .expect("receiver should not drop first");
1143 let _ = runtime.block_on(shutdown_trigger_rx);
1144 });
1145 let server_addr = server_addr_rx.await??;
1146 let internal_server_addr = internal_server_addr_rx.await?;
1147 let internal_http_server_addr = internal_http_server_addr_rx.await?;
1148
1149 let system_client = connect(internal_server_addr, Some("mz_system")).await;
1150 let client = connect(server_addr, None).await;
1151
1152 let inner = RunnerInner {
1153 server_addr,
1154 internal_server_addr,
1155 internal_http_server_addr,
1156 _shutdown_trigger: shutdown_trigger,
1157 _server_thread: server_thread.join_on_drop(),
1158 _temp_dir: temp_dir,
1159 client,
1160 system_client,
1161 clients: BTreeMap::new(),
1162 auto_index_tables: config.auto_index_tables,
1163 auto_index_selects: config.auto_index_selects,
1164 auto_transactions: config.auto_transactions,
1165 enable_table_keys: config.enable_table_keys,
1166 verbosity: config.verbosity,
1167 stdout: config.stdout,
1168 };
1169 inner.ensure_fixed_features().await?;
1170
1171 Ok(inner)
1172 }
1173
1174 async fn ensure_fixed_features(&self) -> Result<(), anyhow::Error> {
1177 self.system_client
1181 .execute("ALTER SYSTEM SET enable_reduce_mfp_fusion = on", &[])
1182 .await?;
1183
1184 self.system_client
1186 .execute("ALTER SYSTEM SET unsafe_enable_unsafe_functions = on", &[])
1187 .await?;
1188 Ok(())
1189 }
1190
1191 async fn run_record<'r>(
1192 &mut self,
1193 record: &'r Record<'r>,
1194 in_transaction: &mut bool,
1195 ) -> Result<Outcome<'r>, anyhow::Error> {
1196 match &record {
1197 Record::Statement {
1198 expected_error,
1199 rows_affected,
1200 sql,
1201 location,
1202 } => {
1203 if self.auto_transactions && *in_transaction {
1204 self.client.execute("COMMIT", &[]).await?;
1205 *in_transaction = false;
1206 }
1207 match self
1208 .run_statement(*expected_error, *rows_affected, sql, location.clone())
1209 .await?
1210 {
1211 Outcome::Success => {
1212 if self.auto_index_tables {
1213 let additional = mutate(sql);
1214 for stmt in additional {
1215 self.client.execute(&stmt, &[]).await?;
1216 }
1217 }
1218 Ok(Outcome::Success)
1219 }
1220 other => {
1221 if expected_error.is_some() {
1222 Ok(other)
1223 } else {
1224 Ok(Outcome::Bail {
1228 cause: Box::new(other),
1229 location: location.clone(),
1230 })
1231 }
1232 }
1233 }
1234 }
1235 Record::Query {
1236 sql,
1237 output,
1238 location,
1239 } => {
1240 self.run_query(sql, output, location.clone(), in_transaction)
1241 .await
1242 }
1243 Record::Simple {
1244 conn,
1245 user,
1246 sql,
1247 output,
1248 location,
1249 ..
1250 } => {
1251 self.run_simple(*conn, *user, sql, output, location.clone())
1252 .await
1253 }
1254 Record::Copy {
1255 table_name,
1256 tsv_path,
1257 } => {
1258 let tsv = tokio::fs::read(tsv_path).await?;
1259 let copy = self
1260 .client
1261 .copy_in(&*format!("COPY {} FROM STDIN", table_name))
1262 .await?;
1263 tokio::pin!(copy);
1264 copy.send(bytes::Bytes::from(tsv)).await?;
1265 copy.finish().await?;
1266 Ok(Outcome::Success)
1267 }
1268 _ => Ok(Outcome::Success),
1269 }
1270 }
1271
1272 async fn run_statement<'r>(
1273 &self,
1274 expected_error: Option<&'r str>,
1275 expected_rows_affected: Option<u64>,
1276 sql: &'r str,
1277 location: Location,
1278 ) -> Result<Outcome<'r>, anyhow::Error> {
1279 static UNSUPPORTED_INDEX_STATEMENT_REGEX: LazyLock<Regex> =
1280 LazyLock::new(|| Regex::new("^(CREATE UNIQUE INDEX|REINDEX)").unwrap());
1281 if UNSUPPORTED_INDEX_STATEMENT_REGEX.is_match(sql) {
1282 return Ok(Outcome::Success);
1284 }
1285
1286 match self.client.execute(sql, &[]).await {
1287 Ok(actual) => {
1288 if let Some(expected_error) = expected_error {
1289 return Ok(Outcome::UnexpectedPlanSuccess {
1290 expected_error,
1291 location,
1292 });
1293 }
1294 match expected_rows_affected {
1295 None => Ok(Outcome::Success),
1296 Some(expected) => {
1297 if expected != actual {
1298 Ok(Outcome::WrongNumberOfRowsInserted {
1299 expected_count: expected,
1300 actual_count: actual,
1301 location,
1302 })
1303 } else {
1304 Ok(Outcome::Success)
1305 }
1306 }
1307 }
1308 }
1309 Err(error) => {
1310 if let Some(expected_error) = expected_error {
1311 if Regex::new(expected_error)?.is_match(&format!("{:#}", error)) {
1312 return Ok(Outcome::Success);
1313 }
1314 }
1315 Ok(Outcome::PlanFailure {
1316 error: anyhow!(error),
1317 location,
1318 })
1319 }
1320 }
1321 }
1322
1323 async fn prepare_query<'r>(
1324 &self,
1325 sql: &str,
1326 output: &'r Result<QueryOutput<'_>, &'r str>,
1327 location: Location,
1328 in_transaction: &mut bool,
1329 ) -> Result<PrepareQueryOutcome<'r>, anyhow::Error> {
1330 let statements = match mz_sql::parse::parse(sql) {
1332 Ok(statements) => statements,
1333 Err(e) => match output {
1334 Ok(_) => {
1335 return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1336 error: e.into(),
1337 location,
1338 }));
1339 }
1340 Err(expected_error) => {
1341 if Regex::new(expected_error)?.is_match(&format!("{:#}", e)) {
1342 return Ok(PrepareQueryOutcome::Outcome(Outcome::Success));
1343 } else {
1344 return Ok(PrepareQueryOutcome::Outcome(Outcome::ParseFailure {
1345 error: e.into(),
1346 location,
1347 }));
1348 }
1349 }
1350 },
1351 };
1352 let statement = match &*statements {
1353 [] => bail!("Got zero statements?"),
1354 [statement] => &statement.ast,
1355 _ => bail!("Got multiple statements: {:?}", statements),
1356 };
1357 let (is_select, num_attributes) = match statement {
1358 Statement::Select(stmt) => (true, derive_num_attributes(&stmt.query.body)),
1359 _ => (false, None),
1360 };
1361
1362 match output {
1363 Ok(_) => {
1364 if self.auto_transactions && !*in_transaction {
1365 self.client.execute("BEGIN", &[]).await?;
1367 *in_transaction = true;
1368 }
1369 }
1370 Err(_) => {
1371 if self.auto_transactions && *in_transaction {
1372 self.client.execute("COMMIT", &[]).await?;
1373 *in_transaction = false;
1374 }
1375 }
1376 }
1377
1378 match statement {
1382 Statement::Show(..) => {
1383 if self.auto_transactions && *in_transaction {
1384 self.client.execute("COMMIT", &[]).await?;
1385 *in_transaction = false;
1386 }
1387 }
1388 _ => (),
1389 }
1390 Ok(PrepareQueryOutcome::QueryPrepared(QueryInfo {
1391 is_select,
1392 num_attributes,
1393 }))
1394 }
1395
1396 async fn execute_query<'r>(
1397 &self,
1398 sql: &str,
1399 output: &'r Result<QueryOutput<'_>, &'r str>,
1400 location: Location,
1401 ) -> Result<Outcome<'r>, anyhow::Error> {
1402 let rows = match self.client.query(sql, &[]).await {
1403 Ok(rows) => rows,
1404 Err(error) => {
1405 return match output {
1406 Ok(_) => {
1407 let error_string = format!("{}", error);
1408 if error_string.contains("supported") || error_string.contains("overload") {
1409 Ok(Outcome::Unsupported {
1411 error: anyhow!(error),
1412 location,
1413 })
1414 } else {
1415 Ok(Outcome::PlanFailure {
1416 error: anyhow!(error),
1417 location,
1418 })
1419 }
1420 }
1421 Err(expected_error) => {
1422 if Regex::new(expected_error)?.is_match(&format!("{:#}", error)) {
1423 Ok(Outcome::Success)
1424 } else {
1425 Ok(Outcome::PlanFailure {
1426 error: anyhow!(error),
1427 location,
1428 })
1429 }
1430 }
1431 };
1432 }
1433 };
1434
1435 let QueryOutput {
1437 sort,
1438 types: expected_types,
1439 column_names: expected_column_names,
1440 output: expected_output,
1441 mode,
1442 ..
1443 } = match output {
1444 Err(expected_error) => {
1445 return Ok(Outcome::UnexpectedPlanSuccess {
1446 expected_error,
1447 location,
1448 });
1449 }
1450 Ok(query_output) => query_output,
1451 };
1452
1453 let mut formatted_rows = vec![];
1455 for row in &rows {
1456 if row.len() != expected_types.len() {
1457 return Ok(Outcome::WrongColumnCount {
1458 expected_count: expected_types.len(),
1459 actual_count: row.len(),
1460 location,
1461 });
1462 }
1463 let row = format_row(row, expected_types, *mode);
1464 formatted_rows.push(row);
1465 }
1466
1467 if let Sort::Row = sort {
1469 formatted_rows.sort();
1470 }
1471 let mut values = formatted_rows.into_iter().flatten().collect::<Vec<_>>();
1472 if let Sort::Value = sort {
1473 values.sort();
1474 }
1475
1476 if let Some(row) = rows.get(0) {
1478 if let Some(expected_column_names) = expected_column_names {
1480 let actual_column_names = row
1481 .columns()
1482 .iter()
1483 .map(|t| ColumnName::from(t.name()))
1484 .collect::<Vec<_>>();
1485 if expected_column_names != &actual_column_names {
1486 return Ok(Outcome::WrongColumnNames {
1487 expected_column_names,
1488 actual_column_names,
1489 actual_output: Output::Values(values),
1490 location,
1491 });
1492 }
1493 }
1494 }
1495
1496 match expected_output {
1498 Output::Values(expected_values) => {
1499 if values != *expected_values {
1500 return Ok(Outcome::OutputFailure {
1501 expected_output,
1502 actual_raw_output: rows,
1503 actual_output: Output::Values(values),
1504 location,
1505 });
1506 }
1507 }
1508 Output::Hashed {
1509 num_values,
1510 md5: expected_md5,
1511 } => {
1512 let mut hasher = Md5::new();
1513 for value in &values {
1514 hasher.update(value);
1515 hasher.update("\n");
1516 }
1517 let md5 = format!("{:x}", hasher.finalize());
1518 if values.len() != *num_values || md5 != *expected_md5 {
1519 return Ok(Outcome::OutputFailure {
1520 expected_output,
1521 actual_raw_output: rows,
1522 actual_output: Output::Hashed {
1523 num_values: values.len(),
1524 md5,
1525 },
1526 location,
1527 });
1528 }
1529 }
1530 }
1531
1532 Ok(Outcome::Success)
1533 }
1534
1535 async fn execute_view_inner<'r>(
1536 &self,
1537 sql: &str,
1538 output: &'r Result<QueryOutput<'_>, &'r str>,
1539 location: Location,
1540 ) -> Result<Option<Outcome<'r>>, anyhow::Error> {
1541 print_sql_if(self.stdout, sql, self.verbosity >= 2);
1542 let sql_result = self.client.execute(sql, &[]).await;
1543
1544 let tentative_outcome = if let Err(view_error) = sql_result {
1546 if let Err(expected_error) = output {
1547 if Regex::new(expected_error)?.is_match(&format!("{:#}", view_error)) {
1548 Some(Outcome::Success)
1549 } else {
1550 Some(Outcome::PlanFailure {
1551 error: view_error.into(),
1552 location: location.clone(),
1553 })
1554 }
1555 } else {
1556 Some(Outcome::PlanFailure {
1557 error: view_error.into(),
1558 location: location.clone(),
1559 })
1560 }
1561 } else {
1562 None
1563 };
1564 Ok(tentative_outcome)
1565 }
1566
1567 async fn execute_view<'r>(
1568 &self,
1569 sql: &str,
1570 num_attributes: Option<usize>,
1571 output: &'r Result<QueryOutput<'_>, &'r str>,
1572 location: Location,
1573 ) -> Result<Outcome<'r>, anyhow::Error> {
1574 let expected_column_names = if let Ok(QueryOutput { column_names, .. }) = output {
1576 column_names.clone()
1577 } else {
1578 None
1579 };
1580 let (create_view, create_index, view_sql, drop_view) = generate_view_sql(
1581 sql,
1582 Uuid::new_v4().as_simple(),
1583 num_attributes,
1584 expected_column_names,
1585 );
1586 let tentative_outcome = self
1587 .execute_view_inner(create_view.as_str(), output, location.clone())
1588 .await?;
1589
1590 if let Some(view_outcome) = tentative_outcome {
1593 return Ok(view_outcome);
1594 }
1595
1596 let tentative_outcome = self
1597 .execute_view_inner(create_index.as_str(), output, location.clone())
1598 .await?;
1599
1600 let view_outcome;
1601 if let Some(outcome) = tentative_outcome {
1602 view_outcome = outcome;
1603 } else {
1604 print_sql_if(self.stdout, view_sql.as_str(), self.verbosity >= 2);
1605 view_outcome = self
1606 .execute_query(view_sql.as_str(), output, location.clone())
1607 .await?;
1608 }
1609
1610 print_sql_if(self.stdout, drop_view.as_str(), self.verbosity >= 2);
1612 self.client.execute(drop_view.as_str(), &[]).await?;
1613
1614 Ok(view_outcome)
1615 }
1616
1617 async fn run_query<'r>(
1618 &self,
1619 sql: &'r str,
1620 output: &'r Result<QueryOutput<'_>, &'r str>,
1621 location: Location,
1622 in_transaction: &mut bool,
1623 ) -> Result<Outcome<'r>, anyhow::Error> {
1624 let prepare_outcome = self
1625 .prepare_query(sql, output, location.clone(), in_transaction)
1626 .await?;
1627 match prepare_outcome {
1628 PrepareQueryOutcome::QueryPrepared(QueryInfo {
1629 is_select,
1630 num_attributes,
1631 }) => {
1632 let query_outcome = self.execute_query(sql, output, location.clone()).await?;
1633 if is_select && self.auto_index_selects {
1634 let view_outcome = self
1635 .execute_view(sql, None, output, location.clone())
1636 .await?;
1637
1638 if std::mem::discriminant::<Outcome>(&query_outcome)
1643 != std::mem::discriminant::<Outcome>(&view_outcome)
1644 {
1645 let view_outcome = if num_attributes.is_some() {
1650 self.execute_view(sql, num_attributes, output, location.clone())
1651 .await?
1652 } else {
1653 view_outcome
1654 };
1655
1656 if std::mem::discriminant::<Outcome>(&query_outcome)
1657 != std::mem::discriminant::<Outcome>(&view_outcome)
1658 {
1659 let inconsistent_view_outcome = Outcome::InconsistentViewOutcome {
1660 query_outcome: Box::new(query_outcome),
1661 view_outcome: Box::new(view_outcome),
1662 location: location.clone(),
1663 };
1664 let outcome = if should_warn(&inconsistent_view_outcome) {
1667 Outcome::Warning {
1668 cause: Box::new(inconsistent_view_outcome),
1669 location: location.clone(),
1670 }
1671 } else {
1672 inconsistent_view_outcome
1673 };
1674 return Ok(outcome);
1675 }
1676 }
1677 }
1678 Ok(query_outcome)
1679 }
1680 PrepareQueryOutcome::Outcome(outcome) => Ok(outcome),
1681 }
1682 }
1683
1684 async fn get_conn(
1685 &mut self,
1686 name: Option<&str>,
1687 user: Option<&str>,
1688 ) -> &tokio_postgres::Client {
1689 match name {
1690 None => &self.client,
1691 Some(name) => {
1692 if !self.clients.contains_key(name) {
1693 let addr = if matches!(user, Some("mz_system") | Some("mz_support")) {
1694 self.internal_server_addr
1695 } else {
1696 self.server_addr
1697 };
1698 let client = connect(addr, user).await;
1699 self.clients.insert(name.into(), client);
1700 }
1701 self.clients.get(name).unwrap()
1702 }
1703 }
1704 }
1705
1706 async fn run_simple<'r>(
1707 &mut self,
1708 conn: Option<&'r str>,
1709 user: Option<&'r str>,
1710 sql: &'r str,
1711 output: &'r Output,
1712 location: Location,
1713 ) -> Result<Outcome<'r>, anyhow::Error> {
1714 let client = self.get_conn(conn, user).await;
1715 let actual = Output::Values(match client.simple_query(sql).await {
1716 Ok(result) => result
1717 .into_iter()
1718 .filter_map(|m| match m {
1719 SimpleQueryMessage::Row(row) => {
1720 let mut s = vec![];
1721 for i in 0..row.len() {
1722 s.push(row.get(i).unwrap_or("NULL"));
1723 }
1724 Some(s.join(","))
1725 }
1726 SimpleQueryMessage::CommandComplete(count) => {
1727 Some(format!("COMPLETE {}", count))
1728 }
1729 SimpleQueryMessage::RowDescription(_) => None,
1730 _ => panic!("unexpected"),
1731 })
1732 .collect::<Vec<_>>(),
1733 Err(error) => error.to_string().lines().map(|s| s.to_string()).collect(),
1737 });
1738 if *output != actual {
1739 Ok(Outcome::OutputFailure {
1740 expected_output: output,
1741 actual_raw_output: vec![],
1742 actual_output: actual,
1743 location,
1744 })
1745 } else {
1746 Ok(Outcome::Success)
1747 }
1748 }
1749
1750 async fn check_catalog(&self) -> Result<(), anyhow::Error> {
1751 let url = format!(
1752 "http://{}/api/catalog/check",
1753 self.internal_http_server_addr
1754 );
1755 let response: serde_json::Value = reqwest::get(&url).await?.json().await?;
1756
1757 if let Some(inconsistencies) = response.get("err") {
1758 let inconsistencies = serde_json::to_string_pretty(&inconsistencies)
1759 .expect("serializing Value cannot fail");
1760 Err(anyhow::anyhow!("Catalog inconsistency\n{inconsistencies}"))
1761 } else {
1762 Ok(())
1763 }
1764 }
1765}
1766
1767async fn connect(addr: SocketAddr, user: Option<&str>) -> tokio_postgres::Client {
1768 let (client, connection) = tokio_postgres::connect(
1769 &format!(
1770 "host={} port={} user={}",
1771 addr.ip(),
1772 addr.port(),
1773 user.unwrap_or("materialize")
1774 ),
1775 NoTls,
1776 )
1777 .await
1778 .unwrap();
1779
1780 task::spawn(|| "sqllogictest_connect", async move {
1781 if let Err(e) = connection.await {
1782 eprintln!("connection error: {}", e);
1783 }
1784 });
1785 client
1786}
1787
1788pub trait WriteFmt {
1789 fn write_fmt(&self, fmt: fmt::Arguments<'_>);
1790}
1791
1792pub struct RunConfig<'a> {
1793 pub stdout: &'a dyn WriteFmt,
1794 pub stderr: &'a dyn WriteFmt,
1795 pub verbosity: u8,
1796 pub postgres_url: String,
1797 pub no_fail: bool,
1798 pub fail_fast: bool,
1799 pub auto_index_tables: bool,
1800 pub auto_index_selects: bool,
1801 pub auto_transactions: bool,
1802 pub enable_table_keys: bool,
1803 pub orchestrator_process_wrapper: Option<String>,
1804 pub tracing: TracingCliArgs,
1805 pub tracing_handle: TracingHandle,
1806 pub system_parameter_defaults: BTreeMap<String, String>,
1807 pub persist_dir: TempDir,
1812 pub replicas: usize,
1813}
1814
1815fn print_record(config: &RunConfig<'_>, record: &Record) {
1816 match record {
1817 Record::Statement { sql, .. } | Record::Query { sql, .. } => print_sql(config.stdout, sql),
1818 _ => (),
1819 }
1820}
1821
1822fn print_sql_if<'a>(stdout: &'a dyn WriteFmt, sql: &str, cond: bool) {
1823 if cond {
1824 print_sql(stdout, sql)
1825 }
1826}
1827
1828fn print_sql<'a>(stdout: &'a dyn WriteFmt, sql: &str) {
1829 writeln!(stdout, "{}", crate::util::indent(sql, 4))
1830}
1831
1832const INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS: [&str; 9] = [
1835 "cannot materialize call to",
1838 "SHOW commands are not allowed in views",
1839 "cannot create view with unstable dependencies",
1840 "cannot use wildcard expansions or NATURAL JOINs in a view that depends on system objects",
1841 "no schema has been selected to create in",
1842 r#"system schema '\w+' cannot be modified"#,
1843 r#"permission denied for (SCHEMA|CLUSTER) "(\w+\.)?\w+""#,
1844 r#"column "[\w\?]+" specified more than once"#,
1849 r#"column "(\w+\.)?\w+" does not exist"#,
1850];
1851
1852fn should_warn(outcome: &Outcome) -> bool {
1857 match outcome {
1858 Outcome::InconsistentViewOutcome {
1859 query_outcome,
1860 view_outcome,
1861 ..
1862 } => match (query_outcome.as_ref(), view_outcome.as_ref()) {
1863 (Outcome::Success, Outcome::PlanFailure { error, .. }) => {
1864 INCONSISTENT_VIEW_OUTCOME_WARNING_REGEXPS.iter().any(|s| {
1865 Regex::new(s)
1866 .expect("unexpected error in regular expression parsing")
1867 .is_match(&format!("{:#}", error))
1868 })
1869 }
1870 _ => false,
1871 },
1872 _ => false,
1873 }
1874}
1875
1876pub async fn run_string(
1877 runner: &mut Runner<'_>,
1878 source: &str,
1879 input: &str,
1880) -> Result<Outcomes, anyhow::Error> {
1881 runner.reset_database().await?;
1882
1883 let mut outcomes = Outcomes::default();
1884 let mut parser = crate::parser::Parser::new(source, input);
1885 let mut in_transaction = false;
1888 writeln!(runner.config.stdout, "--- {}", source);
1889
1890 for record in parser.parse_records()? {
1891 if runner.config.verbosity >= 2 {
1895 print_record(runner.config, &record);
1896 }
1897
1898 let outcome = runner
1899 .run_record(&record, &mut in_transaction)
1900 .await
1901 .map_err(|err| format!("In {}:\n{}", source, err))
1902 .unwrap();
1903
1904 if runner.config.verbosity >= 1 && !outcome.success() {
1906 if runner.config.verbosity < 2 {
1907 if !outcome.failure() {
1914 writeln!(
1915 runner.config.stdout,
1916 "{}",
1917 util::indent("Warning detected for: ", 4)
1918 );
1919 }
1920 print_record(runner.config, &record);
1921 }
1922 if runner.config.verbosity >= 2 || outcome.failure() {
1923 writeln!(
1924 runner.config.stdout,
1925 "{}",
1926 util::indent(&outcome.to_string(), 4)
1927 );
1928 writeln!(runner.config.stdout, "{}", util::indent("----", 4));
1929 }
1930 }
1931
1932 outcomes.stats[outcome.code()] += 1;
1933 if outcome.failure() {
1934 outcomes.details.push(format!("{}", outcome));
1935 }
1936
1937 if let Outcome::Bail { .. } = outcome {
1938 break;
1939 }
1940
1941 if runner.config.fail_fast && outcome.failure() {
1942 break;
1943 }
1944 }
1945 Ok(outcomes)
1946}
1947
1948pub async fn run_file(runner: &mut Runner<'_>, filename: &Path) -> Result<Outcomes, anyhow::Error> {
1949 let mut input = String::new();
1950 File::open(filename)?.read_to_string(&mut input)?;
1951 let outcomes = run_string(runner, &format!("{}", filename.display()), &input).await?;
1952 runner.check_catalog().await?;
1953
1954 Ok(outcomes)
1955}
1956
1957pub async fn rewrite_file(runner: &mut Runner<'_>, filename: &Path) -> Result<(), anyhow::Error> {
1958 runner.reset_database().await?;
1959
1960 let mut file = OpenOptions::new().read(true).write(true).open(filename)?;
1961
1962 let mut input = String::new();
1963 file.read_to_string(&mut input)?;
1964
1965 let mut buf = RewriteBuffer::new(&input);
1966
1967 let mut parser = crate::parser::Parser::new(filename.to_str().unwrap_or(""), &input);
1968 writeln!(runner.config.stdout, "--- {}", filename.display());
1969 let mut in_transaction = false;
1970
1971 fn append_values_output(
1972 buf: &mut RewriteBuffer,
1973 input: &String,
1974 expected_output: &str,
1975 mode: &Mode,
1976 types: &Vec<Type>,
1977 column_names: Option<&Vec<ColumnName>>,
1978 actual_output: &Vec<String>,
1979 multiline: bool,
1980 ) {
1981 buf.append_header(input, expected_output, column_names);
1982
1983 for (i, row) in actual_output.chunks(types.len()).enumerate() {
1984 match mode {
1985 Mode::Cockroach => {
1988 if i != 0 {
1989 buf.append("\n");
1990 }
1991
1992 if row.len() == 0 {
1993 } else if row.len() == 1 {
1995 if multiline {
1998 buf.append(&row[0]);
1999 } else {
2000 buf.append(&row[0].replace('\n', "⏎"))
2001 }
2002 } else {
2003 buf.append(
2006 &row.iter()
2007 .map(|col| {
2008 let mut col = col.replace(' ', "␠");
2009 if !multiline {
2010 col = col.replace('\n', "⏎");
2011 }
2012 col
2013 })
2014 .join(" "),
2015 );
2016 }
2017 }
2018 Mode::Standard => {
2023 for (j, col) in row.iter().enumerate() {
2024 if i != 0 || j != 0 {
2025 buf.append("\n");
2026 }
2027 buf.append(&if multiline {
2028 col.clone()
2029 } else {
2030 col.replace('\n', "⏎")
2031 });
2032 }
2033 }
2034 }
2035 }
2036 }
2037
2038 for record in parser.parse_records()? {
2039 let outcome = runner.run_record(&record, &mut in_transaction).await?;
2040
2041 match (&record, &outcome) {
2042 (
2045 Record::Query {
2046 output:
2047 Ok(QueryOutput {
2048 mode,
2049 output: Output::Values(_),
2050 output_str: expected_output,
2051 types,
2052 column_names,
2053 multiline,
2054 ..
2055 }),
2056 ..
2057 },
2058 Outcome::OutputFailure {
2059 actual_output: Output::Values(actual_output),
2060 ..
2061 },
2062 ) => {
2063 append_values_output(
2064 &mut buf,
2065 &input,
2066 expected_output,
2067 mode,
2068 types,
2069 column_names.as_ref(),
2070 actual_output,
2071 *multiline,
2072 );
2073 }
2074 (
2075 Record::Query {
2076 output:
2077 Ok(QueryOutput {
2078 mode,
2079 output: Output::Values(_),
2080 output_str: expected_output,
2081 types,
2082 multiline,
2083 ..
2084 }),
2085 ..
2086 },
2087 Outcome::WrongColumnNames {
2088 actual_column_names,
2089 actual_output: Output::Values(actual_output),
2090 ..
2091 },
2092 ) => {
2093 append_values_output(
2094 &mut buf,
2095 &input,
2096 expected_output,
2097 mode,
2098 types,
2099 Some(actual_column_names),
2100 actual_output,
2101 *multiline,
2102 );
2103 }
2104 (
2105 Record::Query {
2106 output:
2107 Ok(QueryOutput {
2108 output: Output::Hashed { .. },
2109 output_str: expected_output,
2110 column_names,
2111 ..
2112 }),
2113 ..
2114 },
2115 Outcome::OutputFailure {
2116 actual_output: Output::Hashed { num_values, md5 },
2117 ..
2118 },
2119 ) => {
2120 buf.append_header(&input, expected_output, column_names.as_ref());
2121
2122 buf.append(format!("{} values hashing to {}\n", num_values, md5).as_str())
2123 }
2124 (
2125 Record::Simple {
2126 output_str: expected_output,
2127 ..
2128 },
2129 Outcome::OutputFailure {
2130 actual_output: Output::Values(actual_output),
2131 ..
2132 },
2133 ) => {
2134 buf.append_header(&input, expected_output, None);
2135
2136 for (i, row) in actual_output.iter().enumerate() {
2137 if i != 0 {
2138 buf.append("\n");
2139 }
2140 buf.append(row);
2141 }
2142 }
2143 (
2144 Record::Query {
2145 sql,
2146 output: Err(err),
2147 ..
2148 },
2149 outcome,
2150 )
2151 | (
2152 Record::Statement {
2153 expected_error: Some(err),
2154 sql,
2155 ..
2156 },
2157 outcome,
2158 ) if outcome.err_msg().is_some() => {
2159 buf.rewrite_expected_error(&input, err, &outcome.err_msg().unwrap(), sql)
2160 }
2161 (_, Outcome::Success) => {}
2162 _ => bail!("unexpected: {:?} {:?}", record, outcome),
2163 }
2164 }
2165
2166 file.set_len(0)?;
2167 file.seek(SeekFrom::Start(0))?;
2168 file.write_all(buf.finish().as_bytes())?;
2169 file.sync_all()?;
2170 Ok(())
2171}
2172
2173#[derive(Debug)]
2182struct RewriteBuffer<'a> {
2183 input: &'a str,
2184 input_offset: usize,
2185 output: String,
2186}
2187
2188impl<'a> RewriteBuffer<'a> {
2189 fn new(input: &'a str) -> RewriteBuffer<'a> {
2190 RewriteBuffer {
2191 input,
2192 input_offset: 0,
2193 output: String::new(),
2194 }
2195 }
2196
2197 fn flush_to(&mut self, offset: usize) {
2198 assert!(offset >= self.input_offset);
2199 let chunk = &self.input[self.input_offset..offset];
2200 self.output.push_str(chunk);
2201 self.input_offset = offset;
2202 }
2203
2204 fn skip_to(&mut self, offset: usize) {
2205 assert!(offset >= self.input_offset);
2206 self.input_offset = offset;
2207 }
2208
2209 fn append(&mut self, s: &str) {
2210 self.output.push_str(s);
2211 }
2212
2213 fn append_header(
2214 &mut self,
2215 input: &String,
2216 expected_output: &str,
2217 column_names: Option<&Vec<ColumnName>>,
2218 ) {
2219 #[allow(clippy::as_conversions)]
2222 let offset = expected_output.as_ptr() as usize - input.as_ptr() as usize;
2223 self.flush_to(offset);
2224 self.skip_to(offset + expected_output.len());
2225
2226 if self.peek_last(5) == "\n----" {
2229 self.append("\n");
2230 } else if self.peek_last(6) != "\n----\n" {
2231 self.append("\n----\n");
2232 }
2233
2234 let Some(names) = column_names else {
2235 return;
2236 };
2237 self.append(
2238 &names
2239 .iter()
2240 .map(|name| name.as_str().replace(' ', "␠"))
2241 .collect::<Vec<_>>()
2242 .join(" "),
2243 );
2244 self.append("\n");
2245 }
2246
2247 fn rewrite_expected_error(
2248 &mut self,
2249 input: &String,
2250 old_err: &str,
2251 new_err: &str,
2252 query: &str,
2253 ) {
2254 #[allow(clippy::as_conversions)]
2257 let err_offset = old_err.as_ptr() as usize - input.as_ptr() as usize;
2258 self.flush_to(err_offset);
2259 self.append(new_err);
2260 self.append("\n");
2261 self.append(query);
2262 #[allow(clippy::as_conversions)]
2264 self.skip_to(query.as_ptr() as usize - input.as_ptr() as usize + query.len())
2265 }
2266
2267 fn peek_last(&self, n: usize) -> &str {
2268 &self.output[self.output.len() - n..]
2269 }
2270
2271 fn finish(mut self) -> String {
2272 self.flush_to(self.input.len());
2273 self.output
2274 }
2275}
2276
2277fn generate_view_sql(
2285 sql: &str,
2286 view_uuid: &Simple,
2287 num_attributes: Option<usize>,
2288 expected_column_names: Option<Vec<ColumnName>>,
2289) -> (String, String, String, String) {
2290 let stmts = parser::parse_statements(sql).unwrap_or_default();
2298 assert!(stmts.len() == 1);
2299 let (query, query_as_of) = match &stmts[0].ast {
2300 Statement::Select(stmt) => (&stmt.query, &stmt.as_of),
2301 _ => unreachable!("This function should only be called for SELECTs"),
2302 };
2303
2304 let (view_order_by, extra_columns, distinct) = if num_attributes.is_none() {
2309 (query.order_by.clone(), vec![], None)
2310 } else {
2311 derive_order_by(&query.body, &query.order_by)
2312 };
2313
2314 let name = UnresolvedItemName(vec![Ident::new_unchecked(format!("v{}", view_uuid))]);
2330 let projection = expected_column_names.map_or(
2331 num_attributes.map_or(vec![], |n| {
2332 (1..=n)
2333 .map(|i| Ident::new_unchecked(format!("a{i}")))
2334 .collect()
2335 }),
2336 |cols| {
2337 cols.iter()
2338 .map(|c| Ident::new_unchecked(c.as_str()))
2339 .collect()
2340 },
2341 );
2342 let columns: Vec<Ident> = projection
2343 .iter()
2344 .cloned()
2345 .chain(extra_columns.iter().map(|item| {
2346 if let SelectItem::Expr {
2347 expr: _,
2348 alias: Some(ident),
2349 } = item
2350 {
2351 ident.clone()
2352 } else {
2353 unreachable!("alias must be given for extra column")
2354 }
2355 }))
2356 .collect();
2357
2358 let mut query = query.clone();
2360 if extra_columns.len() > 0 {
2361 match &mut query.body {
2362 SetExpr::Select(stmt) => stmt.projection.extend(extra_columns.iter().cloned()),
2363 _ => unimplemented!("cannot yet rewrite projections of nested queries"),
2364 }
2365 }
2366 let create_view = AstStatement::<Raw>::CreateView(CreateViewStatement {
2367 if_exists: IfExistsBehavior::Error,
2368 temporary: false,
2369 definition: ViewDefinition {
2370 name: name.clone(),
2371 columns: columns.clone(),
2372 query,
2373 },
2374 })
2375 .to_ast_string_stable();
2376
2377 let create_index = AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2381 name: None,
2382 in_cluster: None,
2383 on_name: RawItemName::Name(name.clone()),
2384 key_parts: if columns.len() == 0 {
2385 None
2386 } else {
2387 Some(
2388 columns
2389 .iter()
2390 .map(|ident| Expr::Identifier(vec![ident.clone()]))
2391 .collect(),
2392 )
2393 },
2394 with_options: Vec::new(),
2395 if_not_exists: false,
2396 })
2397 .to_ast_string_stable();
2398
2399 let distinct_unneeded = extra_columns.len() == 0
2401 || match distinct {
2402 None | Some(Distinct::On(_)) => true,
2403 Some(Distinct::EntireRow) => false,
2404 };
2405 let distinct = if distinct_unneeded { None } else { distinct };
2406
2407 let view_sql = AstStatement::<Raw>::Select(SelectStatement {
2409 query: Query {
2410 ctes: CteBlock::Simple(vec![]),
2411 body: SetExpr::Select(Box::new(Select {
2412 distinct,
2413 projection: if projection.len() == 0 {
2414 vec![SelectItem::Wildcard]
2415 } else {
2416 projection
2417 .iter()
2418 .map(|ident| SelectItem::Expr {
2419 expr: Expr::Identifier(vec![ident.clone()]),
2420 alias: None,
2421 })
2422 .collect()
2423 },
2424 from: vec![TableWithJoins {
2425 relation: TableFactor::Table {
2426 name: RawItemName::Name(name.clone()),
2427 alias: None,
2428 },
2429 joins: vec![],
2430 }],
2431 selection: None,
2432 group_by: vec![],
2433 having: None,
2434 qualify: None,
2435 options: vec![],
2436 })),
2437 order_by: view_order_by,
2438 limit: None,
2439 offset: None,
2440 },
2441 as_of: query_as_of.clone(),
2442 })
2443 .to_ast_string_stable();
2444
2445 let drop_view = AstStatement::<Raw>::DropObjects(DropObjectsStatement {
2447 object_type: ObjectType::View,
2448 if_exists: false,
2449 names: vec![UnresolvedObjectName::Item(name)],
2450 cascade: false,
2451 })
2452 .to_ast_string_stable();
2453
2454 (create_view, create_index, view_sql, drop_view)
2455}
2456
2457fn derive_num_attributes(body: &SetExpr<Raw>) -> Option<usize> {
2462 let Some((projection, _)) = find_projection(body) else {
2463 return None;
2464 };
2465 derive_num_attributes_from_projection(projection)
2466}
2467
2468fn derive_order_by(
2479 body: &SetExpr<Raw>,
2480 order_by: &Vec<OrderByExpr<Raw>>,
2481) -> (
2482 Vec<OrderByExpr<Raw>>,
2483 Vec<SelectItem<Raw>>,
2484 Option<Distinct<Raw>>,
2485) {
2486 let Some((projection, distinct)) = find_projection(body) else {
2487 return (vec![], vec![], None);
2488 };
2489 let (view_order_by, extra_columns) = derive_order_by_from_projection(projection, order_by);
2490 (view_order_by, extra_columns, distinct.clone())
2491}
2492
2493fn find_projection(body: &SetExpr<Raw>) -> Option<(&Vec<SelectItem<Raw>>, &Option<Distinct<Raw>>)> {
2495 let mut set_expr = body;
2498 loop {
2499 match set_expr {
2500 SetExpr::Select(select) => {
2501 return Some((&select.projection, &select.distinct));
2502 }
2503 SetExpr::SetOperation { left, .. } => set_expr = left.as_ref(),
2504 SetExpr::Query(query) => set_expr = &query.body,
2505 _ => return None,
2506 }
2507 }
2508}
2509
2510fn derive_num_attributes_from_projection(projection: &Vec<SelectItem<Raw>>) -> Option<usize> {
2514 let mut num_attributes = 0usize;
2515 for item in projection.iter() {
2516 let SelectItem::Expr { expr, .. } = item else {
2517 return None;
2518 };
2519 match expr {
2520 Expr::QualifiedWildcard(..) | Expr::WildcardAccess(..) => {
2521 return None;
2522 }
2523 _ => {
2524 num_attributes += 1;
2525 }
2526 }
2527 }
2528 Some(num_attributes)
2529}
2530
2531fn derive_order_by_from_projection(
2536 projection: &Vec<SelectItem<Raw>>,
2537 order_by: &Vec<OrderByExpr<Raw>>,
2538) -> (Vec<OrderByExpr<Raw>>, Vec<SelectItem<Raw>>) {
2539 let mut view_order_by: Vec<OrderByExpr<Raw>> = vec![];
2540 let mut extra_columns: Vec<SelectItem<Raw>> = vec![];
2541 for order_by_expr in order_by.iter() {
2542 let query_expr = &order_by_expr.expr;
2543 let view_expr = match query_expr {
2544 Expr::Value(mz_sql_parser::ast::Value::Number(_)) => query_expr.clone(),
2545 _ => {
2546 if let Some(i) = projection.iter().position(|item| match item {
2548 SelectItem::Expr { expr, alias } => {
2549 expr == query_expr
2550 || match query_expr {
2551 Expr::Identifier(ident) => {
2552 ident.len() == 1 && Some(&ident[0]) == alias.as_ref()
2553 }
2554 _ => false,
2555 }
2556 }
2557 SelectItem::Wildcard => false,
2558 }) {
2559 Expr::Value(mz_sql_parser::ast::Value::Number((i + 1).to_string()))
2560 } else {
2561 let ident = Ident::new_unchecked(format!(
2564 "a{}",
2565 (projection.len() + extra_columns.len() + 1)
2566 ));
2567 extra_columns.push(SelectItem::Expr {
2568 expr: query_expr.clone(),
2569 alias: Some(ident.clone()),
2570 });
2571 Expr::Identifier(vec![ident])
2572 }
2573 }
2574 };
2575 view_order_by.push(OrderByExpr {
2576 expr: view_expr,
2577 asc: order_by_expr.asc,
2578 nulls_last: order_by_expr.nulls_last,
2579 });
2580 }
2581 (view_order_by, extra_columns)
2582}
2583
2584fn mutate(sql: &str) -> Vec<String> {
2586 let stmts = parser::parse_statements(sql).unwrap_or_default();
2587 let mut additional = Vec::new();
2588 for stmt in stmts {
2589 match stmt.ast {
2590 AstStatement::CreateTable(stmt) => additional.push(
2591 AstStatement::<Raw>::CreateIndex(CreateIndexStatement {
2594 name: None,
2595 in_cluster: None,
2596 on_name: RawItemName::Name(stmt.name.clone()),
2597 key_parts: Some(
2598 stmt.columns
2599 .iter()
2600 .map(|def| Expr::Identifier(vec![def.name.clone()]))
2601 .collect(),
2602 ),
2603 with_options: Vec::new(),
2604 if_not_exists: false,
2605 })
2606 .to_ast_string_stable(),
2607 ),
2608 _ => {}
2609 }
2610 }
2611 additional
2612}
2613
2614#[mz_ore::test]
2615#[cfg_attr(miri, ignore)] fn test_generate_view_sql() {
2617 let uuid = Uuid::parse_str("67e5504410b1426f9247bb680e5fe0c8").unwrap();
2618 let cases = vec![
2619 (("SELECT * FROM t", None, None),
2620 (
2621 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM "t""#.to_string(),
2622 r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2623 r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2624 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2625 )),
2626 (("SELECT a, b, c FROM t1, t2", Some(3), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c")])),
2627 (
2628 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2629 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c")"#.to_string(),
2630 r#"SELECT "a", "b", "c" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2631 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2632 )),
2633 (("SELECT a, b, c FROM t1, t2", Some(3), None),
2634 (
2635 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "b", "c" FROM "t1", "t2""#.to_string(),
2636 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2637 r#"SELECT "a1", "a2", "a3" FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2638 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2639 )),
2640 (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a)", None, None),
2643 (
2644 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a")"#.to_string(),
2645 r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2646 r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2647 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2648 )),
2649 (("SELECT a, b, b + d AS c, a + b AS d FROM t1, t2 ORDER BY a, c, a + b", Some(4), Some(vec![ColumnName::from("a"), ColumnName::from("b"), ColumnName::from("c"), ColumnName::from("d")])),
2650 (
2651 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d") AS SELECT "a", "b", "b" + "d" AS "c", "a" + "b" AS "d" FROM "t1", "t2" ORDER BY "a", "c", "a" + "b""#.to_string(),
2652 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a", "b", "c", "d")"#.to_string(),
2653 r#"SELECT "a", "b", "c", "d" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, 3, 4"#.to_string(),
2654 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2655 )),
2656 (("((SELECT 1 AS a UNION SELECT 2 AS b) UNION SELECT 3 AS c) ORDER BY a", Some(1), None),
2657 (
2658 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1") AS (SELECT 1 AS "a" UNION SELECT 2 AS "b") UNION SELECT 3 AS "c" ORDER BY "a""#.to_string(),
2659 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1")"#.to_string(),
2660 r#"SELECT "a1" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2661 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2662 )),
2663 (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY 1", None, None),
2664 (
2665 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY 1"#.to_string(),
2666 r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2667 r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1"#.to_string(),
2668 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2669 )),
2670 (("SELECT * FROM (SELECT a, sum(b) AS a FROM t GROUP BY a) ORDER BY a", None, None),
2671 (
2672 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" AS SELECT * FROM (SELECT "a", "sum"("b") AS "a" FROM "t" GROUP BY "a") ORDER BY "a""#.to_string(),
2673 r#"CREATE DEFAULT INDEX ON "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2674 r#"SELECT * FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a""#.to_string(),
2675 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2676 )),
2677 (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY a, c", Some(2), None),
2678 (
2679 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "a", "c""#.to_string(),
2680 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2681 r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY 1, "a3""#.to_string(),
2682 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2683 )),
2684 (("SELECT a, sum(b) AS a FROM t GROUP BY a, c ORDER BY c, a", Some(2), None),
2685 (
2686 r#"CREATE VIEW "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3") AS SELECT "a", "sum"("b") AS "a", "c" AS "a3" FROM "t" GROUP BY "a", "c" ORDER BY "c", "a""#.to_string(),
2687 r#"CREATE INDEX ON "v67e5504410b1426f9247bb680e5fe0c8" ("a1", "a2", "a3")"#.to_string(),
2688 r#"SELECT "a1", "a2" FROM "v67e5504410b1426f9247bb680e5fe0c8" ORDER BY "a3", 1"#.to_string(),
2689 r#"DROP VIEW "v67e5504410b1426f9247bb680e5fe0c8""#.to_string(),
2690 )),
2691 ];
2692 for ((sql, num_attributes, expected_column_names), expected) in cases {
2693 let view_sql =
2694 generate_view_sql(sql, uuid.as_simple(), num_attributes, expected_column_names);
2695 assert_eq!(expected, view_sql);
2696 }
2697}
2698
2699#[mz_ore::test]
2700fn test_mutate() {
2701 let cases = vec![
2702 ("CREATE TABLE t ()", vec![r#"CREATE INDEX ON "t" ()"#]),
2703 (
2704 "CREATE TABLE t (a INT)",
2705 vec![r#"CREATE INDEX ON "t" ("a")"#],
2706 ),
2707 (
2708 "CREATE TABLE t (a INT, b TEXT)",
2709 vec![r#"CREATE INDEX ON "t" ("a", "b")"#],
2710 ),
2711 ("BAD SYNTAX", Vec::new()),
2713 ];
2714 for (sql, expected) in cases {
2715 let stmts = mutate(sql);
2716 assert_eq!(expected, stmts, "sql: {sql}");
2717 }
2718}