1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use hyper::http::header::HeaderMap;
27use mz_adapter::TimestampExplanation;
28use mz_adapter_types::bootstrap_builtin_cluster_config::{
29 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
30 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
31 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
32};
33
34use mz_catalog::config::ClusterReplicaSizeMap;
35use mz_controller::ControllerConfig;
36use mz_dyncfg::ConfigUpdates;
37use mz_license_keys::ValidatedLicenseKey;
38use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
39use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
40use mz_ore::metrics::MetricsRegistry;
41use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
42use mz_ore::retry::Retry;
43use mz_ore::task;
44use mz_ore::tracing::{
45 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingGuard,
46 TracingHandle,
47};
48use mz_persist_client::PersistLocation;
49use mz_persist_client::cache::PersistClientCache;
50use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
51use mz_persist_client::rpc::PersistGrpcPubSubServer;
52use mz_secrets::SecretsController;
53use mz_server_core::{ReloadTrigger, TlsCertConfig};
54use mz_sql::catalog::EnvironmentId;
55use mz_storage_types::connections::ConnectionContext;
56use mz_tracing::CloneableEnvFilter;
57use openssl::asn1::Asn1Time;
58use openssl::error::ErrorStack;
59use openssl::hash::MessageDigest;
60use openssl::nid::Nid;
61use openssl::pkey::{PKey, Private};
62use openssl::rsa::Rsa;
63use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
64use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
65use openssl::x509::{X509, X509Name, X509NameBuilder};
66use postgres::error::DbError;
67use postgres::tls::{MakeTlsConnect, TlsConnect};
68use postgres::types::{FromSql, Type};
69use postgres::{NoTls, Socket};
70use postgres_openssl::MakeTlsConnector;
71use tempfile::TempDir;
72use tokio::net::TcpListener;
73use tokio::runtime::Runtime;
74use tokio_postgres::config::{Host, SslMode};
75use tokio_postgres::{AsyncMessage, Client};
76use tokio_stream::wrappers::TcpListenerStream;
77use tower_http::cors::AllowOrigin;
78use tracing::Level;
79use tracing_capture::SharedStorage;
80use tracing_subscriber::EnvFilter;
81use tungstenite::stream::MaybeTlsStream;
82use tungstenite::{Message, WebSocket};
83use url::Url;
84
85use crate::{CatalogConfig, FronteggAuthentication, WebSocketAuth, WebSocketResponse};
86
87pub static KAFKA_ADDRS: LazyLock<String> =
88 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
89
90#[derive(Clone)]
92pub struct TestHarness {
93 data_directory: Option<PathBuf>,
94 tls: Option<TlsCertConfig>,
95 frontegg: Option<FronteggAuthentication>,
96 self_hosted_auth: bool,
97 self_hosted_auth_internal: bool,
98 unsafe_mode: bool,
99 workers: usize,
100 now: NowFn,
101 seed: u32,
102 storage_usage_collection_interval: Duration,
103 storage_usage_retention_period: Option<Duration>,
104 default_cluster_replica_size: String,
105 default_cluster_replication_factor: u32,
106 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
107 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
108 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
109 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
110 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
111
112 propagate_crashes: bool,
113 enable_tracing: bool,
114 orchestrator_tracing_cli_args: TracingCliArgs,
117 bootstrap_role: Option<String>,
118 deploy_generation: u64,
119 system_parameter_defaults: BTreeMap<String, String>,
120 internal_console_redirect_url: Option<String>,
121 metrics_registry: Option<MetricsRegistry>,
122 code_version: semver::Version,
123 capture: Option<SharedStorage>,
124 pub environment_id: EnvironmentId,
125}
126
127impl Default for TestHarness {
128 fn default() -> TestHarness {
129 TestHarness {
130 data_directory: None,
131 tls: None,
132 frontegg: None,
133 self_hosted_auth: false,
134 self_hosted_auth_internal: false,
135 unsafe_mode: false,
136 workers: 1,
137 now: SYSTEM_TIME.clone(),
138 seed: rand::random(),
139 storage_usage_collection_interval: Duration::from_secs(3600),
140 storage_usage_retention_period: None,
141 default_cluster_replica_size: "1".to_string(),
142 default_cluster_replication_factor: 2,
143 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
144 size: "1".to_string(),
145 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
146 },
147 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
148 size: "1".to_string(),
149 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
150 },
151 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
152 size: "1".to_string(),
153 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
154 },
155 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
156 size: "1".to_string(),
157 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
158 },
159 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
160 size: "1".to_string(),
161 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
162 },
163 propagate_crashes: false,
164 enable_tracing: false,
165 bootstrap_role: Some("materialize".into()),
166 deploy_generation: 0,
167 system_parameter_defaults: BTreeMap::from([(
170 "log_filter".to_string(),
171 "error".to_string(),
172 )]),
173 internal_console_redirect_url: None,
174 metrics_registry: None,
175 orchestrator_tracing_cli_args: TracingCliArgs {
176 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
177 ..Default::default()
178 },
179 code_version: crate::BUILD_INFO.semver_version(),
180 environment_id: EnvironmentId::for_tests(),
181 capture: None,
182 }
183 }
184}
185
186impl TestHarness {
187 pub async fn start(self) -> TestServer {
191 self.try_start().await.expect("Failed to start test Server")
192 }
193
194 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
196 self.try_start_with_trigger(tls_reload_certs)
197 .await
198 .expect("Failed to start test Server")
199 }
200
201 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
203 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
204 .await
205 }
206
207 pub async fn try_start_with_trigger(
209 self,
210 tls_reload_certs: ReloadTrigger,
211 ) -> Result<TestServer, anyhow::Error> {
212 let listeners = Listeners::new().await?;
213 listeners.serve_with_trigger(self, tls_reload_certs).await
214 }
215
216 pub fn start_blocking(self) -> TestServerWithRuntime {
218 stacker::grow(mz_ore::stack::STACK_SIZE, || {
219 let runtime = Runtime::new().expect("failed to spawn runtime for test");
220 let runtime = Arc::new(runtime);
221 let server = runtime.block_on(self.start());
222 TestServerWithRuntime { runtime, server }
223 })
224 }
225
226 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
227 self.data_directory = Some(data_directory.into());
228 self
229 }
230
231 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
232 self.tls = Some(TlsCertConfig {
233 cert: cert_path.into(),
234 key: key_path.into(),
235 });
236 self
237 }
238
239 pub fn unsafe_mode(mut self) -> Self {
240 self.unsafe_mode = true;
241 self
242 }
243
244 pub fn workers(mut self, workers: usize) -> Self {
245 self.workers = workers;
246 self
247 }
248
249 pub fn with_frontegg(mut self, frontegg: &FronteggAuthentication) -> Self {
250 self.frontegg = Some(frontegg.clone());
251 self
252 }
253
254 pub fn with_self_hosted_auth(mut self, self_hosted_auth: bool) -> Self {
255 self.self_hosted_auth = self_hosted_auth;
256 self
257 }
258
259 pub fn with_self_hosted_auth_internal(mut self, self_hosted_auth_internal: bool) -> Self {
260 self.self_hosted_auth_internal = self_hosted_auth_internal;
261 self
262 }
263
264 pub fn with_now(mut self, now: NowFn) -> Self {
265 self.now = now;
266 self
267 }
268
269 pub fn with_storage_usage_collection_interval(
270 mut self,
271 storage_usage_collection_interval: Duration,
272 ) -> Self {
273 self.storage_usage_collection_interval = storage_usage_collection_interval;
274 self
275 }
276
277 pub fn with_storage_usage_retention_period(
278 mut self,
279 storage_usage_retention_period: Duration,
280 ) -> Self {
281 self.storage_usage_retention_period = Some(storage_usage_retention_period);
282 self
283 }
284
285 pub fn with_default_cluster_replica_size(
286 mut self,
287 default_cluster_replica_size: String,
288 ) -> Self {
289 self.default_cluster_replica_size = default_cluster_replica_size;
290 self
291 }
292
293 pub fn with_builtin_system_cluster_replica_size(
294 mut self,
295 builtin_system_cluster_replica_size: String,
296 ) -> Self {
297 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
298 self
299 }
300
301 pub fn with_builtin_system_cluster_replication_factor(
302 mut self,
303 builtin_system_cluster_replication_factor: u32,
304 ) -> Self {
305 self.builtin_system_cluster_config.replication_factor =
306 builtin_system_cluster_replication_factor;
307 self
308 }
309
310 pub fn with_builtin_catalog_server_cluster_replica_size(
311 mut self,
312 builtin_catalog_server_cluster_replica_size: String,
313 ) -> Self {
314 self.builtin_catalog_server_cluster_config.size =
315 builtin_catalog_server_cluster_replica_size;
316 self
317 }
318
319 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
320 self.propagate_crashes = propagate_crashes;
321 self
322 }
323
324 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
325 self.enable_tracing = enable_tracing;
326 self
327 }
328
329 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
330 self.bootstrap_role = bootstrap_role;
331 self
332 }
333
334 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
335 self.deploy_generation = deploy_generation;
336 self
337 }
338
339 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
340 self.system_parameter_defaults.insert(param, value);
341 self
342 }
343
344 pub fn with_internal_console_redirect_url(
345 mut self,
346 internal_console_redirect_url: Option<String>,
347 ) -> Self {
348 self.internal_console_redirect_url = internal_console_redirect_url;
349 self
350 }
351
352 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
353 self.metrics_registry = Some(registry);
354 self
355 }
356
357 pub fn with_code_version(mut self, version: semver::Version) -> Self {
358 self.code_version = version;
359 self
360 }
361
362 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
363 self.capture = Some(storage);
364 self
365 }
366}
367
368pub struct Listeners {
369 pub inner: crate::Listeners,
370}
371
372impl Listeners {
373 pub async fn new() -> Result<Listeners, anyhow::Error> {
374 let inner = crate::Listeners::bind_any_local().await?;
375 Ok(Listeners { inner })
376 }
377
378 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
379 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
380 .await
381 }
382
383 pub async fn serve_with_trigger(
384 self,
385 config: TestHarness,
386 tls_reload_certs: ReloadTrigger,
387 ) -> Result<TestServer, anyhow::Error> {
388 let (data_directory, temp_dir) = match config.data_directory {
389 None => {
390 let temp_dir = tempfile::tempdir()?;
395 (temp_dir.path().to_path_buf(), Some(temp_dir))
396 }
397 Some(data_directory) => (data_directory, None),
398 };
399 let scratch_dir = tempfile::tempdir()?;
400 let (consensus_uri, timestamp_oracle_url) = {
401 let seed = config.seed;
402 let cockroach_url = env::var("COCKROACH_URL")
403 .map_err(|_| anyhow!("COCKROACH_URL environment variable is not set"))?;
404 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
405 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
406 if let Err(err) = conn.await {
407 panic!("connection error: {}", err);
408 };
409 });
410 client
411 .batch_execute(&format!(
412 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
413 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
414 ))
415 .await?;
416 (
417 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
418 .parse()
419 .expect("invalid consensus URI"),
420 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
421 .parse()
422 .expect("invalid timestamp oracle URI"),
423 )
424 };
425 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
426 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
427 image_dir: env::current_exe()?
428 .parent()
429 .unwrap()
430 .parent()
431 .unwrap()
432 .to_path_buf(),
433 suppress_output: false,
434 environment_id: config.environment_id.to_string(),
435 secrets_dir: data_directory.join("secrets"),
436 command_wrapper: vec![],
437 propagate_crashes: config.propagate_crashes,
438 tcp_proxy: None,
439 scratch_directory: scratch_dir.path().to_path_buf(),
440 })
441 .await?;
442 let orchestrator = Arc::new(orchestrator);
443 let persist_now = SYSTEM_TIME.clone();
446 let dyncfgs = mz_dyncfgs::all_dyncfgs();
447
448 let mut updates = ConfigUpdates::default();
449 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
452 updates.apply(&dyncfgs);
453
454 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
455 persist_cfg.build_version = config.code_version;
456 persist_cfg.set_rollup_threshold(5);
458
459 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
460 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
461 let persist_pubsub_tcp_listener =
462 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
463 .await
464 .expect("pubsub addr binding");
465 let persist_pubsub_server_port = persist_pubsub_tcp_listener
466 .local_addr()
467 .expect("pubsub addr has local addr")
468 .port();
469
470 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
472 persist_pubsub_server
473 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
474 .await
475 .expect("success")
476 });
477 let persist_clients =
478 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
479 let persist_clients = Arc::new(persist_clients);
480
481 let secrets_controller = Arc::clone(&orchestrator);
482 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
483 let orchestrator = Arc::new(TracingOrchestrator::new(
484 orchestrator,
485 config.orchestrator_tracing_cli_args,
486 ));
487 let (tracing_handle, tracing_guard) = if config.enable_tracing {
488 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
489 service_name: "environmentd",
490 stderr_log: StderrLogConfig {
491 format: StderrLogFormat::Json,
492 filter: EnvFilter::default(),
493 },
494 opentelemetry: Some(OpenTelemetryConfig {
495 endpoint: "http://fake_address_for_testing:8080".to_string(),
496 headers: http::HeaderMap::new(),
497 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
498 resource: opentelemetry_sdk::resource::Resource::default(),
499 max_batch_queue_size: 2048,
500 max_export_batch_size: 512,
501 max_concurrent_exports: 1,
502 batch_scheduled_delay: Duration::from_millis(5000),
503 max_export_timeout: Duration::from_secs(30),
504 }),
505 tokio_console: None,
506 sentry: None,
507 build_version: crate::BUILD_INFO.version,
508 build_sha: crate::BUILD_INFO.sha,
509 registry: metrics_registry.clone(),
510 capture: config.capture,
511 };
512 let (tracing_handle, tracing_guard) = mz_ore::tracing::configure(config).await?;
513 (tracing_handle, Some(tracing_guard))
514 } else {
515 (TracingHandle::disabled(), None)
516 };
517 let host_name = format!("localhost:{}", self.inner.http_local_addr().port());
518 let catalog_config = CatalogConfig {
519 persist_clients: Arc::clone(&persist_clients),
520 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
521 };
522
523 let inner = self
524 .inner
525 .serve(crate::Config {
526 catalog_config,
527 timestamp_oracle_url: Some(timestamp_oracle_url),
528 controller: ControllerConfig {
529 build_info: &crate::BUILD_INFO,
530 orchestrator,
531 clusterd_image: "clusterd".into(),
532 init_container_image: None,
533 deploy_generation: config.deploy_generation,
534 persist_location: PersistLocation {
535 blob_uri: format!("file://{}/persist/blob", data_directory.display())
536 .parse()
537 .expect("invalid blob URI"),
538 consensus_uri,
539 },
540 persist_clients,
541 now: config.now.clone(),
542 metrics_registry: metrics_registry.clone(),
543 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
544 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
545 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
546 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
547 secrets_reader_kubernetes_context: None,
548 secrets_reader_aws_prefix: None,
549 secrets_reader_name_prefix: None,
550 },
551 connection_context,
552 },
553 secrets_controller,
554 cloud_resource_controller: None,
555 tls: config.tls,
556 frontegg: config.frontegg,
557 self_hosted_auth: config.self_hosted_auth,
558 self_hosted_auth_internal: config.self_hosted_auth_internal,
559 unsafe_mode: config.unsafe_mode,
560 all_features: false,
561 metrics_registry: metrics_registry.clone(),
562 now: config.now,
563 environment_id: config.environment_id,
564 cors_allowed_origin: AllowOrigin::list([]),
565 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
566 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
567 bootstrap_default_cluster_replication_factor: config
568 .default_cluster_replication_factor,
569 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
570 bootstrap_builtin_catalog_server_cluster_config: config
571 .builtin_catalog_server_cluster_config,
572 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
573 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
574 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
575 system_parameter_defaults: config.system_parameter_defaults,
576 availability_zones: Default::default(),
577 tracing_handle,
578 storage_usage_collection_interval: config.storage_usage_collection_interval,
579 storage_usage_retention_period: config.storage_usage_retention_period,
580 segment_api_key: None,
581 segment_client_side: false,
582 egress_addresses: vec![],
583 aws_account_id: None,
584 aws_privatelink_availability_zones: None,
585 launchdarkly_sdk_key: None,
586 launchdarkly_key_map: Default::default(),
587 config_sync_timeout: Duration::from_secs(30),
588 config_sync_loop_interval: None,
589 bootstrap_role: config.bootstrap_role,
590 http_host_name: Some(host_name),
591 internal_console_redirect_url: config.internal_console_redirect_url,
592 tls_reload_certs,
593 helm_chart_version: None,
594 license_key: ValidatedLicenseKey::for_tests(),
595 })
596 .await?;
597
598 Ok(TestServer {
599 inner,
600 metrics_registry,
601 _temp_dir: temp_dir,
602 _tracing_guard: tracing_guard,
603 })
604 }
605}
606
607pub struct TestServer {
609 pub inner: crate::Server,
610 pub metrics_registry: MetricsRegistry,
611 _temp_dir: Option<TempDir>,
612 _tracing_guard: Option<TracingGuard>,
613}
614
615impl TestServer {
616 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
617 ConnectBuilder::new(self).no_tls()
618 }
619
620 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
621 let internal_client = self.connect().internal().await.unwrap();
622
623 for flag in flags {
624 internal_client
625 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
626 .await
627 .unwrap();
628 }
629 }
630
631 pub fn ws_addr(&self) -> Url {
632 Url::parse(&format!(
633 "ws://{}/api/experimental/sql",
634 self.inner.http_local_addr()
635 ))
636 .unwrap()
637 }
638
639 pub fn internal_ws_addr(&self) -> Url {
640 Url::parse(&format!(
641 "ws://{}/api/experimental/sql",
642 self.inner.internal_http_local_addr()
643 ))
644 .unwrap()
645 }
646}
647
648pub struct ConnectBuilder<'s, T, H> {
652 server: &'s TestServer,
654
655 pg_config: tokio_postgres::Config,
657 port: u16,
659 tls: T,
661
662 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
664
665 _with_handle: H,
667}
668
669impl<'s> ConnectBuilder<'s, (), NoHandle> {
670 fn new(server: &'s TestServer) -> Self {
671 let mut pg_config = tokio_postgres::Config::new();
672 pg_config
673 .host(&Ipv4Addr::LOCALHOST.to_string())
674 .user("materialize")
675 .options("--welcome_message=off")
676 .application_name("environmentd_test_framework");
677
678 ConnectBuilder {
679 server,
680 pg_config,
681 port: server.inner.sql_local_addr().port(),
682 tls: (),
683 notice_callback: None,
684 _with_handle: NoHandle,
685 }
686 }
687}
688
689impl<'s, T, H> ConnectBuilder<'s, T, H> {
690 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
694 ConnectBuilder {
695 server: self.server,
696 pg_config: self.pg_config,
697 port: self.port,
698 tls: postgres::NoTls,
699 notice_callback: self.notice_callback,
700 _with_handle: self._with_handle,
701 }
702 }
703
704 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
706 where
707 Tls: MakeTlsConnect<Socket> + Send + 'static,
708 Tls::TlsConnect: Send,
709 Tls::Stream: Send,
710 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
711 {
712 ConnectBuilder {
713 server: self.server,
714 pg_config: self.pg_config,
715 port: self.port,
716 tls,
717 notice_callback: self.notice_callback,
718 _with_handle: self._with_handle,
719 }
720 }
721
722 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
724 self.pg_config = pg_config;
725 self
726 }
727
728 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
730 self.pg_config.ssl_mode(mode);
731 self
732 }
733
734 pub fn user(mut self, user: &str) -> Self {
736 self.pg_config.user(user);
737 self
738 }
739
740 pub fn password(mut self, password: &str) -> Self {
742 self.pg_config.password(password);
743 self
744 }
745
746 pub fn application_name(mut self, application_name: &str) -> Self {
748 self.pg_config.application_name(application_name);
749 self
750 }
751
752 pub fn dbname(mut self, dbname: &str) -> Self {
754 self.pg_config.dbname(dbname);
755 self
756 }
757
758 pub fn options(mut self, options: &str) -> Self {
760 self.pg_config.options(options);
761 self
762 }
763
764 pub fn internal(mut self) -> Self {
769 self.port = self.server.inner.internal_sql_local_addr().port();
770 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
771 self
772 }
773
774 pub fn balancer(mut self) -> Self {
779 self.port = self.server.inner.sql_local_addr().port();
780 self.pg_config.user("materialize");
781 self
782 }
783
784 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
786 ConnectBuilder {
787 notice_callback: Some(Box::new(callback)),
788 ..self
789 }
790 }
791
792 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
795 ConnectBuilder {
796 server: self.server,
797 pg_config: self.pg_config,
798 port: self.port,
799 tls: self.tls,
800 notice_callback: self.notice_callback,
801 _with_handle: WithHandle,
802 }
803 }
804
805 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
807 &self.pg_config
808 }
809}
810
811pub trait IncludeHandle: Send {
814 type Output;
815 fn transform_result(
816 client: tokio_postgres::Client,
817 handle: mz_ore::task::JoinHandle<()>,
818 ) -> Self::Output;
819}
820
821pub struct NoHandle;
824impl IncludeHandle for NoHandle {
825 type Output = tokio_postgres::Client;
826 fn transform_result(
827 client: tokio_postgres::Client,
828 _handle: mz_ore::task::JoinHandle<()>,
829 ) -> Self::Output {
830 client
831 }
832}
833
834pub struct WithHandle;
837impl IncludeHandle for WithHandle {
838 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
839 fn transform_result(
840 client: tokio_postgres::Client,
841 handle: mz_ore::task::JoinHandle<()>,
842 ) -> Self::Output {
843 (client, handle)
844 }
845}
846
847impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
848where
849 T: MakeTlsConnect<Socket> + Send + 'static,
850 T::TlsConnect: Send,
851 T::Stream: Send,
852 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
853 H: IncludeHandle,
854{
855 type Output = Result<H::Output, postgres::Error>;
856 type IntoFuture = BoxFuture<'static, Self::Output>;
857
858 fn into_future(mut self) -> Self::IntoFuture {
859 Box::pin(async move {
860 assert!(
861 self.pg_config.get_ports().is_empty(),
862 "specifying multiple ports is not supported"
863 );
864 self.pg_config.port(self.port);
865
866 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
867 let mut notice_callback = self.notice_callback.take();
868
869 let handle = task::spawn(|| "connect", async move {
870 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
871 match msg {
872 Ok(AsyncMessage::Notice(notice)) => {
873 if let Some(callback) = notice_callback.as_mut() {
874 callback(notice);
875 }
876 }
877 Ok(msg) => {
878 tracing::debug!(?msg, "Dropping message from database");
879 }
880 Err(e) => {
881 tracing::info!("connection error: {e}");
886 break;
887 }
888 }
889 }
890 tracing::info!("connection closed");
891 });
892
893 let output = H::transform_result(client, handle);
894 Ok(output)
895 })
896 }
897}
898
899pub struct TestServerWithRuntime {
904 server: TestServer,
905 runtime: Arc<Runtime>,
906}
907
908impl TestServerWithRuntime {
909 pub fn runtime(&self) -> &Arc<Runtime> {
913 &self.runtime
914 }
915
916 pub fn inner(&self) -> &crate::Server {
918 &self.server.inner
919 }
920
921 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
923 where
924 T: MakeTlsConnect<Socket> + Send + 'static,
925 T::TlsConnect: Send,
926 T::Stream: Send,
927 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
928 {
929 self.pg_config().connect(tls)
930 }
931
932 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
934 where
935 T: MakeTlsConnect<Socket> + Send + 'static,
936 T::TlsConnect: Send,
937 T::Stream: Send,
938 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
939 {
940 Ok(self.pg_config_internal().connect(tls)?)
941 }
942
943 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
945 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
946
947 for flag in flags {
948 internal_client
949 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
950 .unwrap();
951 }
952 }
953
954 pub fn pg_config(&self) -> postgres::Config {
957 let local_addr = self.server.inner.sql_local_addr();
958 let mut config = postgres::Config::new();
959 config
960 .host(&Ipv4Addr::LOCALHOST.to_string())
961 .port(local_addr.port())
962 .user("materialize")
963 .options("--welcome_message=off");
964 config
965 }
966
967 pub fn pg_config_internal(&self) -> postgres::Config {
970 let local_addr = self.server.inner.internal_sql_local_addr();
971 let mut config = postgres::Config::new();
972 config
973 .host(&Ipv4Addr::LOCALHOST.to_string())
974 .port(local_addr.port())
975 .user("mz_system")
976 .options("--welcome_message=off");
977 config
978 }
979
980 pub fn pg_config_balancer(&self) -> postgres::Config {
983 let local_addr = self.server.inner.sql_local_addr();
984 let mut config = postgres::Config::new();
985 config
986 .host(&Ipv4Addr::LOCALHOST.to_string())
987 .port(local_addr.port())
988 .user("materialize")
989 .options("--welcome_message=off")
990 .ssl_mode(tokio_postgres::config::SslMode::Disable);
991 config
992 }
993
994 pub fn ws_addr(&self) -> Url {
995 self.server.ws_addr()
996 }
997
998 pub fn internal_ws_addr(&self) -> Url {
999 self.server.internal_ws_addr()
1000 }
1001}
1002
1003#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1004pub struct MzTimestamp(pub u64);
1005
1006impl<'a> FromSql<'a> for MzTimestamp {
1007 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1008 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1009 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1010 }
1011
1012 fn accepts(ty: &Type) -> bool {
1013 mz_pgrepr::Numeric::accepts(ty)
1014 }
1015}
1016
1017pub trait PostgresErrorExt {
1018 fn unwrap_db_error(self) -> DbError;
1019}
1020
1021impl PostgresErrorExt for postgres::Error {
1022 fn unwrap_db_error(self) -> DbError {
1023 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1024 Some(e) => e.clone(),
1025 None => panic!("expected DbError, but got: {:?}", self),
1026 }
1027 }
1028}
1029
1030impl<T, E> PostgresErrorExt for Result<T, E>
1031where
1032 E: PostgresErrorExt,
1033{
1034 fn unwrap_db_error(self) -> DbError {
1035 match self {
1036 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1037 Err(e) => e.unwrap_db_error(),
1038 }
1039 }
1040}
1041
1042pub async fn insert_with_deterministic_timestamps(
1046 table: &'static str,
1047 values: &'static str,
1048 server: &TestServer,
1049 now: Arc<std::sync::Mutex<EpochMillis>>,
1050) -> Result<(), Box<dyn Error>> {
1051 let client_write = server.connect().await?;
1052 let client_read = server.connect().await?;
1053
1054 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1055
1056 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1057
1058 let write_future = client_write.execute(&insert_query, &[]);
1059 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1060
1061 let mut write_future = std::pin::pin!(write_future);
1062 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1063
1064 loop {
1067 tokio::select! {
1068 _ = (&mut write_future) => return Ok(()),
1069 _ = timestamp_interval.tick() => {
1070 current_ts += 1;
1071 *now.lock().expect("lock poisoned") = current_ts;
1072 }
1073 };
1074 }
1075}
1076
1077pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1078 try_get_explain_timestamp(from_suffix, client)
1079 .await
1080 .unwrap()
1081}
1082
1083pub async fn try_get_explain_timestamp(
1084 from_suffix: &str,
1085 client: &Client,
1086) -> Result<EpochMillis, anyhow::Error> {
1087 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1088 let ts = det.determination.timestamp_context.timestamp_or_default();
1089 Ok(ts.into())
1090}
1091
1092pub async fn get_explain_timestamp_determination(
1093 from_suffix: &str,
1094 client: &Client,
1095) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1096 let row = client
1097 .query_one(
1098 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1099 &[],
1100 )
1101 .await?;
1102 let explain: String = row.get(0);
1103 Ok(serde_json::from_str(&explain).unwrap())
1104}
1105
1106pub async fn create_postgres_source_with_table<'a>(
1114 server: &TestServer,
1115 mz_client: &Client,
1116 table_name: &str,
1117 table_schema: &str,
1118 source_name: &str,
1119) -> (
1120 Client,
1121 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1122) {
1123 server
1124 .enable_feature_flags(&["enable_create_table_from_source"])
1125 .await;
1126
1127 let postgres_url = env::var("POSTGRES_URL")
1128 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1129 .unwrap();
1130
1131 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1132 .await
1133 .unwrap();
1134
1135 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1136 let user = pg_config.get_user().unwrap_or("postgres");
1137 let db_name = pg_config.get_dbname().unwrap_or(user);
1138 let ports = pg_config.get_ports();
1139 let port = if ports.is_empty() { 5432 } else { ports[0] };
1140 let hosts = pg_config.get_hosts();
1141 let host = if hosts.is_empty() {
1142 "localhost".to_string()
1143 } else {
1144 match &hosts[0] {
1145 Host::Tcp(host) => host.to_string(),
1146 Host::Unix(host) => host.to_str().unwrap().to_string(),
1147 }
1148 };
1149 let password = pg_config.get_password();
1150
1151 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1152 if let Err(e) = connection.await {
1153 panic!("connection error: {}", e);
1154 }
1155 });
1156
1157 let _ = pg_client
1159 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1160 .await
1161 .unwrap();
1162 let _ = pg_client
1163 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1164 .await
1165 .unwrap();
1166 let _ = pg_client
1167 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1168 .await
1169 .unwrap();
1170 let _ = pg_client
1171 .execute(
1172 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1173 &[],
1174 )
1175 .await
1176 .unwrap();
1177 let _ = pg_client
1178 .execute(
1179 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1180 &[],
1181 )
1182 .await
1183 .unwrap();
1184
1185 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1187 if let Some(password) = password {
1188 let password = std::str::from_utf8(password).unwrap();
1189 mz_client
1190 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1191 .await
1192 .unwrap();
1193 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1194 }
1195 mz_client
1196 .batch_execute(&format!(
1197 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1198 ))
1199 .await
1200 .unwrap();
1201 mz_client
1202 .batch_execute(&format!(
1203 "CREATE SOURCE {source_name}
1204 FROM POSTGRES
1205 CONNECTION pgconn
1206 (PUBLICATION '{source_name}')"
1207 ))
1208 .await
1209 .unwrap();
1210 mz_client
1211 .batch_execute(&format!(
1212 "CREATE TABLE {table_name}
1213 FROM SOURCE {source_name}
1214 (REFERENCE {table_name});"
1215 ))
1216 .await
1217 .unwrap();
1218
1219 let table_name = table_name.to_string();
1220 let source_name = source_name.to_string();
1221 (
1222 pg_client,
1223 move |mz_client: &'a Client, pg_client: &'a Client| {
1224 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1225 mz_client
1226 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1227 .await
1228 .unwrap();
1229 mz_client
1230 .batch_execute("DROP CONNECTION pgconn;")
1231 .await
1232 .unwrap();
1233
1234 let _ = pg_client
1235 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1236 .await
1237 .unwrap();
1238 let _ = pg_client
1239 .execute(&format!("DROP TABLE {table_name};"), &[])
1240 .await
1241 .unwrap();
1242 });
1243 f
1244 },
1245 )
1246}
1247
1248pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1249 let current_isolation = mz_client
1250 .query_one("SHOW transaction_isolation", &[])
1251 .await
1252 .unwrap()
1253 .get::<_, String>(0);
1254 mz_client
1255 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1256 .await
1257 .unwrap();
1258 Retry::default()
1259 .retry_async(|_| async move {
1260 let rows = mz_client
1261 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1262 .await
1263 .unwrap()
1264 .get::<_, i64>(0);
1265 if rows == source_rows {
1266 Ok(())
1267 } else {
1268 Err(format!(
1269 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1270 ))
1271 }
1272 })
1273 .await
1274 .unwrap();
1275 mz_client
1276 .batch_execute(&format!(
1277 "SET transaction_isolation = '{current_isolation}'"
1278 ))
1279 .await
1280 .unwrap();
1281}
1282
1283pub fn auth_with_ws(
1285 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1286 mut options: BTreeMap<String, String>,
1287) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1288 if !options.contains_key("welcome_message") {
1289 options.insert("welcome_message".into(), "off".into());
1290 }
1291 auth_with_ws_impl(
1292 ws,
1293 Message::Text(
1294 serde_json::to_string(&WebSocketAuth::Basic {
1295 user: "materialize".into(),
1296 password: "".into(),
1297 options,
1298 })
1299 .unwrap(),
1300 ),
1301 )
1302}
1303
1304pub fn auth_with_ws_impl(
1305 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1306 auth_message: Message,
1307) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1308 ws.send(auth_message)?;
1309
1310 let mut msgs = Vec::new();
1312 loop {
1313 let resp = ws.read()?;
1314 match resp {
1315 Message::Text(msg) => {
1316 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1317 match msg {
1318 WebSocketResponse::ReadyForQuery(_) => break,
1319 msg => {
1320 msgs.push(msg);
1321 }
1322 }
1323 }
1324 Message::Ping(_) => continue,
1325 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1326 Message::Close(Some(close_frame)) => {
1327 return Err(anyhow!("ws closed after auth").context(close_frame));
1328 }
1329 _ => panic!("unexpected response: {:?}", resp),
1330 }
1331 }
1332 Ok(msgs)
1333}
1334
1335pub fn make_header<H: Header>(h: H) -> HeaderMap {
1336 let mut map = HeaderMap::new();
1337 map.typed_insert(h);
1338 map
1339}
1340
1341pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1342where
1343 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1344{
1345 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1346 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1367 connector_builder.set_options(options);
1368 configure(&mut connector_builder).unwrap();
1369 MakeTlsConnector::new(connector_builder.build())
1370}
1371
1372pub struct Ca {
1374 pub dir: TempDir,
1375 pub name: X509Name,
1376 pub cert: X509,
1377 pub pkey: PKey<Private>,
1378}
1379
1380impl Ca {
1381 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1382 let dir = tempfile::tempdir()?;
1383 let rsa = Rsa::generate(2048)?;
1384 let pkey = PKey::from_rsa(rsa)?;
1385 let name = {
1386 let mut builder = X509NameBuilder::new()?;
1387 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1388 builder.build()
1389 };
1390 let cert = {
1391 let mut builder = X509::builder()?;
1392 builder.set_version(2)?;
1393 builder.set_pubkey(&pkey)?;
1394 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1395 builder.set_subject_name(&name)?;
1396 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1397 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1398 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1399 builder.sign(
1400 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1401 MessageDigest::sha256(),
1402 )?;
1403 builder.build()
1404 };
1405 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1406 Ok(Ca {
1407 dir,
1408 name,
1409 cert,
1410 pkey,
1411 })
1412 }
1413
1414 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1416 Ca::make_ca(name, None)
1417 }
1418
1419 pub fn ca_cert_path(&self) -> PathBuf {
1421 self.dir.path().join("ca.crt")
1422 }
1423
1424 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1426 Ca::make_ca(name, Some(self))
1427 }
1428
1429 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1434 self.request_cert(name, iter::empty())
1435 }
1436
1437 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1440 where
1441 I: IntoIterator<Item = IpAddr>,
1442 {
1443 let rsa = Rsa::generate(2048)?;
1444 let pkey = PKey::from_rsa(rsa)?;
1445 let subject_name = {
1446 let mut builder = X509NameBuilder::new()?;
1447 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1448 builder.build()
1449 };
1450 let cert = {
1451 let mut builder = X509::builder()?;
1452 builder.set_version(2)?;
1453 builder.set_pubkey(&pkey)?;
1454 builder.set_issuer_name(self.cert.subject_name())?;
1455 builder.set_subject_name(&subject_name)?;
1456 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1457 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1458 for ip in ips {
1459 builder.append_extension(
1460 SubjectAlternativeName::new()
1461 .ip(&ip.to_string())
1462 .build(&builder.x509v3_context(None, None))?,
1463 )?;
1464 }
1465 builder.sign(&self.pkey, MessageDigest::sha256())?;
1466 builder.build()
1467 };
1468 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1469 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1470 fs::write(&cert_path, cert.to_pem()?)?;
1471 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1472 Ok((cert_path, key_path))
1473 }
1474}