use std::collections::{BTreeMap, BTreeSet};
use std::convert::Infallible;
use std::error::Error;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::time::Duration;
use std::{io, process};
use aws_credential_types::Credentials;
use aws_types::region::Region;
use clap::ArgAction;
use globset::GlobBuilder;
use itertools::Itertools;
use mz_build_info::{build_info, BuildInfo};
use mz_catalog::config::ClusterReplicaSizeMap;
use mz_ore::cli::{self, CliConfig};
use mz_ore::path::PathExt;
use mz_ore::url::SensitiveUrl;
use mz_testdrive::{CatalogConfig, Config, ConsistencyCheckLevel};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use time::Instant;
use tracing::info;
use tracing_subscriber::filter::EnvFilter;
use url::Url;
use walkdir::WalkDir;
macro_rules! die {
($($e:expr),*) => {{
eprintln!($($e),*);
process::exit(1);
}}
}
pub const BUILD_INFO: BuildInfo = build_info!();
#[derive(clap::Parser)]
struct Args {
#[clap(
long,
env = "VAR",
use_value_delimiter = true,
value_name = "NAME=VALUE"
)]
var: Vec<String>,
#[clap(long, value_name = "N", action = ArgAction::Set)]
seed: Option<u32>,
#[clap(long, action = ArgAction::SetTrue)]
no_reset: bool,
#[clap(long, value_name = "PATH")]
temp_dir: Option<String>,
#[clap(long, value_name = "SOURCE")]
source: Option<String>,
#[clap(long, value_parser = humantime::parse_duration, default_value = "30s", value_name = "DURATION")]
default_timeout: Duration,
#[clap(long, default_value = "18446744073709551615", value_name = "N")]
default_max_tries: usize,
#[clap(long, value_parser = humantime::parse_duration, default_value = "50ms", value_name = "DURATION")]
initial_backoff: Duration,
#[clap(long, default_value = "1.5", value_name = "FACTOR")]
backoff_factor: f64,
#[clap(long, default_value = "10", value_name = "N")]
max_errors: usize,
#[clap(long, default_value = "18446744073709551615", value_name = "N")]
max_tests: usize,
#[clap(long)]
shuffle_tests: bool,
#[clap(long, requires = "shard_count", value_name = "N")]
shard: Option<usize>,
#[clap(long, requires = "shard", value_name = "N")]
shard_count: Option<usize>,
#[clap(long, value_name = "FILE")]
junit_report: Option<PathBuf>,
#[clap(long, default_value_t = ConsistencyCheckLevel::default(), value_enum)]
consistency_checks: ConsistencyCheckLevel,
#[clap(
long,
env = "LOG_FILTER",
value_name = "FILTER",
default_value = "librdkafka=off,mz_kafka_util::client=off,warn"
)]
log_filter: String,
globs: Vec<String>,
#[clap(long)]
rewrite_results: bool,
#[clap(
long,
default_value = "postgres://materialize@localhost:6875",
value_name = "URL",
action = ArgAction::Set,
)]
materialize_url: tokio_postgres::Config,
#[clap(
long,
default_value = "postgres://materialize@localhost:6877",
value_name = "INTERNAL_URL",
action = ArgAction::Set,
)]
materialize_internal_url: tokio_postgres::Config,
#[clap(long)]
materialize_use_https: bool,
#[clap(long, default_value = "6876", value_name = "PORT")]
materialize_http_port: u16,
#[clap(long, default_value = "6878", value_name = "PORT")]
materialize_internal_http_port: u16,
#[clap(long, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
materialize_param: Vec<(String, String)>,
#[clap(long)]
validate_catalog_store: bool,
#[clap(
long,
value_name = "PERSIST_CONSENSUS_URL",
required_if_eq("validate_catalog_store", "true"),
action = ArgAction::Set,
)]
persist_consensus_url: Option<SensitiveUrl>,
#[clap(
long,
value_name = "PERSIST_BLOB_URL",
required_if_eq("validate_catalog_store", "true")
)]
persist_blob_url: Option<SensitiveUrl>,
#[clap(
long,
value_name = "ENCRYPTION://HOST:PORT",
default_value = "localhost:9092",
action = ArgAction::Set,
)]
kafka_addr: String,
#[clap(long, default_value = "1", value_name = "N")]
kafka_default_partitions: usize,
#[clap(long, env = "KAFKA_OPTION", use_value_delimiter=true, value_name = "KEY=VAL", value_parser = parse_kafka_opt)]
kafka_option: Vec<(String, String)>,
#[clap(long, value_name = "URL", default_value = "http://localhost:8081")]
schema_registry_url: Url,
#[clap(long, value_name = "PATH")]
cert: Option<String>,
#[clap(long, value_name = "PASSWORD")]
cert_password: Option<String>,
#[clap(long, value_name = "USERNAME")]
ccsr_username: Option<String>,
#[clap(long, value_name = "PASSWORD")]
ccsr_password: Option<String>,
#[clap(
long,
conflicts_with = "aws_endpoint",
value_name = "REGION",
env = "AWS_REGION"
)]
aws_region: Option<String>,
#[clap(
long,
conflicts_with = "aws_region",
value_name = "URL",
env = "AWS_ENDPOINT"
)]
aws_endpoint: Option<String>,
#[clap(
long,
value_name = "KEY_ID",
default_value = "dummy-access-key-id",
env = "AWS_ACCESS_KEY_ID"
)]
aws_access_key_id: String,
#[clap(
long,
value_name = "KEY",
default_value = "dummy-secret-access-key",
env = "AWS_SECRET_ACCESS_KEY"
)]
aws_secret_access_key: String,
#[clap(
long,
value_name = "FIVETRAN_DESTINATION_URL",
default_value = "http://localhost:6874"
)]
fivetran_destination_url: String,
#[clap(
long,
value_name = "FIVETRAN_DESTINATION_FILES_PATH",
default_value = "/tmp"
)]
fivetran_destination_files_path: String,
#[clap(long, env = "CLUSTER_REPLICA_SIZES")]
cluster_replica_sizes: String,
}
#[tokio::main]
async fn main() {
let args: Args = cli::parse_args(CliConfig::default());
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from(args.log_filter))
.with_writer(io::stdout)
.init();
let (aws_config, aws_account) = match args.aws_region {
Some(region) => {
let config = mz_aws_util::defaults()
.region(Region::new(region))
.load()
.await;
let account = async {
let sts_client = aws_sdk_sts::Client::new(&config);
Ok::<_, Box<dyn Error>>(
sts_client
.get_caller_identity()
.send()
.await?
.account
.ok_or("account ID is missing")?,
)
};
let account = account
.await
.unwrap_or_else(|e| die!("testdrive: failed fetching AWS account ID: {}", e));
(config, account)
}
None => {
let endpoint = args
.aws_endpoint
.unwrap_or_else(|| "http://localhost:4566".parse().unwrap());
let config = mz_aws_util::defaults()
.region(Region::new("us-east-1"))
.credentials_provider(Credentials::from_keys(
args.aws_access_key_id,
args.aws_secret_access_key,
None,
))
.endpoint_url(endpoint)
.load()
.await;
let account = "000000000000".into();
(config, account)
}
};
info!(
"Configuration parameters:
Kafka address: {}
Schema registry URL: {}
Materialize host: {:?}
Error limit: {}
Consistency check level: {:?}",
args.kafka_addr,
args.schema_registry_url,
args.materialize_url.get_hosts()[0],
args.max_errors,
args.consistency_checks,
);
if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
if shard != 0 || shard_count != 1 {
eprintln!(" Shard: {}/{}", shard + 1, shard_count);
}
}
let mut arg_vars = BTreeMap::new();
for arg in &args.var {
let mut parts = arg.splitn(2, '=');
let name = parts.next().expect("Clap ensures all --vars get a value");
let val = match parts.next() {
Some(val) => val,
None => {
eprintln!("No =VALUE for --var {}", name);
process::exit(1)
}
};
arg_vars.insert(name.to_string(), val.to_string());
}
let cluster_replica_sizes: ClusterReplicaSizeMap =
serde_json::from_str(&args.cluster_replica_sizes)
.unwrap_or_else(|e| die!("testdrive: failed to parse replica size map: {}", e));
let materialize_catalog_config = if args.validate_catalog_store {
Some(CatalogConfig {
persist_consensus_url: args
.persist_consensus_url
.clone()
.expect("required for persist catalog"),
persist_blob_url: args
.persist_blob_url
.clone()
.expect("required for persist catalog"),
})
} else {
None
};
let config = Config {
arg_vars,
seed: args.seed,
reset: !args.no_reset,
temp_dir: args.temp_dir,
source: args.source,
default_timeout: args.default_timeout,
default_max_tries: args.default_max_tries,
initial_backoff: args.initial_backoff,
backoff_factor: args.backoff_factor,
consistency_checks: args.consistency_checks,
rewrite_results: args.rewrite_results,
materialize_pgconfig: args.materialize_url,
materialize_cluster_replica_sizes: cluster_replica_sizes,
materialize_internal_pgconfig: args.materialize_internal_url,
materialize_http_port: args.materialize_http_port,
materialize_internal_http_port: args.materialize_internal_http_port,
materialize_use_https: args.materialize_use_https,
materialize_params: args.materialize_param,
materialize_catalog_config,
build_info: &BUILD_INFO,
persist_consensus_url: args.persist_consensus_url,
persist_blob_url: args.persist_blob_url,
kafka_addr: args.kafka_addr,
kafka_default_partitions: args.kafka_default_partitions,
kafka_opts: args.kafka_option,
schema_registry_url: args.schema_registry_url,
cert_path: args.cert,
cert_password: args.cert_password,
ccsr_password: args.ccsr_password,
ccsr_username: args.ccsr_username,
aws_config,
aws_account,
fivetran_destination_url: args.fivetran_destination_url,
fivetran_destination_files_path: args.fivetran_destination_files_path,
};
if args.junit_report.is_some() && args.rewrite_results {
eprintln!("--rewrite-results is not compatible with --junit-report");
process::exit(1);
}
let mut files = vec![];
if args.globs.is_empty() {
files.push(PathBuf::from("-"))
} else {
let all_files = WalkDir::new(".")
.sort_by_file_name()
.into_iter()
.map(|f| f.map(|f| f.path().clean()))
.collect::<Result<Vec<_>, _>>()
.unwrap_or_else(|e| die!("testdrive: failed walking directory: {}", e));
for glob in args.globs {
if glob == "-" {
files.push(glob.into());
continue;
}
let matcher = GlobBuilder::new(&Path::new(&glob).clean().to_string_lossy())
.literal_separator(true)
.build()
.unwrap_or_else(|e| die!("testdrive: invalid glob syntax: {}: {}", glob, e))
.compile_matcher();
let mut found = false;
for file in &all_files {
if matcher.is_match(file) {
files.push(file.clone());
found = true;
}
}
if !found {
die!("testdrive: glob did not match any patterns: {}", glob)
}
}
}
if let (Some(shard), Some(shard_count)) = (args.shard, args.shard_count) {
files = files.into_iter().skip(shard).step_by(shard_count).collect();
}
if args.shuffle_tests {
let seed = args.seed.unwrap_or_else(|| rand::thread_rng().gen());
let mut rng = StdRng::seed_from_u64(seed.into());
files.shuffle(&mut rng);
}
let mut error_count = 0;
let mut error_files = BTreeSet::new();
let mut junit = match args.junit_report {
Some(filename) => match File::create(&filename) {
Ok(file) => Some((file, junit_report::TestSuite::new("testdrive"))),
Err(err) => die!("creating {}: {}", filename.display(), err),
},
None => None,
};
for file in files.into_iter().take(args.max_tests) {
let start_time = Instant::now();
let res = if file == Path::new("-") {
if args.rewrite_results {
eprintln!("--rewrite-results is not compatible with stdin files");
process::exit(1);
}
mz_testdrive::run_stdin(&config).await
} else {
mz_testdrive::run_file(&config, &file).await
};
if let Some((_, junit_suite)) = &mut junit {
let mut test_case = match &res {
Ok(()) => {
junit_report::TestCase::success(&file.to_string_lossy(), start_time.elapsed())
}
Err(error) => junit_report::TestCase::failure(
&file.to_string_lossy(),
start_time.elapsed(),
"failure",
&error.to_string().replace("\n", " "),
),
};
test_case.set_classname("testdrive");
junit_suite.add_testcase(test_case);
}
if let Err(error) = res {
let _ = error.print_error();
error_count += 1;
error_files.insert(file);
if error_count >= args.max_errors {
eprintln!("testdrive: maximum number of errors reached; giving up");
break;
}
}
}
if let Some((mut junit_file, junit_suite)) = junit {
let report = junit_report::ReportBuilder::new()
.add_testsuite(junit_suite)
.build();
match report.write_xml(&mut junit_file) {
Ok(()) => (),
Err(e) => die!("error: unable to write junit report: {}", e),
}
}
if error_count > 0 {
eprint!("+++ ");
eprintln!("!!! Error Report");
eprintln!("{} errors were encountered during execution", error_count);
if config.source.is_some() {
eprintln!("source: {}", config.source.unwrap());
} else if !error_files.is_empty() {
eprintln!(
"files involved: {}",
error_files.iter().map(|p| p.display()).join(" ")
);
}
process::exit(1);
}
}
fn parse_kafka_opt(opt: &str) -> Result<(String, String), Infallible> {
let mut pieces = opt.splitn(2, '=');
let key = pieces.next().unwrap_or("").to_owned();
let val = pieces.next().unwrap_or("").to_owned();
Ok((key, val))
}