1use std::cell::RefCell;
11use std::collections::BTreeMap;
12use std::collections::btree_map::Entry;
13use std::fmt;
14use std::fs::File;
15use std::io::{self, Write};
16use std::path::PathBuf;
17use std::process::ExitCode;
18
19use chrono::Utc;
20use clap::ArgAction;
21use mz_orchestrator_tracing::{StaticTracingConfig, TracingCliArgs};
22use mz_ore::cli::{self, CliConfig, KeyValueArg};
23use mz_ore::metrics::MetricsRegistry;
24use mz_sql::session::vars::{
25 DISK_CLUSTER_REPLICAS_DEFAULT, ENABLE_LOGICAL_COMPACTION_WINDOW, Var, VarInput,
26};
27use mz_sqllogictest::runner::{self, Outcomes, RunConfig, Runner, WriteFmt};
28use mz_sqllogictest::util;
29use mz_tracing::CloneableEnvFilter;
30#[allow(deprecated)] use time::Instant;
32use walkdir::WalkDir;
33
34#[derive(clap::Parser)]
36struct Args {
37 #[clap(short = 'v', long = "verbose", action = ArgAction::Count)]
43 verbosity: u8,
44 #[clap(long)]
46 no_fail: bool,
47 #[clap(long)]
49 timestamps: bool,
50 #[clap(long)]
52 rewrite_results: bool,
53 #[clap(long, value_name = "FILE")]
55 junit_report: Option<PathBuf>,
56 #[clap(long)]
58 postgres_url: String,
59 #[clap(value_name = "PATH", required = true)]
61 paths: Vec<String>,
62 #[clap(long)]
64 fail_fast: bool,
65 #[clap(long)]
67 auto_index_tables: bool,
68 #[clap(long)]
71 auto_index_selects: bool,
72 #[clap(long)]
75 auto_transactions: bool,
76 #[clap(long)]
78 enable_table_keys: bool,
79 #[clap(long, requires = "shard_count", value_name = "N")]
81 shard: Option<usize>,
82 #[clap(long, requires = "shard", value_name = "N")]
84 shard_count: Option<usize>,
85 #[clap(long, env = "ORCHESTRATOR_PROCESS_WRAPPER")]
87 orchestrator_process_wrapper: Option<String>,
88 #[clap(long, default_value = "2")]
90 replicas: usize,
91 #[clap(
94 long,
95 env = "SYSTEM_PARAMETER_DEFAULT",
96 action = ArgAction::Append,
97 value_delimiter = ';'
98 )]
99 system_parameter_default: Vec<KeyValueArg<String, String>>,
100 #[clap(
101 long,
102 env = "LOG_FILTER",
103 value_name = "FILTER",
104 default_value = "warn"
105 )]
106 pub log_filter: CloneableEnvFilter,
107}
108
109#[tokio::main]
110async fn main() -> ExitCode {
111 mz_ore::panic::install_enhanced_handler();
112
113 let args: Args = cli::parse_args(CliConfig {
114 env_prefix: Some("MZ_"),
115 enable_version_flag: false,
116 });
117
118 let tracing_args = TracingCliArgs {
119 startup_log_filter: args.log_filter.clone(),
120 ..Default::default()
121 };
122 let (tracing_handle, _tracing_guard) = tracing_args
123 .configure_tracing(
124 StaticTracingConfig {
125 service_name: "sqllogictest",
126 build_info: mz_environmentd::BUILD_INFO,
127 },
128 MetricsRegistry::new(),
129 )
130 .await
131 .unwrap();
132
133 let required_system_defaults: Vec<_> = [
137 (&DISK_CLUSTER_REPLICAS_DEFAULT, "true"),
138 (ENABLE_LOGICAL_COMPACTION_WINDOW.flag, "true"),
139 ]
140 .into();
141 let mut system_parameter_defaults: BTreeMap<_, _> = args
142 .system_parameter_default
143 .clone()
144 .into_iter()
145 .map(|kv| (kv.key, kv.value))
146 .collect();
147 for (var, value) in required_system_defaults {
148 let parse = |value| {
149 var.parse(VarInput::Flat(value))
150 .expect("invalid value")
151 .format()
152 };
153 let value = parse(value);
154 match system_parameter_defaults.entry(var.name().to_string()) {
155 Entry::Vacant(entry) => {
156 entry.insert(value);
157 }
158 Entry::Occupied(entry) => {
159 assert_eq!(
160 value,
161 parse(entry.get()),
162 "sqllogictest test requires {} to have a value of {}",
163 var.name(),
164 value
165 )
166 }
167 }
168 }
169
170 let config = RunConfig {
171 stdout: &OutputStream::new(io::stdout(), args.timestamps),
172 stderr: &OutputStream::new(io::stderr(), args.timestamps),
173 verbosity: args.verbosity,
174 postgres_url: args.postgres_url.clone(),
175 no_fail: args.no_fail,
176 fail_fast: args.fail_fast,
177 auto_index_tables: args.auto_index_tables,
178 auto_index_selects: args.auto_index_selects,
179 auto_transactions: args.auto_transactions,
180 enable_table_keys: args.enable_table_keys,
181 orchestrator_process_wrapper: args.orchestrator_process_wrapper.clone(),
182 tracing: tracing_args.clone(),
183 tracing_handle,
184 system_parameter_defaults,
185 persist_dir: match tempfile::tempdir() {
186 Ok(t) => t,
187 Err(e) => {
188 eprintln!("error creating state dir: {e}");
189 return ExitCode::FAILURE;
190 }
191 },
192 replicas: args.replicas,
193 };
194
195 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
196 if shard != 0 || shard_count != 1 {
197 eprintln!("Shard: {}/{}", shard + 1, shard_count);
198 }
199 }
200
201 if args.rewrite_results {
202 return rewrite(&config, args).await;
203 }
204
205 let mut junit = match args.junit_report {
206 Some(filename) => match File::create(&filename) {
207 Ok(file) => Some((file, junit_report::TestSuite::new("sqllogictest"))),
208 Err(err) => {
209 writeln!(config.stderr, "creating {}: {}", filename.display(), err);
210 return ExitCode::FAILURE;
211 }
212 },
213 None => None,
214 };
215 let mut outcomes = Outcomes::default();
216 let mut runner = Runner::start(&config).await.unwrap();
217 let mut paths = args.paths;
218
219 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
220 paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
221 }
222
223 for path in &paths {
224 for entry in WalkDir::new(path) {
225 match entry {
226 Ok(entry) if entry.file_type().is_file() => {
227 #[allow(deprecated)] let start_time = Instant::now();
229 match runner::run_file(&mut runner, entry.path()).await {
230 Ok(o) => {
231 if o.any_failed() || config.verbosity >= 1 {
232 writeln!(
233 config.stdout,
234 "{}",
235 util::indent(&o.display(config.no_fail, false).to_string(), 4)
236 );
237 }
238 if let Some((_, junit_suite)) = &mut junit {
239 let mut test_case = if o.any_failed() && !args.no_fail {
240 let mut result = junit_report::TestCase::failure(
241 &entry.path().to_string_lossy(),
242 start_time.elapsed(),
243 "failure",
244 "",
245 );
246 result.system_out = Some(
248 o.display(false, true)
249 .to_string()
250 .trim_end_matches('\n')
251 .to_string(),
252 );
253 result
254 } else {
255 junit_report::TestCase::success(
256 &entry.path().to_string_lossy(),
257 start_time.elapsed(),
258 )
259 };
260 test_case.set_classname("sqllogictest");
261 junit_suite.add_testcase(test_case);
262 }
263 outcomes += o;
264 }
265 Err(err) => {
266 writeln!(
267 config.stderr,
268 "FAIL: error: running file {}: {}",
269 entry.file_name().to_string_lossy(),
270 err
271 );
272 return ExitCode::FAILURE;
273 }
274 }
275 }
276 Ok(_) => (),
277 Err(err) => {
278 writeln!(
279 config.stderr,
280 "FAIL: error: reading directory entry: {}",
281 err
282 );
283 return ExitCode::FAILURE;
284 }
285 }
286 }
287 }
288
289 writeln!(config.stdout, "{}", outcomes.display(config.no_fail, false));
290
291 if let Some((mut junit_file, junit_suite)) = junit {
292 let report = junit_report::ReportBuilder::new()
293 .add_testsuite(junit_suite)
294 .build();
295 match report.write_xml(&mut junit_file) {
296 Ok(()) => (),
297 Err(err) => {
298 writeln!(
299 config.stderr,
300 "error: unable to write junit report: {}",
301 err
302 );
303 return ExitCode::from(2);
304 }
305 }
306 }
307
308 if outcomes.any_failed() && !args.no_fail {
309 return ExitCode::FAILURE;
310 }
311 ExitCode::SUCCESS
312}
313
314async fn rewrite(config: &RunConfig<'_>, args: Args) -> ExitCode {
315 if args.junit_report.is_some() {
316 writeln!(
317 config.stderr,
318 "--rewrite-results is not compatible with --junit-report"
319 );
320 return ExitCode::FAILURE;
321 }
322
323 if args.paths.iter().any(|path| path == "-") {
324 writeln!(config.stderr, "--rewrite-results cannot be used with stdin");
325 return ExitCode::FAILURE;
326 }
327
328 let mut runner = Runner::start(config).await.unwrap();
329 let mut paths = args.paths;
330
331 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
332 paths = paths.into_iter().skip(shard).step_by(shard_count).collect();
333 }
334
335 for path in paths {
336 for entry in WalkDir::new(path) {
337 match entry {
338 Ok(entry) => {
339 if entry.file_type().is_file() {
340 if let Err(err) = runner::rewrite_file(&mut runner, entry.path()).await {
341 writeln!(config.stderr, "FAIL: error: rewriting file: {}", err);
342 return ExitCode::FAILURE;
343 }
344 }
345 }
346 Err(err) => {
347 writeln!(
348 config.stderr,
349 "FAIL: error: reading directory entry: {}",
350 err
351 );
352 return ExitCode::FAILURE;
353 }
354 }
355 }
356 }
357 ExitCode::SUCCESS
358}
359
360struct OutputStream<W> {
361 inner: RefCell<W>,
362 need_timestamp: RefCell<bool>,
363 timestamps: bool,
364}
365
366impl<W> OutputStream<W>
367where
368 W: Write,
369{
370 fn new(inner: W, timestamps: bool) -> OutputStream<W> {
371 OutputStream {
372 inner: RefCell::new(inner),
373 need_timestamp: RefCell::new(true),
374 timestamps,
375 }
376 }
377
378 fn emit_str(&self, s: &str) {
379 self.inner.borrow_mut().write_all(s.as_bytes()).unwrap();
380 }
381}
382
383impl<W> WriteFmt for OutputStream<W>
384where
385 W: Write,
386{
387 fn write_fmt(&self, fmt: fmt::Arguments<'_>) {
388 let s = format!("{}", fmt);
389 if self.timestamps {
390 let timestamp = Utc::now();
393 let timestamp_str = timestamp.format("%Y-%m-%d %H:%M:%S.%f %Z");
394
395 if self.need_timestamp.replace(false) {
398 self.emit_str(&format!("[{}] ", timestamp_str));
399 }
400
401 let (s, last_was_timestamp) = match s.strip_suffix('\n') {
404 None => (&*s, false),
405 Some(s) => (s, true),
406 };
407 self.emit_str(&s.replace('\n', &format!("\n[{}] ", timestamp_str)));
408
409 if last_was_timestamp {
414 *self.need_timestamp.borrow_mut() = true;
415 self.emit_str("\n");
416 }
417 } else {
418 self.emit_str(&s)
419 }
420 }
421}