1use std::collections::{BTreeMap, BTreeSet};
11use std::convert::Infallible;
12use std::error::Error;
13use std::fs::File;
14use std::path::{Path, PathBuf};
15use std::time::Duration;
16use std::{io, process};
17
18use aws_credential_types::Credentials;
19use aws_types::region::Region;
20use clap::ArgAction;
21use globset::GlobBuilder;
22use itertools::Itertools;
23use mz_build_info::{BuildInfo, build_info};
24use mz_catalog::config::ClusterReplicaSizeMap;
25use mz_license_keys::ValidatedLicenseKey;
26use mz_ore::cli::{self, CliConfig};
27use mz_ore::path::PathExt;
28use mz_ore::url::SensitiveUrl;
29use mz_testdrive::{CatalogConfig, Config, ConsistencyCheckLevel};
30use rand::SeedableRng;
31use rand::rngs::StdRng;
32use rand::seq::SliceRandom;
33#[allow(deprecated)] use time::Instant;
35use tracing::info;
36use tracing_subscriber::filter::EnvFilter;
37use url::Url;
38use walkdir::WalkDir;
39
40macro_rules! die {
41 ($($e:expr),*) => {{
42 eprintln!($($e),*);
43 process::exit(1);
44 }}
45}
46
47pub const BUILD_INFO: BuildInfo = build_info!();
48
49#[derive(clap::Parser)]
51struct Args {
52 #[clap(long, env = "VAR", value_name = "NAME=VALUE")]
58 var: Vec<String>,
59 #[clap(long, value_name = "N", action = ArgAction::Set)]
61 seed: Option<u32>,
62 #[clap(long, action = ArgAction::SetTrue)]
65 no_reset: bool,
66 #[clap(long, value_name = "PATH")]
71 temp_dir: Option<String>,
72 #[clap(long, value_name = "SOURCE")]
74 source: Option<String>,
75 #[clap(long, value_parser = humantime::parse_duration, default_value = "30s", value_name = "DURATION")]
77 default_timeout: Duration,
78 #[clap(long, default_value = "18446744073709551615", value_name = "N")]
80 default_max_tries: usize,
81 #[clap(long, value_parser = humantime::parse_duration, default_value = "50ms", value_name = "DURATION")]
85 initial_backoff: Duration,
86 #[clap(long, default_value = "1.5", value_name = "FACTOR")]
90 backoff_factor: f64,
91 #[clap(long, default_value = "10", value_name = "N")]
93 max_errors: usize,
94 #[clap(long, default_value = "18446744073709551615", value_name = "N")]
96 max_tests: usize,
97 #[clap(long)]
101 shuffle_tests: bool,
102 #[clap(long, requires = "shard_count", value_name = "N")]
105 shard: Option<usize>,
106 #[clap(long, requires = "shard", value_name = "N")]
108 shard_count: Option<usize>,
109 #[clap(long, value_name = "FILE")]
111 junit_report: Option<PathBuf>,
112 #[clap(long, default_value_t = ConsistencyCheckLevel::default(), value_enum)]
114 consistency_checks: ConsistencyCheckLevel,
115 #[clap(long, action = ArgAction::SetTrue)]
118 check_statement_logging: bool,
119 #[clap(
123 long,
124 env = "LOG_FILTER",
125 value_name = "FILTER",
126 default_value = "librdkafka=off,mz_kafka_util::client=off,warn"
127 )]
128 log_filter: String,
129 globs: Vec<String>,
131 #[clap(long)]
134 rewrite_results: bool,
135
136 #[clap(
139 long,
140 default_value = "postgres://materialize@localhost:6875",
141 value_name = "URL",
142 action = ArgAction::Set,
143 )]
144 materialize_url: tokio_postgres::Config,
145 #[clap(
147 long,
148 default_value = "postgres://materialize@localhost:6877",
149 value_name = "INTERNAL_URL",
150 action = ArgAction::Set,
151 )]
152 materialize_internal_url: tokio_postgres::Config,
153 #[clap(long)]
154 materialize_use_https: bool,
155 #[clap(long, default_value = "6876", value_name = "PORT")]
160 materialize_http_port: u16,
161 #[clap(long, default_value = "6878", value_name = "PORT")]
165 materialize_internal_http_port: u16,
166 #[clap(long, default_value = "6880", value_name = "PORT")]
170 materialize_password_sql_port: u16,
171 #[clap(long, default_value = "6881", value_name = "PORT")]
175 materialize_sasl_sql_port: u16,
176 #[clap(long, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
179 materialize_param: Vec<(String, String)>,
180 #[clap(long)]
182 validate_catalog_store: bool,
183
184 #[clap(
187 long,
188 value_name = "PERSIST_CONSENSUS_URL",
189 required_if_eq("validate_catalog_store", "true"),
190 action = ArgAction::Set,
191 )]
192 persist_consensus_url: Option<SensitiveUrl>,
193 #[clap(
195 long,
196 value_name = "PERSIST_BLOB_URL",
197 required_if_eq("validate_catalog_store", "true")
198 )]
199 persist_blob_url: Option<SensitiveUrl>,
200
201 #[clap(
204 long,
205 value_name = "ENCRYPTION://HOST:PORT",
206 default_value = "localhost:9092",
207 action = ArgAction::Set,
208 )]
209 kafka_addr: String,
210 #[clap(long, default_value = "1", value_name = "N")]
212 kafka_default_partitions: usize,
213 #[clap(long, env = "KAFKA_OPTION", use_value_delimiter=true, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
216 kafka_option: Vec<(String, String)>,
217 #[clap(long, value_name = "URL", default_value = "http://localhost:8081")]
219 schema_registry_url: Url,
220 #[clap(long, value_name = "PATH")]
225 cert: Option<String>,
226 #[clap(long, value_name = "PASSWORD")]
228 cert_password: Option<String>,
229 #[clap(long, value_name = "USERNAME")]
231 ccsr_username: Option<String>,
232 #[clap(long, value_name = "PASSWORD")]
234 ccsr_password: Option<String>,
235
236 #[clap(
241 long,
242 conflicts_with = "aws_endpoint",
243 value_name = "REGION",
244 env = "AWS_REGION"
245 )]
246 aws_region: Option<String>,
247 #[clap(
252 long,
253 conflicts_with = "aws_region",
254 value_name = "URL",
255 env = "AWS_ENDPOINT"
256 )]
257 aws_endpoint: Option<String>,
258
259 #[clap(
260 long,
261 value_name = "KEY_ID",
262 default_value = "dummy-access-key-id",
263 env = "AWS_ACCESS_KEY_ID"
264 )]
265 aws_access_key_id: String,
266
267 #[clap(
268 long,
269 value_name = "KEY",
270 default_value = "dummy-secret-access-key",
271 env = "AWS_SECRET_ACCESS_KEY"
272 )]
273 aws_secret_access_key: String,
274
275 #[clap(
278 long,
279 value_name = "FIVETRAN_DESTINATION_URL",
280 default_value = "http://localhost:6874"
281 )]
282 fivetran_destination_url: String,
283 #[clap(
284 long,
285 value_name = "FIVETRAN_DESTINATION_FILES_PATH",
286 default_value = "/tmp"
287 )]
288 fivetran_destination_files_path: String,
289 #[clap(long, env = "CLUSTER_REPLICA_SIZES")]
291 cluster_replica_sizes: String,
292
293 #[clap(long, env = "MZ_CI_LICENSE_KEY")]
294 license_key: Option<String>,
295}
296
297#[tokio::main]
298async fn main() {
299 let args: Args = cli::parse_args(CliConfig::default());
300
301 tracing_subscriber::fmt()
302 .with_env_filter(EnvFilter::from(args.log_filter))
303 .with_writer(io::stdout)
304 .init();
305
306 let (aws_config, aws_account) = match args.aws_region {
307 Some(region) => {
308 let config = mz_aws_util::defaults()
311 .region(Region::new(region))
312 .load()
313 .await;
314 let account = async {
315 let sts_client = aws_sdk_sts::Client::new(&config);
316 Ok::<_, Box<dyn Error>>(
317 sts_client
318 .get_caller_identity()
319 .send()
320 .await?
321 .account
322 .ok_or("account ID is missing")?,
323 )
324 };
325 let account = account
326 .await
327 .unwrap_or_else(|e| die!("testdrive: failed fetching AWS account ID: {}", e));
328 (config, account)
329 }
330 None => {
331 let endpoint = args
334 .aws_endpoint
335 .unwrap_or_else(|| "http://localhost:4566".parse().unwrap());
336 let config = mz_aws_util::defaults()
337 .region(Region::new("us-east-1"))
338 .credentials_provider(Credentials::from_keys(
339 args.aws_access_key_id,
340 args.aws_secret_access_key,
341 None,
342 ))
343 .endpoint_url(endpoint)
344 .load()
345 .await;
346 let account = "000000000000".into();
347 (config, account)
348 }
349 };
350
351 info!(
352 "Configuration parameters:
353 Kafka address: {}
354 Schema registry URL: {}
355 Materialize host: {:?}
356 Error limit: {}
357 Consistency check level: {:?}",
358 args.kafka_addr,
359 args.schema_registry_url,
360 args.materialize_url.get_hosts()[0],
361 args.max_errors,
362 args.consistency_checks,
363 );
364 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
365 if shard != 0 || shard_count != 1 {
366 eprintln!(" Shard: {}/{}", shard + 1, shard_count);
367 }
368 }
369
370 let mut arg_vars = BTreeMap::new();
371 for arg in &args.var {
372 let mut parts = arg.splitn(2, '=');
373 let name = parts.next().expect("Clap ensures all --vars get a value");
374 let val = match parts.next() {
375 Some(val) => val,
376 None => {
377 eprintln!("No =VALUE for --var {}", name);
378 process::exit(1)
379 }
380 };
381 arg_vars.insert(name.to_string(), val.to_string());
382 }
383
384 let license_key = if let Some(license_key_text) = args.license_key {
385 mz_license_keys::validate(license_key_text.trim())
386 .unwrap_or_else(|e| die!("testdrive: failed to validate license key: {}", e))
387 } else {
388 ValidatedLicenseKey::default()
389 };
390
391 let cluster_replica_sizes = ClusterReplicaSizeMap::parse_from_str(
392 &args.cluster_replica_sizes,
393 !license_key.allow_credit_consumption_override,
394 )
395 .unwrap_or_else(|e| die!("testdrive: failed to parse replica size map: {}", e));
396
397 let materialize_catalog_config = if args.validate_catalog_store {
398 Some(CatalogConfig {
399 persist_consensus_url: args
400 .persist_consensus_url
401 .clone()
402 .expect("required for persist catalog"),
403 persist_blob_url: args
404 .persist_blob_url
405 .clone()
406 .expect("required for persist catalog"),
407 })
408 } else {
409 None
410 };
411 let config = Config {
412 arg_vars,
414 seed: args.seed,
415 reset: !args.no_reset,
416 temp_dir: args.temp_dir,
417 source: args.source,
418 default_timeout: args.default_timeout,
419 default_max_tries: args.default_max_tries,
420 initial_backoff: args.initial_backoff,
421 backoff_factor: args.backoff_factor,
422 consistency_checks: args.consistency_checks,
423 check_statement_logging: args.check_statement_logging,
424 rewrite_results: args.rewrite_results,
425
426 materialize_pgconfig: args.materialize_url,
428 materialize_cluster_replica_sizes: cluster_replica_sizes,
429 materialize_internal_pgconfig: args.materialize_internal_url,
430 materialize_http_port: args.materialize_http_port,
431 materialize_internal_http_port: args.materialize_internal_http_port,
432 materialize_use_https: args.materialize_use_https,
433 materialize_password_sql_port: args.materialize_password_sql_port,
434 materialize_sasl_sql_port: args.materialize_sasl_sql_port,
435 materialize_params: args.materialize_param,
436 materialize_catalog_config,
437 build_info: &BUILD_INFO,
438
439 persist_consensus_url: args.persist_consensus_url,
441 persist_blob_url: args.persist_blob_url,
442
443 kafka_addr: args.kafka_addr,
445 kafka_default_partitions: args.kafka_default_partitions,
446 kafka_opts: args.kafka_option,
447 schema_registry_url: args.schema_registry_url,
448 cert_path: args.cert,
449 cert_password: args.cert_password,
450 ccsr_password: args.ccsr_password,
451 ccsr_username: args.ccsr_username,
452
453 aws_config,
455 aws_account,
456
457 fivetran_destination_url: args.fivetran_destination_url,
459 fivetran_destination_files_path: args.fivetran_destination_files_path,
460 };
461
462 if args.junit_report.is_some() && args.rewrite_results {
463 eprintln!("--rewrite-results is not compatible with --junit-report");
464 process::exit(1);
465 }
466
467 let mut files = vec![];
479 if args.globs.is_empty() {
480 files.push(PathBuf::from("-"))
481 } else {
482 let all_files = WalkDir::new(".")
483 .sort_by_file_name()
484 .into_iter()
485 .map(|f| f.map(|f| f.path().clean()))
486 .collect::<Result<Vec<_>, _>>()
487 .unwrap_or_else(|e| die!("testdrive: failed walking directory: {}", e));
488 for glob in args.globs {
489 if glob == "-" {
490 files.push(glob.into());
491 continue;
492 }
493 let matcher = GlobBuilder::new(&Path::new(&glob).clean().to_string_lossy())
494 .literal_separator(true)
495 .build()
496 .unwrap_or_else(|e| die!("testdrive: invalid glob syntax: {}: {}", glob, e))
497 .compile_matcher();
498 let mut found = false;
499 for file in &all_files {
500 if matcher.is_match(file) {
501 files.push(file.clone());
502 found = true;
503 }
504 }
505 if !found {
506 die!("testdrive: glob did not match any patterns: {}", glob)
507 }
508 }
509 }
510
511 if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
512 files = files.into_iter().skip(shard).step_by(shard_count).collect();
513 }
514
515 if args.shuffle_tests {
516 let seed = args.seed.unwrap_or_else(rand::random);
517 let mut rng = StdRng::seed_from_u64(seed.into());
518 files.shuffle(&mut rng);
519 }
520
521 let mut error_count = 0;
522 let mut error_files = BTreeSet::new();
523 let mut junit = match args.junit_report {
524 Some(filename) => match File::create(&filename) {
525 Ok(file) => Some((file, junit_report::TestSuite::new("testdrive"))),
526 Err(err) => die!("creating {}: {}", filename.display(), err),
527 },
528 None => None,
529 };
530
531 for file in files.into_iter().take(args.max_tests) {
532 #[allow(deprecated)] let start_time = Instant::now();
534 let res = if file == Path::new("-") {
535 if args.rewrite_results {
536 eprintln!("--rewrite-results is not compatible with stdin files");
537 process::exit(1);
538 }
539 mz_testdrive::run_stdin(&config).await
540 } else {
541 mz_testdrive::run_file(&config, &file).await
542 };
543 if let Some((_, junit_suite)) = &mut junit {
544 let mut test_case = match &res {
545 Ok(()) => {
546 junit_report::TestCase::success(&file.to_string_lossy(), start_time.elapsed())
547 }
548 Err(error) => junit_report::TestCase::failure(
549 &file.to_string_lossy(),
550 start_time.elapsed(),
551 "failure",
552 &error.to_string().replace("\n", " "),
554 ),
555 };
556 test_case.set_classname("testdrive");
557 junit_suite.add_testcase(test_case);
558 }
559 if let Err(error) = res {
560 let _ = error.print_error();
561 error_count += 1;
562 error_files.insert(file);
563 if error_count >= args.max_errors {
564 eprintln!("testdrive: maximum number of errors reached; giving up");
565 break;
566 }
567 }
568 }
569
570 if let Some((mut junit_file, junit_suite)) = junit {
571 let report = junit_report::ReportBuilder::new()
572 .add_testsuite(junit_suite)
573 .build();
574 match report.write_xml(&mut junit_file) {
575 Ok(()) => (),
576 Err(e) => die!("error: unable to write junit report: {}", e),
577 }
578 }
579
580 if error_count > 0 {
581 eprint!("+++ ");
582 eprintln!("!!! Error Report");
583 eprintln!("{} errors were encountered during execution", error_count);
584 if config.source.is_some() {
585 eprintln!("source: {}", config.source.unwrap());
586 } else if !error_files.is_empty() {
587 eprintln!(
588 "files involved: {}",
589 error_files.iter().map(|p| p.display()).join(" ")
590 );
591 }
592 process::exit(1);
593 }
594}
595
596fn parse_kafka_opt(opt: &str) -> Result<(String, String), Infallible> {
597 let mut pieces = opt.splitn(2, '=');
598 let key = pieces.next().unwrap_or("").to_owned();
599 let val = pieces.next().unwrap_or("").to_owned();
600 Ok((key, val))
601}