sqllogictest/
sqllogictest.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::collections::btree_map::Entry;
13use std::fmt;
14use std::fs::File;
15use std::io::{self, Write};
16use std::path::PathBuf;
17use std::process::ExitCode;
18
19use chrono::Utc;
20use clap::ArgAction;
21use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
22use mz_ore::cli::{self, CliConfig, KeyValueArg};
23use mz_ore::metrics::MetricsRegistry;
24use mz_sql::session::vars::{ENABLE_LOGICAL_COMPACTION_WINDOW, Var, VarInput};
25use mz_sqllogictest::runner::{self, Outcomes, RunConfig, Runner, WriteFmt};
26use mz_sqllogictest::util;
27use mz_tracing::CloneableEnvFilter;
28#[allow(deprecated)] // fails with libraries still using old time lib
29use time::Instant;
30use walkdir::WalkDir;
31
32/// Runs sqllogictest scripts to verify database engine correctness.
33#[derive(clap::Parser)]
34struct Args {
35    /// Increase verbosity.
36    ///
37    /// If specified once, print summary for each source file.
38    /// If specified twice, also show descriptions of each error.
39    /// If specified thrice, also print each query before it is executed.
40    #[clap(short = 'v', long = "verbose", action = ArgAction::Count)]
41    verbosity: u8,
42    /// Don't exit with a failing code if not all queries are successful.
43    #[clap(long)]
44    no_fail: bool,
45    /// Prefix every line of output with the current time.
46    #[clap(long)]
47    timestamps: bool,
48    /// Rewrite expected output based on actual output.
49    #[clap(long)]
50    rewrite_results: bool,
51    /// Generate a JUnit-compatible XML report to the specified file.
52    #[clap(long, value_name = "FILE")]
53    junit_report: Option<PathBuf>,
54    /// PostgreSQL connection URL to use for `persist` consensus.
55    #[clap(long)]
56    postgres_url: String,
57    /// PostgreSQL prefix for this SLT run
58    #[clap(long, default_value = "sqllogictest")]
59    prefix: String,
60    /// Path to sqllogictest script to run.
61    #[clap(value_name = "PATH", required = true)]
62    paths: Vec<String>,
63    /// Stop on first failure.
64    #[clap(long)]
65    fail_fast: bool,
66    /// Inject `CREATE INDEX` after all `CREATE TABLE` statements.
67    #[clap(long)]
68    auto_index_tables: bool,
69    /// Inject `CREATE VIEW <view_name> AS <select_query>` and `CREATE DEFAULT INDEX ON <view_name> ...`
70    /// to redundantly execute a given `SELECT` query and contrast outcomes.
71    #[clap(long)]
72    auto_index_selects: bool,
73    /// Inject `BEGIN` and `COMMIT` to create longer running transactions for faster testing of the
74    /// ported SQLite SLT files. Does not work generally, so don't use it for other tests.
75    #[clap(long)]
76    auto_transactions: bool,
77    /// Inject `ALTER SYSTEM SET unsafe_enable_table_keys = true` before running the SLT file.
78    #[clap(long)]
79    enable_table_keys: bool,
80    /// Divide the test files into shards and run only the test files in this shard.
81    #[clap(long, requires = "shard_count", value_name = "N")]
82    shard: Option<usize>,
83    /// Total number of shards in use.
84    #[clap(long, requires = "shard", value_name = "N")]
85    shard_count: Option<usize>,
86    /// Wrapper program to start child processes
87    #[clap(long, env = "ORCHESTRATOR_PROCESS_WRAPPER")]
88    orchestrator_process_wrapper: Option<String>,
89    /// Replica size
90    #[clap(long, default_value = "scale=1,workers=2")]
91    replica_size: String,
92    /// Replication factor
93    #[clap(long, default_value = "1")]
94    replicas: usize,
95    /// An list of NAME=VALUE pairs used to override static defaults
96    /// for system parameters.
97    #[clap(
98        long,
99        env = "SYSTEM_PARAMETER_DEFAULT",
100        action = ArgAction::Append,
101        value_delimiter = ';'
102    )]
103    system_parameter_default: Vec<KeyValueArg<String, String>>,
104    #[clap(
105        long,
106        env = "LOG_FILTER",
107        value_name = "FILTER",
108        default_value = "warn"
109    )]
110    pub log_filter: CloneableEnvFilter,
111}
112
113#[tokio::main]
114async fn main() -> ExitCode {
115    mz_ore::panic::install_enhanced_handler();
116
117    let args: Args = cli::parse_args(CliConfig {
118        env_prefix: Some("MZ_"),
119        enable_version_flag: false,
120    });
121
122    let tracing_args = TracingCliArgs {
123        startup_log_filter: args.log_filter.clone(),
124        ..Default::default()
125    };
126    let tracing_handle = tracing_args
127        .configure_tracing(
128            StaticTracingConfig {
129                service_name: "sqllogictest",
130                build_info: mz_environmentd::BUILD_INFO,
131            },
132            MetricsRegistry::new(),
133        )
134        .await
135        .unwrap();
136
137    // sqllogictest requires that Materialize have some system variables set to some specific value
138    // to pass. If the caller hasn't set this variable, then we set it for them. If the caller has
139    // set this variable, then we assert that it's set to the right value.
140    let required_system_defaults: Vec<_> = [(ENABLE_LOGICAL_COMPACTION_WINDOW.flag, "true")].into();
141    let mut system_parameter_defaults: BTreeMap<_, _> = args
142        .system_parameter_default
143        .clone()
144        .into_iter()
145        .map(|kv| (kv.key, kv.value))
146        .collect();
147    for (var, value) in required_system_defaults {
148        let parse = |value| {
149            var.parse(VarInput::Flat(value))
150                .expect("invalid value")
151                .format()
152        };
153        let value = parse(value);
154        match system_parameter_defaults.entry(var.name().to_string()) {
155            Entry::Vacant(entry) => {
156                entry.insert(value);
157            }
158            Entry::Occupied(entry) => {
159                assert_eq!(
160                    value,
161                    parse(entry.get()),
162                    "sqllogictest test requires {} to have a value of {}",
163                    var.name(),
164                    value
165                )
166            }
167        }
168    }
169
170    let config = RunConfig {
171        stdout: &OutputStream::new(io::stdout(), args.timestamps),
172        stderr: &OutputStream::new(io::stderr(), args.timestamps),
173        verbosity: args.verbosity,
174        postgres_url: args.postgres_url.clone(),
175        prefix: args.prefix.clone(),
176        no_fail: args.no_fail,
177        fail_fast: args.fail_fast,
178        auto_index_tables: args.auto_index_tables,
179        auto_index_selects: args.auto_index_selects,
180        auto_transactions: args.auto_transactions,
181        enable_table_keys: args.enable_table_keys,
182        orchestrator_process_wrapper: args.orchestrator_process_wrapper.clone(),
183        tracing: tracing_args.clone(),
184        tracing_handle,
185        system_parameter_defaults,
186        persist_dir: match tempfile::tempdir() {
187            Ok(t) => t,
188            Err(e) => {
189                eprintln!("error creating state dir: {e}");
190                return ExitCode::FAILURE;
191            }
192        },
193        replicas: args.replicas,
194        replica_size: args.replica_size.clone(),
195    };
196
197    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
198        if shard != 0 || shard_count != 1 {
199            eprintln!("Shard: {}/{}", shard + 1, shard_count);
200        }
201    }
202
203    if args.rewrite_results {
204        return rewrite(&config, args).await;
205    }
206
207    let mut junit = match args.junit_report {
208        Some(filename) => match File::create(&filename) {
209            Ok(file) => Some((file, junit_report::TestSuite::new("sqllogictest"))),
210            Err(err) => {
211                writeln!(config.stderr, "creating {}: {}", filename.display(), err);
212                return ExitCode::FAILURE;
213            }
214        },
215        None => None,
216    };
217    let mut outcomes = Outcomes::default();
218    let mut runner = Runner::start(&config).await.unwrap();
219    let mut paths = args.paths;
220
221    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
222        paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
223    }
224
225    for path in &paths {
226        for entry in WalkDir::new(path) {
227            match entry {
228                Ok(entry) if entry.file_type().is_file() => {
229                    #[allow(deprecated)] // fails with libraries still using old time lib
230                    let start_time = Instant::now();
231                    match runner::run_file(&mut runner, entry.path()).await {
232                        Ok(o) => {
233                            if o.any_failed() || config.verbosity >= 1 {
234                                writeln!(
235                                    config.stdout,
236                                    "{}",
237                                    util::indent(&o.display(config.no_fail, false).to_string(), 4)
238                                );
239                            }
240                            if let Some((_, junit_suite)) = &mut junit {
241                                let mut test_case = if o.any_failed() && !args.no_fail {
242                                    let mut result = junit_report::TestCase::failure(
243                                        &entry.path().to_string_lossy(),
244                                        start_time.elapsed(),
245                                        "failure",
246                                        "",
247                                    );
248                                    // Encode in system_out so we can display newlines
249                                    result.system_out = Some(
250                                        o.display(false, true)
251                                            .to_string()
252                                            .trim_end_matches('\n')
253                                            .to_string(),
254                                    );
255                                    result
256                                } else {
257                                    junit_report::TestCase::success(
258                                        &entry.path().to_string_lossy(),
259                                        start_time.elapsed(),
260                                    )
261                                };
262                                test_case.set_classname("sqllogictest");
263                                junit_suite.add_testcase(test_case);
264                            }
265                            outcomes += o;
266                        }
267                        Err(err) => {
268                            writeln!(
269                                config.stderr,
270                                "FAIL: error: running file {}: {}",
271                                entry.file_name().to_string_lossy(),
272                                err
273                            );
274                            return ExitCode::FAILURE;
275                        }
276                    }
277                }
278                Ok(_) => (),
279                Err(err) => {
280                    writeln!(
281                        config.stderr,
282                        "FAIL: error: reading directory entry: {}",
283                        err
284                    );
285                    return ExitCode::FAILURE;
286                }
287            }
288        }
289    }
290
291    writeln!(config.stdout, "{}", outcomes.display(config.no_fail, false));
292
293    if let Some((mut junit_file, junit_suite)) = junit {
294        let report = junit_report::ReportBuilder::new()
295            .add_testsuite(junit_suite)
296            .build();
297        match report.write_xml(&mut junit_file) {
298            Ok(()) => (),
299            Err(err) => {
300                writeln!(
301                    config.stderr,
302                    "error: unable to write junit report: {}",
303                    err
304                );
305                return ExitCode::from(2);
306            }
307        }
308    }
309
310    if outcomes.any_failed() && !args.no_fail {
311        return ExitCode::FAILURE;
312    }
313    ExitCode::SUCCESS
314}
315
316async fn rewrite(config: &RunConfig<'_>, args: Args) -> ExitCode {
317    if args.junit_report.is_some() {
318        writeln!(
319            config.stderr,
320            "--rewrite-results is not compatible with --junit-report"
321        );
322        return ExitCode::FAILURE;
323    }
324
325    if args.paths.iter().any(|path| path == "-") {
326        writeln!(config.stderr, "--rewrite-results cannot be used with stdin");
327        return ExitCode::FAILURE;
328    }
329
330    let mut runner = Runner::start(config).await.unwrap();
331    let mut paths = args.paths;
332
333    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
334        paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
335    }
336
337    for path in paths {
338        for entry in WalkDir::new(path) {
339            match entry {
340                Ok(entry) => {
341                    if entry.file_type().is_file() {
342                        if let Err(err) = runner::rewrite_file(&mut runner, entry.path()).await {
343                            writeln!(config.stderr, "FAIL: error: rewriting file: {}", err);
344                            return ExitCode::FAILURE;
345                        }
346                    }
347                }
348                Err(err) => {
349                    writeln!(
350                        config.stderr,
351                        "FAIL: error: reading directory entry: {}",
352                        err
353                    );
354                    return ExitCode::FAILURE;
355                }
356            }
357        }
358    }
359    ExitCode::SUCCESS
360}
361
362struct OutputStream<W> {
363    inner: RefCell<W>,
364    need_timestamp: RefCell<bool>,
365    timestamps: bool,
366}
367
368impl<W> OutputStream<W>
369where
370    W: Write,
371{
372    fn new(inner: W, timestamps: bool) -> OutputStream<W> {
373        OutputStream {
374            inner: RefCell::new(inner),
375            need_timestamp: RefCell::new(true),
376            timestamps,
377        }
378    }
379
380    fn emit_str(&self, s: &str) {
381        self.inner.borrow_mut().write_all(s.as_bytes()).unwrap();
382    }
383}
384
385impl<W> WriteFmt for OutputStream<W>
386where
387    W: Write,
388{
389    fn write_fmt(&self, fmt: fmt::Arguments<'_>) {
390        let s = format!("{}", fmt);
391        if self.timestamps {
392            // We need to prefix every line in `s` with the current timestamp.
393
394            let timestamp = Utc::now();
395            let timestamp_str = timestamp.format("%Y-%m-%d %H:%M:%S.%f %Z");
396
397            // If the last character we outputted was a newline, then output a
398            // timestamp prefix at the start of this line.
399            if self.need_timestamp.replace(false) {
400                self.emit_str(&format!("[{}] ", timestamp_str));
401            }
402
403            // Emit `s`, installing a timestamp at the start of every line
404            // except the last.
405            let (s, last_was_timestamp) = match s.strip_suffix('\n') {
406                None => (&*s, false),
407                Some(s) => (s, true),
408            };
409            self.emit_str(&s.replace('\n', &format!("\n[{}] ", timestamp_str)));
410
411            // If the line ended with a newline, output the newline but *not*
412            // the timestamp prefix. We want the timestamp to reflect the moment
413            // the *next* character is output. So instead we just remember that
414            // the last character we output was a newline.
415            if last_was_timestamp {
416                *self.need_timestamp.borrow_mut() = true;
417                self.emit_str("\n");
418            }
419        } else {
420            self.emit_str(&s)
421        }
422    }
423}