sqllogictest/
sqllogictest.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::collections::btree_map::Entry;
13use std::fmt;
14use std::fs::File;
15use std::io::{self, Write};
16use std::path::PathBuf;
17use std::process::ExitCode;
18
19use chrono::Utc;
20use clap::ArgAction;
21use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
22use mz_ore::cli::{self, CliConfig, KeyValueArg};
23use mz_ore::metrics::MetricsRegistry;
24use mz_sql::session::vars::{
25    DISK_CLUSTER_REPLICAS_DEFAULT, ENABLE_LOGICAL_COMPACTION_WINDOW, Var, VarInput,
26};
27use mz_sqllogictest::runner::{self, Outcomes, RunConfig, Runner, WriteFmt};
28use mz_sqllogictest::util;
29use mz_tracing::CloneableEnvFilter;
30#[allow(deprecated)] // fails with libraries still using old time lib
31use time::Instant;
32use walkdir::WalkDir;
33
34/// Runs sqllogictest scripts to verify database engine correctness.
35#[derive(clap::Parser)]
36struct Args {
37    /// Increase verbosity.
38    ///
39    /// If specified once, print summary for each source file.
40    /// If specified twice, also show descriptions of each error.
41    /// If specified thrice, also print each query before it is executed.
42    #[clap(short = 'v', long = "verbose", action = ArgAction::Count)]
43    verbosity: u8,
44    /// Don't exit with a failing code if not all queries are successful.
45    #[clap(long)]
46    no_fail: bool,
47    /// Prefix every line of output with the current time.
48    #[clap(long)]
49    timestamps: bool,
50    /// Rewrite expected output based on actual output.
51    #[clap(long)]
52    rewrite_results: bool,
53    /// Generate a JUnit-compatible XML report to the specified file.
54    #[clap(long, value_name = "FILE")]
55    junit_report: Option<PathBuf>,
56    /// PostgreSQL connection URL to use for `persist` consensus.
57    #[clap(long)]
58    postgres_url: String,
59    /// Path to sqllogictest script to run.
60    #[clap(value_name = "PATH", required = true)]
61    paths: Vec<String>,
62    /// Stop on first failure.
63    #[clap(long)]
64    fail_fast: bool,
65    /// Inject `CREATE INDEX` after all `CREATE TABLE` statements.
66    #[clap(long)]
67    auto_index_tables: bool,
68    /// Inject `CREATE VIEW <view_name> AS <select_query>` and `CREATE DEFAULT INDEX ON <view_name> ...`
69    /// to redundantly execute a given `SELECT` query and contrast outcomes.
70    #[clap(long)]
71    auto_index_selects: bool,
72    /// Inject `BEGIN` and `COMMIT` to create longer running transactions for faster testing of the
73    /// ported SQLite SLT files. Does not work generally, so don't use it for other tests.
74    #[clap(long)]
75    auto_transactions: bool,
76    /// Inject `ALTER SYSTEM SET unsafe_enable_table_keys = true` before running the SLT file.
77    #[clap(long)]
78    enable_table_keys: bool,
79    /// Divide the test files into shards and run only the test files in this shard.
80    #[clap(long, requires = "shard_count", value_name = "N")]
81    shard: Option<usize>,
82    /// Total number of shards in use.
83    #[clap(long, requires = "shard", value_name = "N")]
84    shard_count: Option<usize>,
85    /// Wrapper program to start child processes
86    #[clap(long, env = "ORCHESTRATOR_PROCESS_WRAPPER")]
87    orchestrator_process_wrapper: Option<String>,
88    /// Number of replicas, defaults to 2
89    #[clap(long, default_value = "2")]
90    replicas: usize,
91    /// An list of NAME=VALUE pairs used to override static defaults
92    /// for system parameters.
93    #[clap(
94        long,
95        env = "SYSTEM_PARAMETER_DEFAULT",
96        action = ArgAction::Append,
97        value_delimiter = ';'
98    )]
99    system_parameter_default: Vec<KeyValueArg<String, String>>,
100    #[clap(
101        long,
102        env = "LOG_FILTER",
103        value_name = "FILTER",
104        default_value = "warn"
105    )]
106    pub log_filter: CloneableEnvFilter,
107}
108
109#[tokio::main]
110async fn main() -> ExitCode {
111    mz_ore::panic::install_enhanced_handler();
112
113    let args: Args = cli::parse_args(CliConfig {
114        env_prefix: Some("MZ_"),
115        enable_version_flag: false,
116    });
117
118    let tracing_args = TracingCliArgs {
119        startup_log_filter: args.log_filter.clone(),
120        ..Default::default()
121    };
122    let (tracing_handle, _tracing_guard) = tracing_args
123        .configure_tracing(
124            StaticTracingConfig {
125                service_name: "sqllogictest",
126                build_info: mz_environmentd::BUILD_INFO,
127            },
128            MetricsRegistry::new(),
129        )
130        .await
131        .unwrap();
132
133    // sqllogictest requires that Materialize have some system variables set to some specific value
134    // to pass. If the caller hasn't set this variable, then we set it for them. If the caller has
135    // set this variable, then we assert that it's set to the right value.
136    let required_system_defaults: Vec<_> = [
137        (&DISK_CLUSTER_REPLICAS_DEFAULT, "true"),
138        (ENABLE_LOGICAL_COMPACTION_WINDOW.flag, "true"),
139    ]
140    .into();
141    let mut system_parameter_defaults: BTreeMap<_, _> = args
142        .system_parameter_default
143        .clone()
144        .into_iter()
145        .map(|kv| (kv.key, kv.value))
146        .collect();
147    for (var, value) in required_system_defaults {
148        let parse = |value| {
149            var.parse(VarInput::Flat(value))
150                .expect("invalid value")
151                .format()
152        };
153        let value = parse(value);
154        match system_parameter_defaults.entry(var.name().to_string()) {
155            Entry::Vacant(entry) => {
156                entry.insert(value);
157            }
158            Entry::Occupied(entry) => {
159                assert_eq!(
160                    value,
161                    parse(entry.get()),
162                    "sqllogictest test requires {} to have a value of {}",
163                    var.name(),
164                    value
165                )
166            }
167        }
168    }
169
170    let config = RunConfig {
171        stdout: &OutputStream::new(io::stdout(), args.timestamps),
172        stderr: &OutputStream::new(io::stderr(), args.timestamps),
173        verbosity: args.verbosity,
174        postgres_url: args.postgres_url.clone(),
175        no_fail: args.no_fail,
176        fail_fast: args.fail_fast,
177        auto_index_tables: args.auto_index_tables,
178        auto_index_selects: args.auto_index_selects,
179        auto_transactions: args.auto_transactions,
180        enable_table_keys: args.enable_table_keys,
181        orchestrator_process_wrapper: args.orchestrator_process_wrapper.clone(),
182        tracing: tracing_args.clone(),
183        tracing_handle,
184        system_parameter_defaults,
185        persist_dir: match tempfile::tempdir() {
186            Ok(t) => t,
187            Err(e) => {
188                eprintln!("error creating state dir: {e}");
189                return ExitCode::FAILURE;
190            }
191        },
192        replicas: args.replicas,
193    };
194
195    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
196        if shard != 0 || shard_count != 1 {
197            eprintln!("Shard: {}/{}", shard + 1, shard_count);
198        }
199    }
200
201    if args.rewrite_results {
202        return rewrite(&config, args).await;
203    }
204
205    let mut junit = match args.junit_report {
206        Some(filename) => match File::create(&filename) {
207            Ok(file) => Some((file, junit_report::TestSuite::new("sqllogictest"))),
208            Err(err) => {
209                writeln!(config.stderr, "creating {}: {}", filename.display(), err);
210                return ExitCode::FAILURE;
211            }
212        },
213        None => None,
214    };
215    let mut outcomes = Outcomes::default();
216    let mut runner = Runner::start(&config).await.unwrap();
217    let mut paths = args.paths;
218
219    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
220        paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
221    }
222
223    for path in &paths {
224        for entry in WalkDir::new(path) {
225            match entry {
226                Ok(entry) if entry.file_type().is_file() => {
227                    #[allow(deprecated)] // fails with libraries still using old time lib
228                    let start_time = Instant::now();
229                    match runner::run_file(&mut runner, entry.path()).await {
230                        Ok(o) => {
231                            if o.any_failed() || config.verbosity >= 1 {
232                                writeln!(
233                                    config.stdout,
234                                    "{}",
235                                    util::indent(&o.display(config.no_fail, false).to_string(), 4)
236                                );
237                            }
238                            if let Some((_, junit_suite)) = &mut junit {
239                                let mut test_case = if o.any_failed() && !args.no_fail {
240                                    let mut result = junit_report::TestCase::failure(
241                                        &entry.path().to_string_lossy(),
242                                        start_time.elapsed(),
243                                        "failure",
244                                        "",
245                                    );
246                                    // Encode in system_out so we can display newlines
247                                    result.system_out = Some(
248                                        o.display(false, true)
249                                            .to_string()
250                                            .trim_end_matches('\n')
251                                            .to_string(),
252                                    );
253                                    result
254                                } else {
255                                    junit_report::TestCase::success(
256                                        &entry.path().to_string_lossy(),
257                                        start_time.elapsed(),
258                                    )
259                                };
260                                test_case.set_classname("sqllogictest");
261                                junit_suite.add_testcase(test_case);
262                            }
263                            outcomes += o;
264                        }
265                        Err(err) => {
266                            writeln!(
267                                config.stderr,
268                                "FAIL: error: running file {}: {}",
269                                entry.file_name().to_string_lossy(),
270                                err
271                            );
272                            return ExitCode::FAILURE;
273                        }
274                    }
275                }
276                Ok(_) => (),
277                Err(err) => {
278                    writeln!(
279                        config.stderr,
280                        "FAIL: error: reading directory entry: {}",
281                        err
282                    );
283                    return ExitCode::FAILURE;
284                }
285            }
286        }
287    }
288
289    writeln!(config.stdout, "{}", outcomes.display(config.no_fail, false));
290
291    if let Some((mut junit_file, junit_suite)) = junit {
292        let report = junit_report::ReportBuilder::new()
293            .add_testsuite(junit_suite)
294            .build();
295        match report.write_xml(&mut junit_file) {
296            Ok(()) => (),
297            Err(err) => {
298                writeln!(
299                    config.stderr,
300                    "error: unable to write junit report: {}",
301                    err
302                );
303                return ExitCode::from(2);
304            }
305        }
306    }
307
308    if outcomes.any_failed() && !args.no_fail {
309        return ExitCode::FAILURE;
310    }
311    ExitCode::SUCCESS
312}
313
314async fn rewrite(config: &RunConfig<'_>, args: Args) -> ExitCode {
315    if args.junit_report.is_some() {
316        writeln!(
317            config.stderr,
318            "--rewrite-results is not compatible with --junit-report"
319        );
320        return ExitCode::FAILURE;
321    }
322
323    if args.paths.iter().any(|path| path == "-") {
324        writeln!(config.stderr, "--rewrite-results cannot be used with stdin");
325        return ExitCode::FAILURE;
326    }
327
328    let mut runner = Runner::start(config).await.unwrap();
329    let mut paths = args.paths;
330
331    if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
332        paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
333    }
334
335    for path in paths {
336        for entry in WalkDir::new(path) {
337            match entry {
338                Ok(entry) => {
339                    if entry.file_type().is_file() {
340                        if let Err(err) = runner::rewrite_file(&mut runner, entry.path()).await {
341                            writeln!(config.stderr, "FAIL: error: rewriting file: {}", err);
342                            return ExitCode::FAILURE;
343                        }
344                    }
345                }
346                Err(err) => {
347                    writeln!(
348                        config.stderr,
349                        "FAIL: error: reading directory entry: {}",
350                        err
351                    );
352                    return ExitCode::FAILURE;
353                }
354            }
355        }
356    }
357    ExitCode::SUCCESS
358}
359
360struct OutputStream<W> {
361    inner: RefCell<W>,
362    need_timestamp: RefCell<bool>,
363    timestamps: bool,
364}
365
366impl<W> OutputStream<W>
367where
368    W: Write,
369{
370    fn new(inner: W, timestamps: bool) -> OutputStream<W> {
371        OutputStream {
372            inner: RefCell::new(inner),
373            need_timestamp: RefCell::new(true),
374            timestamps,
375        }
376    }
377
378    fn emit_str(&self, s: &str) {
379        self.inner.borrow_mut().write_all(s.as_bytes()).unwrap();
380    }
381}
382
383impl<W> WriteFmt for OutputStream<W>
384where
385    W: Write,
386{
387    fn write_fmt(&self, fmt: fmt::Arguments<'_>) {
388        let s = format!("{}", fmt);
389        if self.timestamps {
390            // We need to prefix every line in `s` with the current timestamp.
391
392            let timestamp = Utc::now();
393            let timestamp_str = timestamp.format("%Y-%m-%d %H:%M:%S.%f %Z");
394
395            // If the last character we outputted was a newline, then output a
396            // timestamp prefix at the start of this line.
397            if self.need_timestamp.replace(false) {
398                self.emit_str(&format!("[{}] ", timestamp_str));
399            }
400
401            // Emit `s`, installing a timestamp at the start of every line
402            // except the last.
403            let (s, last_was_timestamp) = match s.strip_suffix('\n') {
404                None => (&*s, false),
405                Some(s) => (s, true),
406            };
407            self.emit_str(&s.replace('\n', &format!("\n[{}] ", timestamp_str)));
408
409            // If the line ended with a newline, output the newline but *not*
410            // the timestamp prefix. We want the timestamp to reflect the moment
411            // the *next* character is output. So instead we just remember that
412            // the last character we output was a newline.
413            if last_was_timestamp {
414                *self.need_timestamp.borrow_mut() = true;
415                self.emit_str("\n");
416            }
417        } else {
418            self.emit_str(&s)
419        }
420    }
421}