use std::cell::RefCell;
use std::collections::btree_map::Entry;
use std::collections::BTreeMap;
use std::fmt;
use std::fs::File;
use std::io::{self, Write};
use std::path::PathBuf;
use std::process::ExitCode;
use chrono::Utc;
use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
use mz_ore::cli::{self, CliConfig, KeyValueArg};
use mz_ore::metrics::MetricsRegistry;
use mz_sql::session::vars::{
Var, VarInput, DISK_CLUSTER_REPLICAS_DEFAULT, ENABLE_LOGICAL_COMPACTION_WINDOW,
};
use mz_sqllogictest::runner::{self, Outcomes, RunConfig, Runner, WriteFmt};
use mz_sqllogictest::util;
use mz_tracing::CloneableEnvFilter;
use time::Instant;
use walkdir::WalkDir;
#[derive(clap::Parser)]
struct Args {
#[clap(short = 'v', long = "verbose", parse(from_occurrences))]
verbosity: usize,
#[clap(long)]
no_fail: bool,
#[clap(long)]
timestamps: bool,
#[clap(long)]
rewrite_results: bool,
#[clap(long, value_name = "FILE")]
junit_report: Option<PathBuf>,
#[clap(long)]
postgres_url: String,
#[clap(value_name = "PATH", required = true)]
paths: Vec<String>,
#[clap(long)]
fail_fast: bool,
#[clap(long)]
auto_index_tables: bool,
#[clap(long)]
auto_index_selects: bool,
#[clap(long)]
auto_transactions: bool,
#[clap(long)]
enable_table_keys: bool,
#[clap(long, requires = "shard-count", value_name = "N")]
shard: Option<usize>,
#[clap(long, requires = "shard", value_name = "N")]
shard_count: Option<usize>,
#[clap(long, env = "ORCHESTRATOR_PROCESS_WRAPPER")]
orchestrator_process_wrapper: Option<String>,
#[clap(long, default_value = "2")]
replicas: usize,
#[clap(
long,
env = "SYSTEM_PARAMETER_DEFAULT",
multiple = true,
value_delimiter = ';'
)]
system_parameter_default: Vec<KeyValueArg<String, String>>,
#[clap(
long,
env = "LOG_FILTER",
value_name = "FILTER",
default_value = "warn"
)]
pub log_filter: CloneableEnvFilter,
}
#[tokio::main]
async fn main() -> ExitCode {
mz_ore::panic::set_abort_on_panic();
let args: Args = cli::parse_args(CliConfig {
env_prefix: Some("MZ_"),
enable_version_flag: false,
});
let tracing_args = TracingCliArgs {
startup_log_filter: args.log_filter.clone(),
..Default::default()
};
let (tracing_handle, _tracing_guard) = tracing_args
.configure_tracing(
StaticTracingConfig {
service_name: "sqllogictest",
build_info: mz_environmentd::BUILD_INFO,
},
MetricsRegistry::new(),
)
.await
.unwrap();
let required_system_defaults: Vec<_> = [
(&DISK_CLUSTER_REPLICAS_DEFAULT, "true"),
(ENABLE_LOGICAL_COMPACTION_WINDOW.flag, "true"),
]
.into();
let mut system_parameter_defaults: BTreeMap<_, _> = args
.system_parameter_default
.clone()
.into_iter()
.map(|kv| (kv.key, kv.value))
.collect();
for (var, value) in required_system_defaults {
let parse = |value| {
var.parse(VarInput::Flat(value))
.expect("invalid value")
.format()
};
let value = parse(value);
match system_parameter_defaults.entry(var.name().to_string()) {
Entry::Vacant(entry) => {
entry.insert(value);
}
Entry::Occupied(entry) => {
assert_eq!(
value,
parse(entry.get()),
"sqllogictest test requires {} to have a value of {}",
var.name(),
value
)
}
}
}
let config = RunConfig {
stdout: &OutputStream::new(io::stdout(), args.timestamps),
stderr: &OutputStream::new(io::stderr(), args.timestamps),
verbosity: args.verbosity,
postgres_url: args.postgres_url.clone(),
no_fail: args.no_fail,
fail_fast: args.fail_fast,
auto_index_tables: args.auto_index_tables,
auto_index_selects: args.auto_index_selects,
auto_transactions: args.auto_transactions,
enable_table_keys: args.enable_table_keys,
orchestrator_process_wrapper: args.orchestrator_process_wrapper.clone(),
tracing: tracing_args.clone(),
tracing_handle,
system_parameter_defaults,
persist_dir: match tempfile::tempdir() {
Ok(t) => t,
Err(e) => {
eprintln!("error creating state dir: {e}");
return ExitCode::FAILURE;
}
},
replicas: args.replicas,
};
if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
if shard != 0 || shard_count != 1 {
eprintln!("Shard: {}/{}", shard + 1, shard_count);
}
}
if args.rewrite_results {
return rewrite(&config, args).await;
}
let mut junit = match args.junit_report {
Some(filename) => match File::create(&filename) {
Ok(file) => Some((file, junit_report::TestSuite::new("sqllogictest"))),
Err(err) => {
writeln!(config.stderr, "creating {}: {}", filename.display(), err);
return ExitCode::FAILURE;
}
},
None => None,
};
let mut outcomes = Outcomes::default();
let mut runner = Runner::start(&config).await.unwrap();
let mut paths = args.paths;
if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
}
for path in &paths {
for entry in WalkDir::new(path) {
match entry {
Ok(entry) if entry.file_type().is_file() => {
let start_time = Instant::now();
match runner::run_file(&mut runner, entry.path()).await {
Ok(o) => {
if o.any_failed() || config.verbosity >= 1 {
writeln!(
config.stdout,
"{}",
util::indent(&o.display(config.no_fail).to_string(), 4)
);
}
if let Some((_, junit_suite)) = &mut junit {
let mut test_case = if o.any_failed() {
junit_report::TestCase::failure(
&entry.path().to_string_lossy(),
start_time.elapsed(),
"failure",
&o.display(false).to_string(),
)
} else {
junit_report::TestCase::success(
&entry.path().to_string_lossy(),
start_time.elapsed(),
)
};
test_case.set_classname("sqllogictest");
junit_suite.add_testcase(test_case);
}
outcomes += o;
}
Err(err) => {
writeln!(
config.stderr,
"FAIL: error: running file {}: {}",
entry.file_name().to_string_lossy(),
err
);
return ExitCode::FAILURE;
}
}
}
Ok(_) => (),
Err(err) => {
writeln!(
config.stderr,
"FAIL: error: reading directory entry: {}",
err
);
return ExitCode::FAILURE;
}
}
}
}
writeln!(config.stdout, "{}", outcomes.display(config.no_fail));
if let Some((mut junit_file, junit_suite)) = junit {
let report = junit_report::ReportBuilder::new()
.add_testsuite(junit_suite)
.build();
match report.write_xml(&mut junit_file) {
Ok(()) => (),
Err(err) => {
writeln!(
config.stderr,
"error: unable to write junit report: {}",
err
);
return ExitCode::from(2);
}
}
}
if outcomes.any_failed() && !args.no_fail {
return ExitCode::FAILURE;
}
ExitCode::SUCCESS
}
async fn rewrite(config: &RunConfig<'_>, args: Args) -> ExitCode {
if args.junit_report.is_some() {
writeln!(
config.stderr,
"--rewrite-results is not compatible with --junit-report"
);
return ExitCode::FAILURE;
}
if args.paths.iter().any(|path| path == "-") {
writeln!(config.stderr, "--rewrite-results cannot be used with stdin");
return ExitCode::FAILURE;
}
let mut runner = Runner::start(config).await.unwrap();
let mut paths = args.paths;
if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
}
for path in paths {
for entry in WalkDir::new(path) {
match entry {
Ok(entry) => {
if entry.file_type().is_file() {
if let Err(err) = runner::rewrite_file(&mut runner, entry.path()).await {
writeln!(config.stderr, "FAIL: error: rewriting file: {}", err);
return ExitCode::FAILURE;
}
}
}
Err(err) => {
writeln!(
config.stderr,
"FAIL: error: reading directory entry: {}",
err
);
return ExitCode::FAILURE;
}
}
}
}
ExitCode::SUCCESS
}
struct OutputStream<W> {
inner: RefCell<W>,
need_timestamp: RefCell<bool>,
timestamps: bool,
}
impl<W> OutputStream<W>
where
W: Write,
{
fn new(inner: W, timestamps: bool) -> OutputStream<W> {
OutputStream {
inner: RefCell::new(inner),
need_timestamp: RefCell::new(true),
timestamps,
}
}
fn emit_str(&self, s: &str) {
self.inner.borrow_mut().write_all(s.as_bytes()).unwrap();
}
}
impl<W> WriteFmt for OutputStream<W>
where
W: Write,
{
fn write_fmt(&self, fmt: fmt::Arguments<'_>) {
let s = format!("{}", fmt);
if self.timestamps {
let timestamp = Utc::now();
let timestamp_str = timestamp.format("%Y-%m-%d %H:%M:%S.%f %Z");
if self.need_timestamp.replace(false) {
self.emit_str(&format!("[{}] ", timestamp_str));
}
let (s, last_was_timestamp) = match s.strip_suffix('\n') {
None => (&*s, false),
Some(s) => (s, true),
};
self.emit_str(&s.replace('\n', &format!("\n[{}] ", timestamp_str)));
if last_was_timestamp {
*self.need_timestamp.borrow_mut() = true;
self.emit_str("\n");
}
} else {
self.emit_str(&s)
}
}
}