1use std::collections::{BTreeMap, BTreeSet};
11use std::convert::Infallible;
12use std::error::Error;
13use std::fs::File;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::{io, process};
17
18use aws_credential_types::Credentials;
19use aws_types::region::Region;
20use clap::ArgAction;
21use globset::GlobBuilder;
22use itertools::Itertools;
23use mz_build_info::{BuildInfo, build_info};
24use mz_catalog::config::ClusterReplicaSizeMap;
25use mz_license_keys::ValidatedLicenseKey;
26use mz_ore::cli::{self, CliConfig};
27use mz_ore::path::PathExt;
28use mz_ore::url::SensitiveUrl;
29use mz_testdrive::{CatalogConfig, Config, ConsistencyCheckLevel};
30use rand::rngs::StdRng;
31use rand::seq::SliceRandom;
32use rand::{Rng, SeedableRng};
33#[allow(deprecated)] use time::Instant;
35use tracing::info;
36use tracing_subscriber::filter::EnvFilter;
37use url::Url;
38use walkdir::WalkDir;
39
40macro_rules! die {
41 ($($e:expr),*) => {{
42 eprintln!($($e),*);
43 process::exit(1);
44 }}
45}
46
47pub const BUILD_INFO: BuildInfo = build_info!();
48
49#[derive(clap::Parser)]
51struct Args {
52 #[clap(
58 long,
59 env = "VAR",
60 use_value_delimiter = true,
61 value_name = "NAME=VALUE"
62 )]
63 var: Vec<String>,
64 #[clap(long, value_name = "N", action = ArgAction::Set)]
66 seed: Option<u32>,
67 #[clap(long, action = ArgAction::SetTrue)]
70 no_reset: bool,
71 #[clap(long, value_name = "PATH")]
76 temp_dir: Option<String>,
77 #[clap(long, value_name = "SOURCE")]
79 source: Option<String>,
80 #[clap(long, value_parser = humantime::parse_duration, default_value = "30s", value_name = "DURATION")]
82 default_timeout: Duration,
83 #[clap(long, default_value = "18446744073709551615", value_name = "N")]
85 default_max_tries: usize,
86 #[clap(long, value_parser = humantime::parse_duration, default_value = "50ms", value_name = "DURATION")]
90 initial_backoff: Duration,
91 #[clap(long, default_value = "1.5", value_name = "FACTOR")]
95 backoff_factor: f64,
96 #[clap(long, default_value = "10", value_name = "N")]
98 max_errors: usize,
99 #[clap(long, default_value = "18446744073709551615", value_name = "N")]
101 max_tests: usize,
102 #[clap(long)]
106 shuffle_tests: bool,
107 #[clap(long, requires = "shard_count", value_name = "N")]
110 shard: Option<usize>,
111 #[clap(long, requires = "shard", value_name = "N")]
113 shard_count: Option<usize>,
114 #[clap(long, value_name = "FILE")]
116 junit_report: Option<PathBuf>,
117 #[clap(long, default_value_t = ConsistencyCheckLevel::default(), value_enum)]
119 consistency_checks: ConsistencyCheckLevel,
120 #[clap(
124 long,
125 env = "LOG_FILTER",
126 value_name = "FILTER",
127 default_value = "librdkafka=off,mz_kafka_util::client=off,warn"
128 )]
129 log_filter: String,
130 globs: Vec<String>,
132 #[clap(long)]
135 rewrite_results: bool,
136
137 #[clap(
140 long,
141 default_value = "postgres://materialize@localhost:6875",
142 value_name = "URL",
143 action = ArgAction::Set,
144 )]
145 materialize_url: tokio_postgres::Config,
146 #[clap(
148 long,
149 default_value = "postgres://materialize@localhost:6877",
150 value_name = "INTERNAL_URL",
151 action = ArgAction::Set,
152 )]
153 materialize_internal_url: tokio_postgres::Config,
154 #[clap(long)]
155 materialize_use_https: bool,
156 #[clap(long, default_value = "6876", value_name = "PORT")]
161 materialize_http_port: u16,
162 #[clap(long, default_value = "6878", value_name = "PORT")]
166 materialize_internal_http_port: u16,
167 #[clap(long, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
170 materialize_param: Vec<(String, String)>,
171 #[clap(long)]
173 validate_catalog_store: bool,
174
175 #[clap(
178 long,
179 value_name = "PERSIST_CONSENSUS_URL",
180 required_if_eq("validate_catalog_store", "true"),
181 action = ArgAction::Set,
182 )]
183 persist_consensus_url: Option<SensitiveUrl>,
184 #[clap(
186 long,
187 value_name = "PERSIST_BLOB_URL",
188 required_if_eq("validate_catalog_store", "true")
189 )]
190 persist_blob_url: Option<SensitiveUrl>,
191
192 #[clap(
195 long,
196 value_name = "ENCRYPTION://HOST:PORT",
197 default_value = "localhost:9092",
198 action = ArgAction::Set,
199 )]
200 kafka_addr: String,
201 #[clap(long, default_value = "1", value_name = "N")]
203 kafka_default_partitions: usize,
204 #[clap(long, env = "KAFKA_OPTION", use_value_delimiter=true, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
207 kafka_option: Vec<(String, String)>,
208 #[clap(long, value_name = "URL", default_value = "http://localhost:8081")]
210 schema_registry_url: Url,
211 #[clap(long, value_name = "PATH")]
216 cert: Option<String>,
217 #[clap(long, value_name = "PASSWORD")]
219 cert_password: Option<String>,
220 #[clap(long, value_name = "USERNAME")]
222 ccsr_username: Option<String>,
223 #[clap(long, value_name = "PASSWORD")]
225 ccsr_password: Option<String>,
226
227 #[clap(
232 long,
233 conflicts_with = "aws_endpoint",
234 value_name = "REGION",
235 env = "AWS_REGION"
236 )]
237 aws_region: Option<String>,
238 #[clap(
243 long,
244 conflicts_with = "aws_region",
245 value_name = "URL",
246 env = "AWS_ENDPOINT"
247 )]
248 aws_endpoint: Option<String>,
249
250 #[clap(
251 long,
252 value_name = "KEY_ID",
253 default_value = "dummy-access-key-id",
254 env = "AWS_ACCESS_KEY_ID"
255 )]
256 aws_access_key_id: String,
257
258 #[clap(
259 long,
260 value_name = "KEY",
261 default_value = "dummy-secret-access-key",
262 env = "AWS_SECRET_ACCESS_KEY"
263 )]
264 aws_secret_access_key: String,
265
266 #[clap(
269 long,
270 value_name = "FIVETRAN_DESTINATION_URL",
271 default_value = "http://localhost:6874"
272 )]
273 fivetran_destination_url: String,
274 #[clap(
275 long,
276 value_name = "FIVETRAN_DESTINATION_FILES_PATH",
277 default_value = "/tmp"
278 )]
279 fivetran_destination_files_path: String,
280 #[clap(long, env = "CLUSTER_REPLICA_SIZES")]
282 cluster_replica_sizes: String,
283
284 #[clap(long, env = "MZ_CI_LICENSE_KEY")]
285 license_key: Option<String>,
286}
287
288#[tokio::main]
289async fn main() {
290 let args: Args = cli::parse_args(CliConfig::default());
291
292 tracing_subscriber::fmt()
293 .with_env_filter(EnvFilter::from(args.log_filter))
294 .with_writer(io::stdout)
295 .init();
296
297 let (aws_config, aws_account) = match args.aws_region {
298 Some(region) => {
299 let config = mz_aws_util::defaults()
302 .region(Region::new(region))
303 .load()
304 .await;
305 let account = async {
306 let sts_client = aws_sdk_sts::Client::new(&config);
307 Ok::<_, Box<dyn Error>>(
308 sts_client
309 .get_caller_identity()
310 .send()
311 .await?
312 .account
313 .ok_or("account ID is missing")?,
314 )
315 };
316 let account = account
317 .await
318 .unwrap_or_else(|e| die!("testdrive: failed fetching AWS account ID: {}", e));
319 (config, account)
320 }
321 None => {
322 let endpoint = args
325 .aws_endpoint
326 .unwrap_or_else(|| "http://localhost:4566".parse().unwrap());
327 let config = mz_aws_util::defaults()
328 .region(Region::new("us-east-1"))
329 .credentials_provider(Credentials::from_keys(
330 args.aws_access_key_id,
331 args.aws_secret_access_key,
332 None,
333 ))
334 .endpoint_url(endpoint)
335 .load()
336 .await;
337 let account = "000000000000".into();
338 (config, account)
339 }
340 };
341
342 info!(
343 "Configuration parameters:
344 Kafka address: {}
345 Schema registry URL: {}
346 Materialize host: {:?}
347 Error limit: {}
348 Consistency check level: {:?}",
349 args.kafka_addr,
350 args.schema_registry_url,
351 args.materialize_url.get_hosts()[0],
352 args.max_errors,
353 args.consistency_checks,
354 );
355 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
356 if shard != 0 || shard_count != 1 {
357 eprintln!(" Shard: {}/{}", shard + 1, shard_count);
358 }
359 }
360
361 let mut arg_vars = BTreeMap::new();
362 for arg in &args.var {
363 let mut parts = arg.splitn(2, '=');
364 let name = parts.next().expect("Clap ensures all --vars get a value");
365 let val = match parts.next() {
366 Some(val) => val,
367 None => {
368 eprintln!("No =VALUE for --var {}", name);
369 process::exit(1)
370 }
371 };
372 arg_vars.insert(name.to_string(), val.to_string());
373 }
374
375 let license_key = if let Some(license_key_text) = args.license_key {
376 mz_license_keys::validate(license_key_text.trim(), "00000000-0000-0000-000000000000")
377 .unwrap_or_else(|e| die!("testdrive: failed to validate license key: {}", e))
378 } else {
379 ValidatedLicenseKey::default()
380 };
381
382 let cluster_replica_sizes = ClusterReplicaSizeMap::parse_from_str(
383 &args.cluster_replica_sizes,
384 !license_key.allow_credit_consumption_override,
385 )
386 .unwrap_or_else(|e| die!("testdrive: failed to parse replica size map: {}", e));
387
388 let materialize_catalog_config = if args.validate_catalog_store {
389 Some(CatalogConfig {
390 persist_consensus_url: args
391 .persist_consensus_url
392 .clone()
393 .expect("required for persist catalog"),
394 persist_blob_url: args
395 .persist_blob_url
396 .clone()
397 .expect("required for persist catalog"),
398 })
399 } else {
400 None
401 };
402 let config = Config {
403 arg_vars,
405 seed: args.seed,
406 reset: !args.no_reset,
407 temp_dir: args.temp_dir,
408 source: args.source,
409 default_timeout: args.default_timeout,
410 default_max_tries: args.default_max_tries,
411 initial_backoff: args.initial_backoff,
412 backoff_factor: args.backoff_factor,
413 consistency_checks: args.consistency_checks,
414 rewrite_results: args.rewrite_results,
415
416 materialize_pgconfig: args.materialize_url,
418 materialize_cluster_replica_sizes: cluster_replica_sizes,
419 materialize_internal_pgconfig: args.materialize_internal_url,
420 materialize_http_port: args.materialize_http_port,
421 materialize_internal_http_port: args.materialize_internal_http_port,
422 materialize_use_https: args.materialize_use_https,
423 materialize_params: args.materialize_param,
424 materialize_catalog_config,
425 build_info: &BUILD_INFO,
426
427 persist_consensus_url: args.persist_consensus_url,
429 persist_blob_url: args.persist_blob_url,
430
431 kafka_addr: args.kafka_addr,
433 kafka_default_partitions: args.kafka_default_partitions,
434 kafka_opts: args.kafka_option,
435 schema_registry_url: args.schema_registry_url,
436 cert_path: args.cert,
437 cert_password: args.cert_password,
438 ccsr_password: args.ccsr_password,
439 ccsr_username: args.ccsr_username,
440
441 aws_config,
443 aws_account,
444
445 fivetran_destination_url: args.fivetran_destination_url,
447 fivetran_destination_files_path: args.fivetran_destination_files_path,
448 };
449
450 if args.junit_report.is_some() && args.rewrite_results {
451 eprintln!("--rewrite-results is not compatible with --junit-report");
452 process::exit(1);
453 }
454
455 let mut files = vec![];
467 if args.globs.is_empty() {
468 files.push(PathBuf::from("-"))
469 } else {
470 let all_files = WalkDir::new(".")
471 .sort_by_file_name()
472 .into_iter()
473 .map(|f| f.map(|f| f.path().clean()))
474 .collect::<Result<Vec<_>, _>>()
475 .unwrap_or_else(|e| die!("testdrive: failed walking directory: {}", e));
476 for glob in args.globs {
477 if glob == "-" {
478 files.push(glob.into());
479 continue;
480 }
481 let matcher = GlobBuilder::new(&Path::new(&glob).clean().to_string_lossy())
482 .literal_separator(true)
483 .build()
484 .unwrap_or_else(|e| die!("testdrive: invalid glob syntax: {}: {}", glob, e))
485 .compile_matcher();
486 let mut found = false;
487 for file in &all_files {
488 if matcher.is_match(file) {
489 files.push(file.clone());
490 found = true;
491 }
492 }
493 if !found {
494 die!("testdrive: glob did not match any patterns: {}", glob)
495 }
496 }
497 }
498
499 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
500 files = files.into_iter().skip(shard).step_by(shard_count).collect();
501 }
502
503 if args.shuffle_tests {
504 let seed = args.seed.unwrap_or_else(|| rand::thread_rng().r#gen());
505 let mut rng = StdRng::seed_from_u64(seed.into());
506 files.shuffle(&mut rng);
507 }
508
509 let mut error_count = 0;
510 let mut error_files = BTreeSet::new();
511 let mut junit = match args.junit_report {
512 Some(filename) => match File::create(&filename) {
513 Ok(file) => Some((file, junit_report::TestSuite::new("testdrive"))),
514 Err(err) => die!("creating {}: {}", filename.display(), err),
515 },
516 None => None,
517 };
518
519 for file in files.into_iter().take(args.max_tests) {
520 #[allow(deprecated)] let start_time = Instant::now();
522 let res = if file == Path::new("-") {
523 if args.rewrite_results {
524 eprintln!("--rewrite-results is not compatible with stdin files");
525 process::exit(1);
526 }
527 mz_testdrive::run_stdin(&config).await
528 } else {
529 mz_testdrive::run_file(&config, &file).await
530 };
531 if let Some((_, junit_suite)) = &mut junit {
532 let mut test_case = match &res {
533 Ok(()) => {
534 junit_report::TestCase::success(&file.to_string_lossy(), start_time.elapsed())
535 }
536 Err(error) => junit_report::TestCase::failure(
537 &file.to_string_lossy(),
538 start_time.elapsed(),
539 "failure",
540 &error.to_string().replace("\n", " "),
542 ),
543 };
544 test_case.set_classname("testdrive");
545 junit_suite.add_testcase(test_case);
546 }
547 if let Err(error) = res {
548 let _ = error.print_error();
549 error_count += 1;
550 error_files.insert(file);
551 if error_count >= args.max_errors {
552 eprintln!("testdrive: maximum number of errors reached; giving up");
553 break;
554 }
555 }
556 }
557
558 if let Some((mut junit_file, junit_suite)) = junit {
559 let report = junit_report::ReportBuilder::new()
560 .add_testsuite(junit_suite)
561 .build();
562 match report.write_xml(&mut junit_file) {
563 Ok(()) => (),
564 Err(e) => die!("error: unable to write junit report: {}", e),
565 }
566 }
567
568 if error_count > 0 {
569 eprint!("+++ ");
570 eprintln!("!!! Error Report");
571 eprintln!("{} errors were encountered during execution", error_count);
572 if config.source.is_some() {
573 eprintln!("source: {}", config.source.unwrap());
574 } else if !error_files.is_empty() {
575 eprintln!(
576 "files involved: {}",
577 error_files.iter().map(|p| p.display()).join(" ")
578 );
579 }
580 process::exit(1);
581 }
582}
583
584fn parse_kafka_opt(opt: &str) -> Result<(String, String), Infallible> {
585 let mut pieces = opt.splitn(2, '=');
586 let key = pieces.next().unwrap_or("").to_owned();
587 let val = pieces.next().unwrap_or("").to_owned();
588 Ok((key, val))
589}