1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91 WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97#[derive(Clone)]
99pub struct TestHarness {
100 data_directory: Option<PathBuf>,
101 tls: Option<TlsCertConfig>,
102 frontegg: Option<FronteggAuthenticator>,
103 external_login_password_mz_system: Option<Password>,
104 listeners_config: ListenersConfig,
105 unsafe_mode: bool,
106 workers: usize,
107 now: NowFn,
108 seed: u32,
109 storage_usage_collection_interval: Duration,
110 storage_usage_retention_period: Option<Duration>,
111 default_cluster_replica_size: String,
112 default_cluster_replication_factor: u32,
113 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119 propagate_crashes: bool,
120 enable_tracing: bool,
121 orchestrator_tracing_cli_args: TracingCliArgs,
124 bootstrap_role: Option<String>,
125 deploy_generation: u64,
126 system_parameter_defaults: BTreeMap<String, String>,
127 internal_console_redirect_url: Option<String>,
128 metrics_registry: Option<MetricsRegistry>,
129 code_version: semver::Version,
130 capture: Option<SharedStorage>,
131 pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135 fn default() -> TestHarness {
136 TestHarness {
137 data_directory: None,
138 tls: None,
139 frontegg: None,
140 external_login_password_mz_system: None,
141 listeners_config: ListenersConfig {
142 sql: btreemap![
143 "external".to_owned() => SqlListenerConfig {
144 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145 authenticator_kind: AuthenticatorKind::None,
146 allowed_roles: AllowedRoles::Normal,
147 enable_tls: false,
148 },
149 "internal".to_owned() => SqlListenerConfig {
150 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151 authenticator_kind: AuthenticatorKind::None,
152 allowed_roles: AllowedRoles::NormalAndInternal,
153 enable_tls: false,
154 },
155 ],
156 http: btreemap![
157 "external".to_owned() => HttpListenerConfig {
158 base: BaseListenerConfig {
159 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160 authenticator_kind: AuthenticatorKind::None,
161 allowed_roles: AllowedRoles::Normal,
162 enable_tls: false,
163 },
164 routes: HttpRoutesEnabled{
165 base: true,
166 webhook: true,
167 internal: false,
168 metrics: false,
169 profiling: false,
170 },
171 },
172 "internal".to_owned() => HttpListenerConfig {
173 base: BaseListenerConfig {
174 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
175 authenticator_kind: AuthenticatorKind::None,
176 allowed_roles: AllowedRoles::NormalAndInternal,
177 enable_tls: false,
178 },
179 routes: HttpRoutesEnabled{
180 base: true,
181 webhook: true,
182 internal: true,
183 metrics: true,
184 profiling: true,
185 },
186 },
187 ],
188 },
189 unsafe_mode: false,
190 workers: 1,
191 now: SYSTEM_TIME.clone(),
192 seed: rand::random(),
193 storage_usage_collection_interval: Duration::from_secs(3600),
194 storage_usage_retention_period: None,
195 default_cluster_replica_size: "scale=1,workers=1".to_string(),
196 default_cluster_replication_factor: 1,
197 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
198 size: "scale=1,workers=1".to_string(),
199 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
200 },
201 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
202 size: "scale=1,workers=1".to_string(),
203 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204 },
205 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
206 size: "scale=1,workers=1".to_string(),
207 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208 },
209 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
210 size: "scale=1,workers=1".to_string(),
211 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212 },
213 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
214 size: "scale=1,workers=1".to_string(),
215 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216 },
217 propagate_crashes: false,
218 enable_tracing: false,
219 bootstrap_role: Some("materialize".into()),
220 deploy_generation: 0,
221 system_parameter_defaults: BTreeMap::from([(
224 "log_filter".to_string(),
225 "error".to_string(),
226 )]),
227 internal_console_redirect_url: None,
228 metrics_registry: None,
229 orchestrator_tracing_cli_args: TracingCliArgs {
230 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
231 ..Default::default()
232 },
233 code_version: crate::BUILD_INFO.semver_version(),
234 environment_id: EnvironmentId::for_tests(),
235 capture: None,
236 }
237 }
238}
239
240impl TestHarness {
241 pub async fn start(self) -> TestServer {
245 self.try_start().await.expect("Failed to start test Server")
246 }
247
248 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
250 self.try_start_with_trigger(tls_reload_certs)
251 .await
252 .expect("Failed to start test Server")
253 }
254
255 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
257 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
258 .await
259 }
260
261 pub async fn try_start_with_trigger(
263 self,
264 tls_reload_certs: ReloadTrigger,
265 ) -> Result<TestServer, anyhow::Error> {
266 let listeners = Listeners::new(&self).await?;
267 listeners.serve_with_trigger(self, tls_reload_certs).await
268 }
269
270 pub fn start_blocking(self) -> TestServerWithRuntime {
272 let runtime = tokio::runtime::Builder::new_multi_thread()
273 .enable_all()
274 .thread_stack_size(mz_ore::stack::STACK_SIZE)
275 .build()
276 .expect("failed to spawn runtime for test");
277 let runtime = Arc::new(runtime);
278 let server = runtime.block_on(self.start());
279 TestServerWithRuntime { runtime, server }
280 }
281
282 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
283 self.data_directory = Some(data_directory.into());
284 self
285 }
286
287 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
288 self.tls = Some(TlsCertConfig {
289 cert: cert_path.into(),
290 key: key_path.into(),
291 });
292 for (_, listener) in &mut self.listeners_config.sql {
293 listener.enable_tls = true;
294 }
295 for (_, listener) in &mut self.listeners_config.http {
296 listener.base.enable_tls = true;
297 }
298 self
299 }
300
301 pub fn unsafe_mode(mut self) -> Self {
302 self.unsafe_mode = true;
303 self
304 }
305
306 pub fn workers(mut self, workers: usize) -> Self {
307 self.workers = workers;
308 self
309 }
310
311 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
312 self.frontegg = Some(frontegg.clone());
313 let enable_tls = self.tls.is_some();
314 self.listeners_config = ListenersConfig {
315 sql: btreemap! {
316 "external".to_owned() => SqlListenerConfig {
317 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
318 authenticator_kind: AuthenticatorKind::Frontegg,
319 allowed_roles: AllowedRoles::Normal,
320 enable_tls,
321 },
322 "internal".to_owned() => SqlListenerConfig {
323 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
324 authenticator_kind: AuthenticatorKind::None,
325 allowed_roles: AllowedRoles::NormalAndInternal,
326 enable_tls: false,
327 },
328 },
329 http: btreemap! {
330 "external".to_owned() => HttpListenerConfig {
331 base: BaseListenerConfig {
332 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
333 authenticator_kind: AuthenticatorKind::Frontegg,
334 allowed_roles: AllowedRoles::Normal,
335 enable_tls,
336 },
337 routes: HttpRoutesEnabled{
338 base: true,
339 webhook: true,
340 internal: false,
341 metrics: false,
342 profiling: false,
343 },
344 },
345 "internal".to_owned() => HttpListenerConfig {
346 base: BaseListenerConfig {
347 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
348 authenticator_kind: AuthenticatorKind::None,
349 allowed_roles: AllowedRoles::NormalAndInternal,
350 enable_tls: false,
351 },
352 routes: HttpRoutesEnabled{
353 base: true,
354 webhook: true,
355 internal: true,
356 metrics: true,
357 profiling: true,
358 },
359 },
360 },
361 };
362 self
363 }
364
365 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
366 self.external_login_password_mz_system = Some(mz_system_password);
367 let enable_tls = self.tls.is_some();
368 self.listeners_config = ListenersConfig {
369 sql: btreemap! {
370 "external".to_owned() => SqlListenerConfig {
371 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
372 authenticator_kind: AuthenticatorKind::Password,
373 allowed_roles: AllowedRoles::NormalAndInternal,
374 enable_tls,
375 },
376 },
377 http: btreemap! {
378 "external".to_owned() => HttpListenerConfig {
379 base: BaseListenerConfig {
380 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
381 authenticator_kind: AuthenticatorKind::Password,
382 allowed_roles: AllowedRoles::NormalAndInternal,
383 enable_tls,
384 },
385 routes: HttpRoutesEnabled{
386 base: true,
387 webhook: true,
388 internal: true,
389 metrics: false,
390 profiling: true,
391 },
392 },
393 "metrics".to_owned() => HttpListenerConfig {
394 base: BaseListenerConfig {
395 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
396 authenticator_kind: AuthenticatorKind::None,
397 allowed_roles: AllowedRoles::NormalAndInternal,
398 enable_tls: false,
399 },
400 routes: HttpRoutesEnabled{
401 base: false,
402 webhook: false,
403 internal: false,
404 metrics: true,
405 profiling: false,
406 },
407 },
408 },
409 };
410 self
411 }
412
413 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
414 self.external_login_password_mz_system = Some(mz_system_password);
415 let enable_tls = self.tls.is_some();
416 self.listeners_config = ListenersConfig {
417 sql: btreemap! {
418 "external".to_owned() => SqlListenerConfig {
419 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
420 authenticator_kind: AuthenticatorKind::Sasl,
421 allowed_roles: AllowedRoles::NormalAndInternal,
422 enable_tls,
423 },
424 },
425 http: btreemap! {
426 "external".to_owned() => HttpListenerConfig {
427 base: BaseListenerConfig {
428 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
429 authenticator_kind: AuthenticatorKind::Password,
430 allowed_roles: AllowedRoles::NormalAndInternal,
431 enable_tls,
432 },
433 routes: HttpRoutesEnabled{
434 base: true,
435 webhook: true,
436 internal: true,
437 metrics: false,
438 profiling: true,
439 },
440 },
441 "metrics".to_owned() => HttpListenerConfig {
442 base: BaseListenerConfig {
443 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
444 authenticator_kind: AuthenticatorKind::None,
445 allowed_roles: AllowedRoles::NormalAndInternal,
446 enable_tls: false,
447 },
448 routes: HttpRoutesEnabled{
449 base: false,
450 webhook: false,
451 internal: false,
452 metrics: true,
453 profiling: false,
454 },
455 },
456 },
457 };
458 self
459 }
460
461 pub fn with_now(mut self, now: NowFn) -> Self {
462 self.now = now;
463 self
464 }
465
466 pub fn with_storage_usage_collection_interval(
467 mut self,
468 storage_usage_collection_interval: Duration,
469 ) -> Self {
470 self.storage_usage_collection_interval = storage_usage_collection_interval;
471 self
472 }
473
474 pub fn with_storage_usage_retention_period(
475 mut self,
476 storage_usage_retention_period: Duration,
477 ) -> Self {
478 self.storage_usage_retention_period = Some(storage_usage_retention_period);
479 self
480 }
481
482 pub fn with_default_cluster_replica_size(
483 mut self,
484 default_cluster_replica_size: String,
485 ) -> Self {
486 self.default_cluster_replica_size = default_cluster_replica_size;
487 self
488 }
489
490 pub fn with_builtin_system_cluster_replica_size(
491 mut self,
492 builtin_system_cluster_replica_size: String,
493 ) -> Self {
494 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
495 self
496 }
497
498 pub fn with_builtin_system_cluster_replication_factor(
499 mut self,
500 builtin_system_cluster_replication_factor: u32,
501 ) -> Self {
502 self.builtin_system_cluster_config.replication_factor =
503 builtin_system_cluster_replication_factor;
504 self
505 }
506
507 pub fn with_builtin_catalog_server_cluster_replica_size(
508 mut self,
509 builtin_catalog_server_cluster_replica_size: String,
510 ) -> Self {
511 self.builtin_catalog_server_cluster_config.size =
512 builtin_catalog_server_cluster_replica_size;
513 self
514 }
515
516 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
517 self.propagate_crashes = propagate_crashes;
518 self
519 }
520
521 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
522 self.enable_tracing = enable_tracing;
523 self
524 }
525
526 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
527 self.bootstrap_role = bootstrap_role;
528 self
529 }
530
531 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
532 self.deploy_generation = deploy_generation;
533 self
534 }
535
536 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
537 self.system_parameter_defaults.insert(param, value);
538 self
539 }
540
541 pub fn with_internal_console_redirect_url(
542 mut self,
543 internal_console_redirect_url: Option<String>,
544 ) -> Self {
545 self.internal_console_redirect_url = internal_console_redirect_url;
546 self
547 }
548
549 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
550 self.metrics_registry = Some(registry);
551 self
552 }
553
554 pub fn with_code_version(mut self, version: semver::Version) -> Self {
555 self.code_version = version;
556 self
557 }
558
559 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
560 self.capture = Some(storage);
561 self
562 }
563}
564
565pub struct Listeners {
566 pub inner: crate::Listeners,
567}
568
569impl Listeners {
570 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
571 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
572 Ok(Listeners { inner })
573 }
574
575 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
576 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
577 .await
578 }
579
580 pub async fn serve_with_trigger(
581 self,
582 config: TestHarness,
583 tls_reload_certs: ReloadTrigger,
584 ) -> Result<TestServer, anyhow::Error> {
585 let (data_directory, temp_dir) = match config.data_directory {
586 None => {
587 let temp_dir = tempfile::tempdir()?;
592 (temp_dir.path().to_path_buf(), Some(temp_dir))
593 }
594 Some(data_directory) => (data_directory, None),
595 };
596 let scratch_dir = tempfile::tempdir()?;
597 let (consensus_uri, timestamp_oracle_url) = {
598 let seed = config.seed;
599 let cockroach_url = env::var("METADATA_BACKEND_URL")
600 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
601 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
602 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
603 if let Err(err) = conn.await {
604 panic!("connection error: {}", err);
605 };
606 });
607 client
608 .batch_execute(&format!(
609 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
610 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
611 ))
612 .await?;
613 (
614 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
615 .parse()
616 .expect("invalid consensus URI"),
617 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
618 .parse()
619 .expect("invalid timestamp oracle URI"),
620 )
621 };
622 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
623 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
624 image_dir: env::current_exe()?
625 .parent()
626 .unwrap()
627 .parent()
628 .unwrap()
629 .to_path_buf(),
630 suppress_output: false,
631 environment_id: config.environment_id.to_string(),
632 secrets_dir: data_directory.join("secrets"),
633 command_wrapper: vec![],
634 propagate_crashes: config.propagate_crashes,
635 tcp_proxy: None,
636 scratch_directory: scratch_dir.path().to_path_buf(),
637 })
638 .await?;
639 let orchestrator = Arc::new(orchestrator);
640 let persist_now = SYSTEM_TIME.clone();
643 let dyncfgs = mz_dyncfgs::all_dyncfgs();
644
645 let mut updates = ConfigUpdates::default();
646 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
649 updates.apply(&dyncfgs);
650
651 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
652 persist_cfg.build_version = config.code_version;
653 persist_cfg.set_rollup_threshold(5);
655
656 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
657 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
658 let persist_pubsub_tcp_listener =
659 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
660 .await
661 .expect("pubsub addr binding");
662 let persist_pubsub_server_port = persist_pubsub_tcp_listener
663 .local_addr()
664 .expect("pubsub addr has local addr")
665 .port();
666
667 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
669 persist_pubsub_server
670 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
671 .await
672 .expect("success")
673 });
674 let persist_clients =
675 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
676 let persist_clients = Arc::new(persist_clients);
677
678 let secrets_controller = Arc::clone(&orchestrator);
679 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
680 let orchestrator = Arc::new(TracingOrchestrator::new(
681 orchestrator,
682 config.orchestrator_tracing_cli_args,
683 ));
684 let tracing_handle = if config.enable_tracing {
685 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
686 service_name: "environmentd",
687 stderr_log: StderrLogConfig {
688 format: StderrLogFormat::Json,
689 filter: EnvFilter::default(),
690 },
691 opentelemetry: Some(OpenTelemetryConfig {
692 endpoint: "http://fake_address_for_testing:8080".to_string(),
693 headers: http::HeaderMap::new(),
694 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
695 resource: opentelemetry_sdk::resource::Resource::builder().build(),
696 max_batch_queue_size: 2048,
697 max_export_batch_size: 512,
698 max_concurrent_exports: 1,
699 batch_scheduled_delay: Duration::from_millis(5000),
700 max_export_timeout: Duration::from_secs(30),
701 }),
702 tokio_console: None,
703 sentry: None,
704 build_version: crate::BUILD_INFO.version,
705 build_sha: crate::BUILD_INFO.sha,
706 registry: metrics_registry.clone(),
707 capture: config.capture,
708 };
709 mz_ore::tracing::configure(config).await?
710 } else {
711 TracingHandle::disabled()
712 };
713 let host_name = format!(
714 "localhost:{}",
715 self.inner.http["external"].handle.local_addr.port()
716 );
717 let catalog_config = CatalogConfig {
718 persist_clients: Arc::clone(&persist_clients),
719 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
720 };
721
722 let inner = self
723 .inner
724 .serve(crate::Config {
725 catalog_config,
726 timestamp_oracle_url: Some(timestamp_oracle_url),
727 controller: ControllerConfig {
728 build_info: &crate::BUILD_INFO,
729 orchestrator,
730 clusterd_image: "clusterd".into(),
731 init_container_image: None,
732 deploy_generation: config.deploy_generation,
733 persist_location: PersistLocation {
734 blob_uri: format!("file://{}/persist/blob", data_directory.display())
735 .parse()
736 .expect("invalid blob URI"),
737 consensus_uri,
738 },
739 persist_clients,
740 now: config.now.clone(),
741 metrics_registry: metrics_registry.clone(),
742 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
743 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
744 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
745 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
746 secrets_reader_kubernetes_context: None,
747 secrets_reader_aws_prefix: None,
748 secrets_reader_name_prefix: None,
749 },
750 connection_context,
751 replica_http_locator: Default::default(),
752 },
753 secrets_controller,
754 cloud_resource_controller: None,
755 tls: config.tls,
756 frontegg: config.frontegg,
757 unsafe_mode: config.unsafe_mode,
758 all_features: false,
759 metrics_registry: metrics_registry.clone(),
760 now: config.now,
761 environment_id: config.environment_id,
762 cors_allowed_origin: AllowOrigin::list([]),
763 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
764 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
765 bootstrap_default_cluster_replication_factor: config
766 .default_cluster_replication_factor,
767 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
768 bootstrap_builtin_catalog_server_cluster_config: config
769 .builtin_catalog_server_cluster_config,
770 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
771 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
772 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
773 system_parameter_defaults: config.system_parameter_defaults,
774 availability_zones: Default::default(),
775 tracing_handle,
776 storage_usage_collection_interval: config.storage_usage_collection_interval,
777 storage_usage_retention_period: config.storage_usage_retention_period,
778 segment_api_key: None,
779 segment_client_side: false,
780 test_only_dummy_segment_client: false,
781 egress_addresses: vec![],
782 aws_account_id: None,
783 aws_privatelink_availability_zones: None,
784 launchdarkly_sdk_key: None,
785 launchdarkly_key_map: Default::default(),
786 config_sync_file_path: None,
787 config_sync_timeout: Duration::from_secs(30),
788 config_sync_loop_interval: None,
789 bootstrap_role: config.bootstrap_role,
790 http_host_name: Some(host_name),
791 internal_console_redirect_url: config.internal_console_redirect_url,
792 tls_reload_certs,
793 helm_chart_version: None,
794 license_key: ValidatedLicenseKey::for_tests(),
795 external_login_password_mz_system: config.external_login_password_mz_system,
796 force_builtin_schema_migration: None,
797 })
798 .await?;
799
800 Ok(TestServer {
801 inner,
802 metrics_registry,
803 _temp_dir: temp_dir,
804 _scratch_dir: scratch_dir,
805 })
806 }
807}
808
809pub struct TestServer {
811 pub inner: crate::Server,
812 pub metrics_registry: MetricsRegistry,
813 _temp_dir: Option<TempDir>,
815 _scratch_dir: TempDir,
816}
817
818impl TestServer {
819 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
820 ConnectBuilder::new(self).no_tls()
821 }
822
823 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
824 let internal_client = self.connect().internal().await.unwrap();
825
826 for flag in flags {
827 internal_client
828 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
829 .await
830 .unwrap();
831 }
832 }
833
834 pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
835 let internal_client = self.connect().internal().await.unwrap();
836
837 for flag in flags {
838 internal_client
839 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
840 .await
841 .unwrap();
842 }
843 }
844
845 pub fn ws_addr(&self) -> Uri {
846 format!(
847 "ws://{}/api/experimental/sql",
848 self.inner.http_listener_handles["external"].local_addr
849 )
850 .parse()
851 .unwrap()
852 }
853
854 pub fn internal_ws_addr(&self) -> Uri {
855 format!(
856 "ws://{}/api/experimental/sql",
857 self.inner.http_listener_handles["internal"].local_addr
858 )
859 .parse()
860 .unwrap()
861 }
862
863 pub fn http_local_addr(&self) -> SocketAddr {
864 self.inner.http_listener_handles["external"].local_addr
865 }
866
867 pub fn internal_http_local_addr(&self) -> SocketAddr {
868 self.inner.http_listener_handles["internal"].local_addr
869 }
870
871 pub fn sql_local_addr(&self) -> SocketAddr {
872 self.inner.sql_listener_handles["external"].local_addr
873 }
874
875 pub fn internal_sql_local_addr(&self) -> SocketAddr {
876 self.inner.sql_listener_handles["internal"].local_addr
877 }
878}
879
880pub struct ConnectBuilder<'s, T, H> {
884 server: &'s TestServer,
886
887 pg_config: tokio_postgres::Config,
889 port: u16,
891 tls: T,
893
894 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
896
897 _with_handle: H,
899}
900
901impl<'s> ConnectBuilder<'s, (), NoHandle> {
902 fn new(server: &'s TestServer) -> Self {
903 let mut pg_config = tokio_postgres::Config::new();
904 pg_config
905 .host(&Ipv4Addr::LOCALHOST.to_string())
906 .user("materialize")
907 .options("--welcome_message=off")
908 .application_name("environmentd_test_framework");
909
910 ConnectBuilder {
911 server,
912 pg_config,
913 port: server.sql_local_addr().port(),
914 tls: (),
915 notice_callback: None,
916 _with_handle: NoHandle,
917 }
918 }
919}
920
921impl<'s, T, H> ConnectBuilder<'s, T, H> {
922 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
926 ConnectBuilder {
927 server: self.server,
928 pg_config: self.pg_config,
929 port: self.port,
930 tls: postgres::NoTls,
931 notice_callback: self.notice_callback,
932 _with_handle: self._with_handle,
933 }
934 }
935
936 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
938 where
939 Tls: MakeTlsConnect<Socket> + Send + 'static,
940 Tls::TlsConnect: Send,
941 Tls::Stream: Send,
942 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
943 {
944 ConnectBuilder {
945 server: self.server,
946 pg_config: self.pg_config,
947 port: self.port,
948 tls,
949 notice_callback: self.notice_callback,
950 _with_handle: self._with_handle,
951 }
952 }
953
954 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
956 self.pg_config = pg_config;
957 self
958 }
959
960 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
962 self.pg_config.ssl_mode(mode);
963 self
964 }
965
966 pub fn user(mut self, user: &str) -> Self {
968 self.pg_config.user(user);
969 self
970 }
971
972 pub fn password(mut self, password: &str) -> Self {
974 self.pg_config.password(password);
975 self
976 }
977
978 pub fn application_name(mut self, application_name: &str) -> Self {
980 self.pg_config.application_name(application_name);
981 self
982 }
983
984 pub fn dbname(mut self, dbname: &str) -> Self {
986 self.pg_config.dbname(dbname);
987 self
988 }
989
990 pub fn options(mut self, options: &str) -> Self {
992 self.pg_config.options(options);
993 self
994 }
995
996 pub fn internal(mut self) -> Self {
1001 self.port = self.server.internal_sql_local_addr().port();
1002 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1003 self
1004 }
1005
1006 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1008 ConnectBuilder {
1009 notice_callback: Some(Box::new(callback)),
1010 ..self
1011 }
1012 }
1013
1014 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1017 ConnectBuilder {
1018 server: self.server,
1019 pg_config: self.pg_config,
1020 port: self.port,
1021 tls: self.tls,
1022 notice_callback: self.notice_callback,
1023 _with_handle: WithHandle,
1024 }
1025 }
1026
1027 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1029 &self.pg_config
1030 }
1031}
1032
1033pub trait IncludeHandle: Send {
1036 type Output;
1037 fn transform_result(
1038 client: tokio_postgres::Client,
1039 handle: mz_ore::task::JoinHandle<()>,
1040 ) -> Self::Output;
1041}
1042
1043pub struct NoHandle;
1046impl IncludeHandle for NoHandle {
1047 type Output = tokio_postgres::Client;
1048 fn transform_result(
1049 client: tokio_postgres::Client,
1050 _handle: mz_ore::task::JoinHandle<()>,
1051 ) -> Self::Output {
1052 client
1053 }
1054}
1055
1056pub struct WithHandle;
1059impl IncludeHandle for WithHandle {
1060 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1061 fn transform_result(
1062 client: tokio_postgres::Client,
1063 handle: mz_ore::task::JoinHandle<()>,
1064 ) -> Self::Output {
1065 (client, handle)
1066 }
1067}
1068
1069impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1070where
1071 T: MakeTlsConnect<Socket> + Send + 'static,
1072 T::TlsConnect: Send,
1073 T::Stream: Send,
1074 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1075 H: IncludeHandle,
1076{
1077 type Output = Result<H::Output, postgres::Error>;
1078 type IntoFuture = BoxFuture<'static, Self::Output>;
1079
1080 fn into_future(mut self) -> Self::IntoFuture {
1081 Box::pin(async move {
1082 assert!(
1083 self.pg_config.get_ports().is_empty(),
1084 "specifying multiple ports is not supported"
1085 );
1086 self.pg_config.port(self.port);
1087
1088 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1089 let mut notice_callback = self.notice_callback.take();
1090
1091 let handle = task::spawn(|| "connect", async move {
1092 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1093 match msg {
1094 Ok(AsyncMessage::Notice(notice)) => {
1095 if let Some(callback) = notice_callback.as_mut() {
1096 callback(notice);
1097 }
1098 }
1099 Ok(msg) => {
1100 tracing::debug!(?msg, "Dropping message from database");
1101 }
1102 Err(e) => {
1103 tracing::info!("connection error: {e}");
1108 break;
1109 }
1110 }
1111 }
1112 tracing::info!("connection closed");
1113 });
1114
1115 let output = H::transform_result(client, handle);
1116 Ok(output)
1117 })
1118 }
1119}
1120
1121pub struct TestServerWithRuntime {
1126 server: TestServer,
1127 runtime: Arc<Runtime>,
1128}
1129
1130impl TestServerWithRuntime {
1131 pub fn runtime(&self) -> &Arc<Runtime> {
1135 &self.runtime
1136 }
1137
1138 pub fn inner(&self) -> &crate::Server {
1140 &self.server.inner
1141 }
1142
1143 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1145 where
1146 T: MakeTlsConnect<Socket> + Send + 'static,
1147 T::TlsConnect: Send,
1148 T::Stream: Send,
1149 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1150 {
1151 self.pg_config().connect(tls)
1152 }
1153
1154 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1156 where
1157 T: MakeTlsConnect<Socket> + Send + 'static,
1158 T::TlsConnect: Send,
1159 T::Stream: Send,
1160 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1161 {
1162 Ok(self.pg_config_internal().connect(tls)?)
1163 }
1164
1165 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1167 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1168
1169 for flag in flags {
1170 internal_client
1171 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1172 .unwrap();
1173 }
1174 }
1175
1176 pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1178 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1179
1180 for flag in flags {
1181 internal_client
1182 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1183 .unwrap();
1184 }
1185 }
1186
1187 pub fn pg_config(&self) -> postgres::Config {
1190 let local_addr = self.server.sql_local_addr();
1191 let mut config = postgres::Config::new();
1192 config
1193 .host(&Ipv4Addr::LOCALHOST.to_string())
1194 .port(local_addr.port())
1195 .user("materialize")
1196 .options("--welcome_message=off");
1197 config
1198 }
1199
1200 pub fn pg_config_internal(&self) -> postgres::Config {
1203 let local_addr = self.server.internal_sql_local_addr();
1204 let mut config = postgres::Config::new();
1205 config
1206 .host(&Ipv4Addr::LOCALHOST.to_string())
1207 .port(local_addr.port())
1208 .user("mz_system")
1209 .options("--welcome_message=off");
1210 config
1211 }
1212
1213 pub fn ws_addr(&self) -> Uri {
1214 self.server.ws_addr()
1215 }
1216
1217 pub fn internal_ws_addr(&self) -> Uri {
1218 self.server.internal_ws_addr()
1219 }
1220
1221 pub fn http_local_addr(&self) -> SocketAddr {
1222 self.server.http_local_addr()
1223 }
1224
1225 pub fn internal_http_local_addr(&self) -> SocketAddr {
1226 self.server.internal_http_local_addr()
1227 }
1228
1229 pub fn sql_local_addr(&self) -> SocketAddr {
1230 self.server.sql_local_addr()
1231 }
1232
1233 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1234 self.server.internal_sql_local_addr()
1235 }
1236
1237 pub fn metrics_registry(&self) -> &MetricsRegistry {
1239 &self.server.metrics_registry
1240 }
1241}
1242
1243#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1244pub struct MzTimestamp(pub u64);
1245
1246impl<'a> FromSql<'a> for MzTimestamp {
1247 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1248 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1249 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1250 }
1251
1252 fn accepts(ty: &Type) -> bool {
1253 mz_pgrepr::Numeric::accepts(ty)
1254 }
1255}
1256
1257pub trait PostgresErrorExt {
1258 fn unwrap_db_error(self) -> DbError;
1259}
1260
1261impl PostgresErrorExt for postgres::Error {
1262 fn unwrap_db_error(self) -> DbError {
1263 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1264 Some(e) => e.clone(),
1265 None => panic!("expected DbError, but got: {:?}", self),
1266 }
1267 }
1268}
1269
1270impl<T, E> PostgresErrorExt for Result<T, E>
1271where
1272 E: PostgresErrorExt,
1273{
1274 fn unwrap_db_error(self) -> DbError {
1275 match self {
1276 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1277 Err(e) => e.unwrap_db_error(),
1278 }
1279 }
1280}
1281
1282pub async fn insert_with_deterministic_timestamps(
1286 table: &'static str,
1287 values: &'static str,
1288 server: &TestServer,
1289 now: Arc<std::sync::Mutex<EpochMillis>>,
1290) -> Result<(), Box<dyn Error>> {
1291 let client_write = server.connect().await?;
1292 let client_read = server.connect().await?;
1293
1294 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1295
1296 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1297
1298 let write_future = client_write.execute(&insert_query, &[]);
1299 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1300
1301 let mut write_future = std::pin::pin!(write_future);
1302 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1303
1304 loop {
1307 tokio::select! {
1308 _ = (&mut write_future) => return Ok(()),
1309 _ = timestamp_interval.tick() => {
1310 current_ts += 1;
1311 *now.lock().expect("lock poisoned") = current_ts;
1312 }
1313 };
1314 }
1315}
1316
1317pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1318 try_get_explain_timestamp(from_suffix, client)
1319 .await
1320 .unwrap()
1321}
1322
1323pub async fn try_get_explain_timestamp(
1324 from_suffix: &str,
1325 client: &Client,
1326) -> Result<EpochMillis, anyhow::Error> {
1327 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1328 let ts = det.determination.timestamp_context.timestamp_or_default();
1329 Ok(ts.into())
1330}
1331
1332pub async fn get_explain_timestamp_determination(
1333 from_suffix: &str,
1334 client: &Client,
1335) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1336 let row = client
1337 .query_one(
1338 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1339 &[],
1340 )
1341 .await?;
1342 let explain: String = row.get(0);
1343 Ok(serde_json::from_str(&explain).unwrap())
1344}
1345
1346pub async fn create_postgres_source_with_table<'a>(
1354 server: &TestServer,
1355 mz_client: &Client,
1356 table_name: &str,
1357 table_schema: &str,
1358 source_name: &str,
1359) -> (
1360 Client,
1361 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1362) {
1363 server
1364 .enable_feature_flags(&["enable_create_table_from_source"])
1365 .await;
1366
1367 let postgres_url = env::var("POSTGRES_URL")
1368 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1369 .unwrap();
1370
1371 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1372 .await
1373 .unwrap();
1374
1375 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1376 let user = pg_config.get_user().unwrap_or("postgres");
1377 let db_name = pg_config.get_dbname().unwrap_or(user);
1378 let ports = pg_config.get_ports();
1379 let port = if ports.is_empty() { 5432 } else { ports[0] };
1380 let hosts = pg_config.get_hosts();
1381 let host = if hosts.is_empty() {
1382 "localhost".to_string()
1383 } else {
1384 match &hosts[0] {
1385 Host::Tcp(host) => host.to_string(),
1386 Host::Unix(host) => host.to_str().unwrap().to_string(),
1387 }
1388 };
1389 let password = pg_config.get_password();
1390
1391 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1392 if let Err(e) = connection.await {
1393 panic!("connection error: {}", e);
1394 }
1395 });
1396
1397 let _ = pg_client
1399 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1400 .await
1401 .unwrap();
1402 let _ = pg_client
1403 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1404 .await
1405 .unwrap();
1406 let _ = pg_client
1407 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1408 .await
1409 .unwrap();
1410 let _ = pg_client
1411 .execute(
1412 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1413 &[],
1414 )
1415 .await
1416 .unwrap();
1417 let _ = pg_client
1418 .execute(
1419 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1420 &[],
1421 )
1422 .await
1423 .unwrap();
1424
1425 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1427 if let Some(password) = password {
1428 let password = std::str::from_utf8(password).unwrap();
1429 mz_client
1430 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1431 .await
1432 .unwrap();
1433 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1434 }
1435 mz_client
1436 .batch_execute(&format!(
1437 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1438 ))
1439 .await
1440 .unwrap();
1441 mz_client
1442 .batch_execute(&format!(
1443 "CREATE SOURCE {source_name}
1444 FROM POSTGRES
1445 CONNECTION pgconn
1446 (PUBLICATION '{source_name}')"
1447 ))
1448 .await
1449 .unwrap();
1450 mz_client
1451 .batch_execute(&format!(
1452 "CREATE TABLE {table_name}
1453 FROM SOURCE {source_name}
1454 (REFERENCE {table_name});"
1455 ))
1456 .await
1457 .unwrap();
1458
1459 let table_name = table_name.to_string();
1460 let source_name = source_name.to_string();
1461 (
1462 pg_client,
1463 move |mz_client: &'a Client, pg_client: &'a Client| {
1464 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1465 mz_client
1466 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1467 .await
1468 .unwrap();
1469 mz_client
1470 .batch_execute("DROP CONNECTION pgconn;")
1471 .await
1472 .unwrap();
1473
1474 let _ = pg_client
1475 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1476 .await
1477 .unwrap();
1478 let _ = pg_client
1479 .execute(&format!("DROP TABLE {table_name};"), &[])
1480 .await
1481 .unwrap();
1482 });
1483 f
1484 },
1485 )
1486}
1487
1488pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1489 let current_isolation = mz_client
1490 .query_one("SHOW transaction_isolation", &[])
1491 .await
1492 .unwrap()
1493 .get::<_, String>(0);
1494 mz_client
1495 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1496 .await
1497 .unwrap();
1498 Retry::default()
1499 .retry_async(|_| async move {
1500 let rows = mz_client
1501 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1502 .await
1503 .unwrap()
1504 .get::<_, i64>(0);
1505 if rows == source_rows {
1506 Ok(())
1507 } else {
1508 Err(format!(
1509 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1510 ))
1511 }
1512 })
1513 .await
1514 .unwrap();
1515 mz_client
1516 .batch_execute(&format!(
1517 "SET transaction_isolation = '{current_isolation}'"
1518 ))
1519 .await
1520 .unwrap();
1521}
1522
1523pub fn auth_with_ws(
1525 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1526 mut options: BTreeMap<String, String>,
1527) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1528 if !options.contains_key("welcome_message") {
1529 options.insert("welcome_message".into(), "off".into());
1530 }
1531 auth_with_ws_impl(
1532 ws,
1533 Message::Text(
1534 serde_json::to_string(&WebSocketAuth::Basic {
1535 user: "materialize".into(),
1536 password: "".into(),
1537 options,
1538 })
1539 .unwrap()
1540 .into(),
1541 ),
1542 )
1543}
1544
1545pub fn auth_with_ws_impl(
1546 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1547 auth_message: Message,
1548) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1549 ws.send(auth_message)?;
1550
1551 let mut msgs = Vec::new();
1553 loop {
1554 let resp = ws.read()?;
1555 match resp {
1556 Message::Text(msg) => {
1557 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1558 match msg {
1559 WebSocketResponse::ReadyForQuery(_) => break,
1560 msg => {
1561 msgs.push(msg);
1562 }
1563 }
1564 }
1565 Message::Ping(_) => continue,
1566 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1567 Message::Close(Some(close_frame)) => {
1568 return Err(anyhow!("ws closed after auth").context(close_frame));
1569 }
1570 _ => panic!("unexpected response: {:?}", resp),
1571 }
1572 }
1573 Ok(msgs)
1574}
1575
1576pub fn make_header<H: Header>(h: H) -> HeaderMap {
1577 let mut map = HeaderMap::new();
1578 map.typed_insert(h);
1579 map
1580}
1581
1582pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1583where
1584 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1585{
1586 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1587 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1608 connector_builder.set_options(options);
1609 configure(&mut connector_builder).unwrap();
1610 MakeTlsConnector::new(connector_builder.build())
1611}
1612
1613pub struct Ca {
1615 pub dir: TempDir,
1616 pub name: X509Name,
1617 pub cert: X509,
1618 pub pkey: PKey<Private>,
1619}
1620
1621impl Ca {
1622 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1623 let dir = tempfile::tempdir()?;
1624 let rsa = Rsa::generate(2048)?;
1625 let pkey = PKey::from_rsa(rsa)?;
1626 let name = {
1627 let mut builder = X509NameBuilder::new()?;
1628 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1629 builder.build()
1630 };
1631 let cert = {
1632 let mut builder = X509::builder()?;
1633 builder.set_version(2)?;
1634 builder.set_pubkey(&pkey)?;
1635 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1636 builder.set_subject_name(&name)?;
1637 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1638 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1639 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1640 builder.sign(
1641 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1642 MessageDigest::sha256(),
1643 )?;
1644 builder.build()
1645 };
1646 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1647 Ok(Ca {
1648 dir,
1649 name,
1650 cert,
1651 pkey,
1652 })
1653 }
1654
1655 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1657 Ca::make_ca(name, None)
1658 }
1659
1660 pub fn ca_cert_path(&self) -> PathBuf {
1662 self.dir.path().join("ca.crt")
1663 }
1664
1665 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1667 Ca::make_ca(name, Some(self))
1668 }
1669
1670 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1675 self.request_cert(name, iter::empty())
1676 }
1677
1678 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1681 where
1682 I: IntoIterator<Item = IpAddr>,
1683 {
1684 let rsa = Rsa::generate(2048)?;
1685 let pkey = PKey::from_rsa(rsa)?;
1686 let subject_name = {
1687 let mut builder = X509NameBuilder::new()?;
1688 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1689 builder.build()
1690 };
1691 let cert = {
1692 let mut builder = X509::builder()?;
1693 builder.set_version(2)?;
1694 builder.set_pubkey(&pkey)?;
1695 builder.set_issuer_name(self.cert.subject_name())?;
1696 builder.set_subject_name(&subject_name)?;
1697 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1698 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1699 for ip in ips {
1700 builder.append_extension(
1701 SubjectAlternativeName::new()
1702 .ip(&ip.to_string())
1703 .build(&builder.x509v3_context(None, None))?,
1704 )?;
1705 }
1706 builder.sign(&self.pkey, MessageDigest::sha256())?;
1707 builder.build()
1708 };
1709 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1710 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1711 fs::write(&cert_path, cert.to_pem()?)?;
1712 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1713 Ok((cert_path, key_path))
1714 }
1715}