1use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::collections::btree_map::Entry;
13use std::fmt;
14use std::fs::File;
15use std::io::{self, Write};
16use std::path::PathBuf;
17use std::process::ExitCode;
18
19use chrono::Utc;
20use clap::ArgAction;
21use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
22use mz_ore::cli::{self, CliConfig, KeyValueArg};
23use mz_ore::metrics::MetricsRegistry;
24use mz_sql::session::vars::{ENABLE_LOGICAL_COMPACTION_WINDOW, Var, VarInput};
25use mz_sqllogictest::runner::{self, Outcomes, RunConfig, Runner, WriteFmt};
26use mz_sqllogictest::util;
27use mz_tracing::CloneableEnvFilter;
28#[allow(deprecated)] use time::Instant;
30use walkdir::WalkDir;
31
32#[derive(clap::Parser)]
34struct Args {
35 #[clap(short = 'v', long = "verbose", action = ArgAction::Count)]
41 verbosity: u8,
42 #[clap(long)]
44 no_fail: bool,
45 #[clap(long)]
47 timestamps: bool,
48 #[clap(long)]
50 rewrite_results: bool,
51 #[clap(long, value_name = "FILE")]
53 junit_report: Option<PathBuf>,
54 #[clap(long)]
56 postgres_url: String,
57 #[clap(long, default_value = "sqllogictest")]
59 prefix: String,
60 #[clap(value_name = "PATH", required = true)]
62 paths: Vec<String>,
63 #[clap(long)]
65 fail_fast: bool,
66 #[clap(long)]
68 auto_index_tables: bool,
69 #[clap(long)]
72 auto_index_selects: bool,
73 #[clap(long)]
76 auto_transactions: bool,
77 #[clap(long)]
79 enable_table_keys: bool,
80 #[clap(long, requires = "shard_count", value_name = "N")]
82 shard: Option<usize>,
83 #[clap(long, requires = "shard", value_name = "N")]
85 shard_count: Option<usize>,
86 #[clap(long, env = "ORCHESTRATOR_PROCESS_WRAPPER")]
88 orchestrator_process_wrapper: Option<String>,
89 #[clap(long, default_value = "scale=1,workers=2")]
91 replica_size: String,
92 #[clap(long, default_value = "1")]
94 replicas: usize,
95 #[clap(
98 long,
99 env = "SYSTEM_PARAMETER_DEFAULT",
100 action = ArgAction::Append,
101 value_delimiter = ';'
102 )]
103 system_parameter_default: Vec<KeyValueArg<String, String>>,
104 #[clap(
105 long,
106 env = "LOG_FILTER",
107 value_name = "FILTER",
108 default_value = "warn"
109 )]
110 pub log_filter: CloneableEnvFilter,
111}
112
113#[tokio::main]
114async fn main() -> ExitCode {
115 mz_ore::panic::install_enhanced_handler();
116
117 let args: Args = cli::parse_args(CliConfig {
118 env_prefix: Some("MZ_"),
119 enable_version_flag: false,
120 });
121
122 let tracing_args = TracingCliArgs {
123 startup_log_filter: args.log_filter.clone(),
124 ..Default::default()
125 };
126 let tracing_handle = tracing_args
127 .configure_tracing(
128 StaticTracingConfig {
129 service_name: "sqllogictest",
130 build_info: mz_environmentd::BUILD_INFO,
131 },
132 MetricsRegistry::new(),
133 )
134 .await
135 .unwrap();
136
137 let required_system_defaults: Vec<_> = [(ENABLE_LOGICAL_COMPACTION_WINDOW.flag, "true")].into();
141 let mut system_parameter_defaults: BTreeMap<_, _> = args
142 .system_parameter_default
143 .clone()
144 .into_iter()
145 .map(|kv| (kv.key, kv.value))
146 .collect();
147 for (var, value) in required_system_defaults {
148 let parse = |value| {
149 var.parse(VarInput::Flat(value))
150 .expect("invalid value")
151 .format()
152 };
153 let value = parse(value);
154 match system_parameter_defaults.entry(var.name().to_string()) {
155 Entry::Vacant(entry) => {
156 entry.insert(value);
157 }
158 Entry::Occupied(entry) => {
159 assert_eq!(
160 value,
161 parse(entry.get()),
162 "sqllogictest test requires {} to have a value of {}",
163 var.name(),
164 value
165 )
166 }
167 }
168 }
169
170 let config = RunConfig {
171 stdout: &OutputStream::new(io::stdout(), args.timestamps),
172 stderr: &OutputStream::new(io::stderr(), args.timestamps),
173 verbosity: args.verbosity,
174 postgres_url: args.postgres_url.clone(),
175 prefix: args.prefix.clone(),
176 no_fail: args.no_fail,
177 fail_fast: args.fail_fast,
178 auto_index_tables: args.auto_index_tables,
179 auto_index_selects: args.auto_index_selects,
180 auto_transactions: args.auto_transactions,
181 enable_table_keys: args.enable_table_keys,
182 orchestrator_process_wrapper: args.orchestrator_process_wrapper.clone(),
183 tracing: tracing_args.clone(),
184 tracing_handle,
185 system_parameter_defaults,
186 persist_dir: match tempfile::tempdir() {
187 Ok(t) => t,
188 Err(e) => {
189 eprintln!("error creating state dir: {e}");
190 return ExitCode::FAILURE;
191 }
192 },
193 replicas: args.replicas,
194 replica_size: args.replica_size.clone(),
195 };
196
197 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
198 if shard != 0 || shard_count != 1 {
199 eprintln!("Shard: {}/{}", shard + 1, shard_count);
200 }
201 }
202
203 if args.rewrite_results {
204 return rewrite(&config, args).await;
205 }
206
207 let mut junit = match args.junit_report {
208 Some(filename) => match File::create(&filename) {
209 Ok(file) => Some((file, junit_report::TestSuite::new("sqllogictest"))),
210 Err(err) => {
211 writeln!(config.stderr, "creating {}: {}", filename.display(), err);
212 return ExitCode::FAILURE;
213 }
214 },
215 None => None,
216 };
217 let mut outcomes = Outcomes::default();
218 let mut runner = Runner::start(&config).await.unwrap();
219 let mut paths = args.paths;
220
221 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
222 paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
223 }
224
225 for path in &paths {
226 for entry in WalkDir::new(path) {
227 match entry {
228 Ok(entry) if entry.file_type().is_file() => {
229 #[allow(deprecated)] let start_time = Instant::now();
231 match runner::run_file(&mut runner, entry.path()).await {
232 Ok(o) => {
233 if o.any_failed() || config.verbosity >= 1 {
234 writeln!(
235 config.stdout,
236 "{}",
237 util::indent(&o.display(config.no_fail, false).to_string(), 4)
238 );
239 }
240 if let Some((_, junit_suite)) = &mut junit {
241 let mut test_case = if o.any_failed() && !args.no_fail {
242 let mut result = junit_report::TestCase::failure(
243 &entry.path().to_string_lossy(),
244 start_time.elapsed(),
245 "failure",
246 "",
247 );
248 result.system_out = Some(
250 o.display(false, true)
251 .to_string()
252 .trim_end_matches('\n')
253 .to_string(),
254 );
255 result
256 } else {
257 junit_report::TestCase::success(
258 &entry.path().to_string_lossy(),
259 start_time.elapsed(),
260 )
261 };
262 test_case.set_classname("sqllogictest");
263 junit_suite.add_testcase(test_case);
264 }
265 outcomes += o;
266 }
267 Err(err) => {
268 writeln!(
269 config.stderr,
270 "FAIL: error: running file {}: {}",
271 entry.file_name().to_string_lossy(),
272 err
273 );
274 return ExitCode::FAILURE;
275 }
276 }
277 }
278 Ok(_) => (),
279 Err(err) => {
280 writeln!(
281 config.stderr,
282 "FAIL: error: reading directory entry: {}",
283 err
284 );
285 return ExitCode::FAILURE;
286 }
287 }
288 }
289 }
290
291 writeln!(config.stdout, "{}", outcomes.display(config.no_fail, false));
292
293 if let Some((mut junit_file, junit_suite)) = junit {
294 let report = junit_report::ReportBuilder::new()
295 .add_testsuite(junit_suite)
296 .build();
297 match report.write_xml(&mut junit_file) {
298 Ok(()) => (),
299 Err(err) => {
300 writeln!(
301 config.stderr,
302 "error: unable to write junit report: {}",
303 err
304 );
305 return ExitCode::from(2);
306 }
307 }
308 }
309
310 if outcomes.any_failed() && !args.no_fail {
311 return ExitCode::FAILURE;
312 }
313 ExitCode::SUCCESS
314}
315
316async fn rewrite(config: &RunConfig<'_>, args: Args) -> ExitCode {
317 if args.junit_report.is_some() {
318 writeln!(
319 config.stderr,
320 "--rewrite-results is not compatible with --junit-report"
321 );
322 return ExitCode::FAILURE;
323 }
324
325 if args.paths.iter().any(|path| path == "-") {
326 writeln!(config.stderr, "--rewrite-results cannot be used with stdin");
327 return ExitCode::FAILURE;
328 }
329
330 let mut runner = Runner::start(config).await.unwrap();
331 let mut paths = args.paths;
332
333 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
334 paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
335 }
336
337 for path in paths {
338 for entry in WalkDir::new(path) {
339 match entry {
340 Ok(entry) => {
341 if entry.file_type().is_file() {
342 if let Err(err) = runner::rewrite_file(&mut runner, entry.path()).await {
343 writeln!(config.stderr, "FAIL: error: rewriting file: {}", err);
344 return ExitCode::FAILURE;
345 }
346 }
347 }
348 Err(err) => {
349 writeln!(
350 config.stderr,
351 "FAIL: error: reading directory entry: {}",
352 err
353 );
354 return ExitCode::FAILURE;
355 }
356 }
357 }
358 }
359 ExitCode::SUCCESS
360}
361
362struct OutputStream<W> {
363 inner: RefCell<W>,
364 need_timestamp: RefCell<bool>,
365 timestamps: bool,
366}
367
368impl<W> OutputStream<W>
369where
370 W: Write,
371{
372 fn new(inner: W, timestamps: bool) -> OutputStream<W> {
373 OutputStream {
374 inner: RefCell::new(inner),
375 need_timestamp: RefCell::new(true),
376 timestamps,
377 }
378 }
379
380 fn emit_str(&self, s: &str) {
381 self.inner.borrow_mut().write_all(s.as_bytes()).unwrap();
382 }
383}
384
385impl<W> WriteFmt for OutputStream<W>
386where
387 W: Write,
388{
389 fn write_fmt(&self, fmt: fmt::Arguments<'_>) {
390 let s = format!("{}", fmt);
391 if self.timestamps {
392 let timestamp = Utc::now();
395 let timestamp_str = timestamp.format("%Y-%m-%d %H:%M:%S.%f %Z");
396
397 if self.need_timestamp.replace(false) {
400 self.emit_str(&format!("[{}] ", timestamp_str));
401 }
402
403 let (s, last_was_timestamp) = match s.strip_suffix('\n') {
406 None => (&*s, false),
407 Some(s) => (s, true),
408 };
409 self.emit_str(&s.replace('\n', &format!("\n[{}] ", timestamp_str)));
410
411 if last_was_timestamp {
416 *self.need_timestamp.borrow_mut() = true;
417 self.emit_str("\n");
418 }
419 } else {
420 self.emit_str(&s)
421 }
422 }
423}