1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingGuard,
49 TracingHandle,
50};
51use mz_persist_client::PersistLocation;
52use mz_persist_client::cache::PersistClientCache;
53use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
54use mz_persist_client::rpc::PersistGrpcPubSubServer;
55use mz_secrets::SecretsController;
56use mz_server_core::listeners::{
57 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
58};
59use mz_server_core::{ReloadTrigger, TlsCertConfig};
60use mz_sql::catalog::EnvironmentId;
61use mz_storage_types::connections::ConnectionContext;
62use mz_tracing::CloneableEnvFilter;
63use openssl::asn1::Asn1Time;
64use openssl::error::ErrorStack;
65use openssl::hash::MessageDigest;
66use openssl::nid::Nid;
67use openssl::pkey::{PKey, Private};
68use openssl::rsa::Rsa;
69use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
70use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
71use openssl::x509::{X509, X509Name, X509NameBuilder};
72use postgres::error::DbError;
73use postgres::tls::{MakeTlsConnect, TlsConnect};
74use postgres::types::{FromSql, Type};
75use postgres::{NoTls, Socket};
76use postgres_openssl::MakeTlsConnector;
77use tempfile::TempDir;
78use tokio::net::TcpListener;
79use tokio::runtime::Runtime;
80use tokio_postgres::config::{Host, SslMode};
81use tokio_postgres::{AsyncMessage, Client};
82use tokio_stream::wrappers::TcpListenerStream;
83use tower_http::cors::AllowOrigin;
84use tracing::Level;
85use tracing_capture::SharedStorage;
86use tracing_subscriber::EnvFilter;
87use tungstenite::stream::MaybeTlsStream;
88use tungstenite::{Message, WebSocket};
89
90use crate::{
91 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
92 WebSocketAuth, WebSocketResponse,
93};
94
95pub static KAFKA_ADDRS: LazyLock<String> =
96 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
97
98#[derive(Clone)]
100pub struct TestHarness {
101 data_directory: Option<PathBuf>,
102 tls: Option<TlsCertConfig>,
103 frontegg: Option<FronteggAuthenticator>,
104 external_login_password_mz_system: Option<Password>,
105 listeners_config: ListenersConfig,
106 unsafe_mode: bool,
107 workers: usize,
108 now: NowFn,
109 seed: u32,
110 storage_usage_collection_interval: Duration,
111 storage_usage_retention_period: Option<Duration>,
112 default_cluster_replica_size: String,
113 default_cluster_replication_factor: u32,
114 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
118 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
119
120 propagate_crashes: bool,
121 enable_tracing: bool,
122 orchestrator_tracing_cli_args: TracingCliArgs,
125 bootstrap_role: Option<String>,
126 deploy_generation: u64,
127 system_parameter_defaults: BTreeMap<String, String>,
128 internal_console_redirect_url: Option<String>,
129 metrics_registry: Option<MetricsRegistry>,
130 code_version: semver::Version,
131 capture: Option<SharedStorage>,
132 pub environment_id: EnvironmentId,
133}
134
135impl Default for TestHarness {
136 fn default() -> TestHarness {
137 TestHarness {
138 data_directory: None,
139 tls: None,
140 frontegg: None,
141 external_login_password_mz_system: None,
142 listeners_config: ListenersConfig {
143 sql: btreemap![
144 "external".to_owned() => SqlListenerConfig {
145 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
146 authenticator_kind: AuthenticatorKind::None,
147 allowed_roles: AllowedRoles::Normal,
148 enable_tls: false,
149 },
150 "internal".to_owned() => SqlListenerConfig {
151 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
152 authenticator_kind: AuthenticatorKind::None,
153 allowed_roles: AllowedRoles::NormalAndInternal,
154 enable_tls: false,
155 },
156 ],
157 http: btreemap![
158 "external".to_owned() => HttpListenerConfig {
159 base: BaseListenerConfig {
160 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
161 authenticator_kind: AuthenticatorKind::None,
162 allowed_roles: AllowedRoles::Normal,
163 enable_tls: false,
164 },
165 routes: HttpRoutesEnabled{
166 base: true,
167 webhook: true,
168 internal: false,
169 metrics: false,
170 profiling: false,
171 },
172 },
173 "internal".to_owned() => HttpListenerConfig {
174 base: BaseListenerConfig {
175 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
176 authenticator_kind: AuthenticatorKind::None,
177 allowed_roles: AllowedRoles::NormalAndInternal,
178 enable_tls: false,
179 },
180 routes: HttpRoutesEnabled{
181 base: true,
182 webhook: true,
183 internal: true,
184 metrics: true,
185 profiling: true,
186 },
187 },
188 ],
189 },
190 unsafe_mode: false,
191 workers: 1,
192 now: SYSTEM_TIME.clone(),
193 seed: rand::random(),
194 storage_usage_collection_interval: Duration::from_secs(3600),
195 storage_usage_retention_period: None,
196 default_cluster_replica_size: "scale=1,workers=1".to_string(),
197 default_cluster_replication_factor: 1,
198 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
199 size: "scale=1,workers=1".to_string(),
200 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
201 },
202 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
203 size: "scale=1,workers=1".to_string(),
204 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
205 },
206 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
207 size: "scale=1,workers=1".to_string(),
208 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
209 },
210 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
211 size: "scale=1,workers=1".to_string(),
212 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
213 },
214 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
215 size: "scale=1,workers=1".to_string(),
216 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
217 },
218 propagate_crashes: false,
219 enable_tracing: false,
220 bootstrap_role: Some("materialize".into()),
221 deploy_generation: 0,
222 system_parameter_defaults: BTreeMap::from([(
225 "log_filter".to_string(),
226 "error".to_string(),
227 )]),
228 internal_console_redirect_url: None,
229 metrics_registry: None,
230 orchestrator_tracing_cli_args: TracingCliArgs {
231 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
232 ..Default::default()
233 },
234 code_version: crate::BUILD_INFO.semver_version(),
235 environment_id: EnvironmentId::for_tests(),
236 capture: None,
237 }
238 }
239}
240
241impl TestHarness {
242 pub async fn start(self) -> TestServer {
246 self.try_start().await.expect("Failed to start test Server")
247 }
248
249 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
251 self.try_start_with_trigger(tls_reload_certs)
252 .await
253 .expect("Failed to start test Server")
254 }
255
256 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
258 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
259 .await
260 }
261
262 pub async fn try_start_with_trigger(
264 self,
265 tls_reload_certs: ReloadTrigger,
266 ) -> Result<TestServer, anyhow::Error> {
267 let listeners = Listeners::new(&self).await?;
268 listeners.serve_with_trigger(self, tls_reload_certs).await
269 }
270
271 pub fn start_blocking(self) -> TestServerWithRuntime {
273 stacker::grow(mz_ore::stack::STACK_SIZE, || {
274 let runtime = Runtime::new().expect("failed to spawn runtime for test");
275 let runtime = Arc::new(runtime);
276 let server = runtime.block_on(self.start());
277 TestServerWithRuntime { runtime, server }
278 })
279 }
280
281 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
282 self.data_directory = Some(data_directory.into());
283 self
284 }
285
286 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
287 self.tls = Some(TlsCertConfig {
288 cert: cert_path.into(),
289 key: key_path.into(),
290 });
291 for (_, listener) in &mut self.listeners_config.sql {
292 listener.enable_tls = true;
293 }
294 for (_, listener) in &mut self.listeners_config.http {
295 listener.base.enable_tls = true;
296 }
297 self
298 }
299
300 pub fn unsafe_mode(mut self) -> Self {
301 self.unsafe_mode = true;
302 self
303 }
304
305 pub fn workers(mut self, workers: usize) -> Self {
306 self.workers = workers;
307 self
308 }
309
310 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
311 self.frontegg = Some(frontegg.clone());
312 let enable_tls = self.tls.is_some();
313 self.listeners_config = ListenersConfig {
314 sql: btreemap! {
315 "external".to_owned() => SqlListenerConfig {
316 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
317 authenticator_kind: AuthenticatorKind::Frontegg,
318 allowed_roles: AllowedRoles::Normal,
319 enable_tls,
320 },
321 "internal".to_owned() => SqlListenerConfig {
322 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
323 authenticator_kind: AuthenticatorKind::None,
324 allowed_roles: AllowedRoles::NormalAndInternal,
325 enable_tls: false,
326 },
327 },
328 http: btreemap! {
329 "external".to_owned() => HttpListenerConfig {
330 base: BaseListenerConfig {
331 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
332 authenticator_kind: AuthenticatorKind::Frontegg,
333 allowed_roles: AllowedRoles::Normal,
334 enable_tls,
335 },
336 routes: HttpRoutesEnabled{
337 base: true,
338 webhook: true,
339 internal: false,
340 metrics: false,
341 profiling: false,
342 },
343 },
344 "internal".to_owned() => HttpListenerConfig {
345 base: BaseListenerConfig {
346 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
347 authenticator_kind: AuthenticatorKind::None,
348 allowed_roles: AllowedRoles::NormalAndInternal,
349 enable_tls: false,
350 },
351 routes: HttpRoutesEnabled{
352 base: true,
353 webhook: true,
354 internal: true,
355 metrics: true,
356 profiling: true,
357 },
358 },
359 },
360 };
361 self
362 }
363
364 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
365 self.external_login_password_mz_system = Some(mz_system_password);
366 let enable_tls = self.tls.is_some();
367 self.listeners_config = ListenersConfig {
368 sql: btreemap! {
369 "external".to_owned() => SqlListenerConfig {
370 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
371 authenticator_kind: AuthenticatorKind::Password,
372 allowed_roles: AllowedRoles::NormalAndInternal,
373 enable_tls,
374 },
375 },
376 http: btreemap! {
377 "external".to_owned() => HttpListenerConfig {
378 base: BaseListenerConfig {
379 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
380 authenticator_kind: AuthenticatorKind::Password,
381 allowed_roles: AllowedRoles::NormalAndInternal,
382 enable_tls,
383 },
384 routes: HttpRoutesEnabled{
385 base: true,
386 webhook: true,
387 internal: true,
388 metrics: false,
389 profiling: true,
390 },
391 },
392 "metrics".to_owned() => HttpListenerConfig {
393 base: BaseListenerConfig {
394 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
395 authenticator_kind: AuthenticatorKind::None,
396 allowed_roles: AllowedRoles::NormalAndInternal,
397 enable_tls: false,
398 },
399 routes: HttpRoutesEnabled{
400 base: false,
401 webhook: false,
402 internal: false,
403 metrics: true,
404 profiling: false,
405 },
406 },
407 },
408 };
409 self
410 }
411
412 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
413 self.external_login_password_mz_system = Some(mz_system_password);
414 let enable_tls = self.tls.is_some();
415 self.listeners_config = ListenersConfig {
416 sql: btreemap! {
417 "external".to_owned() => SqlListenerConfig {
418 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
419 authenticator_kind: AuthenticatorKind::Sasl,
420 allowed_roles: AllowedRoles::NormalAndInternal,
421 enable_tls,
422 },
423 },
424 http: btreemap! {
425 "external".to_owned() => HttpListenerConfig {
426 base: BaseListenerConfig {
427 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
428 authenticator_kind: AuthenticatorKind::Password,
429 allowed_roles: AllowedRoles::NormalAndInternal,
430 enable_tls,
431 },
432 routes: HttpRoutesEnabled{
433 base: true,
434 webhook: true,
435 internal: true,
436 metrics: false,
437 profiling: true,
438 },
439 },
440 "metrics".to_owned() => HttpListenerConfig {
441 base: BaseListenerConfig {
442 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
443 authenticator_kind: AuthenticatorKind::None,
444 allowed_roles: AllowedRoles::NormalAndInternal,
445 enable_tls: false,
446 },
447 routes: HttpRoutesEnabled{
448 base: false,
449 webhook: false,
450 internal: false,
451 metrics: true,
452 profiling: false,
453 },
454 },
455 },
456 };
457 self
458 }
459
460 pub fn with_now(mut self, now: NowFn) -> Self {
461 self.now = now;
462 self
463 }
464
465 pub fn with_storage_usage_collection_interval(
466 mut self,
467 storage_usage_collection_interval: Duration,
468 ) -> Self {
469 self.storage_usage_collection_interval = storage_usage_collection_interval;
470 self
471 }
472
473 pub fn with_storage_usage_retention_period(
474 mut self,
475 storage_usage_retention_period: Duration,
476 ) -> Self {
477 self.storage_usage_retention_period = Some(storage_usage_retention_period);
478 self
479 }
480
481 pub fn with_default_cluster_replica_size(
482 mut self,
483 default_cluster_replica_size: String,
484 ) -> Self {
485 self.default_cluster_replica_size = default_cluster_replica_size;
486 self
487 }
488
489 pub fn with_builtin_system_cluster_replica_size(
490 mut self,
491 builtin_system_cluster_replica_size: String,
492 ) -> Self {
493 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
494 self
495 }
496
497 pub fn with_builtin_system_cluster_replication_factor(
498 mut self,
499 builtin_system_cluster_replication_factor: u32,
500 ) -> Self {
501 self.builtin_system_cluster_config.replication_factor =
502 builtin_system_cluster_replication_factor;
503 self
504 }
505
506 pub fn with_builtin_catalog_server_cluster_replica_size(
507 mut self,
508 builtin_catalog_server_cluster_replica_size: String,
509 ) -> Self {
510 self.builtin_catalog_server_cluster_config.size =
511 builtin_catalog_server_cluster_replica_size;
512 self
513 }
514
515 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
516 self.propagate_crashes = propagate_crashes;
517 self
518 }
519
520 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
521 self.enable_tracing = enable_tracing;
522 self
523 }
524
525 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
526 self.bootstrap_role = bootstrap_role;
527 self
528 }
529
530 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
531 self.deploy_generation = deploy_generation;
532 self
533 }
534
535 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
536 self.system_parameter_defaults.insert(param, value);
537 self
538 }
539
540 pub fn with_internal_console_redirect_url(
541 mut self,
542 internal_console_redirect_url: Option<String>,
543 ) -> Self {
544 self.internal_console_redirect_url = internal_console_redirect_url;
545 self
546 }
547
548 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
549 self.metrics_registry = Some(registry);
550 self
551 }
552
553 pub fn with_code_version(mut self, version: semver::Version) -> Self {
554 self.code_version = version;
555 self
556 }
557
558 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
559 self.capture = Some(storage);
560 self
561 }
562}
563
564pub struct Listeners {
565 pub inner: crate::Listeners,
566}
567
568impl Listeners {
569 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
570 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
571 Ok(Listeners { inner })
572 }
573
574 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
575 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
576 .await
577 }
578
579 pub async fn serve_with_trigger(
580 self,
581 config: TestHarness,
582 tls_reload_certs: ReloadTrigger,
583 ) -> Result<TestServer, anyhow::Error> {
584 let (data_directory, temp_dir) = match config.data_directory {
585 None => {
586 let temp_dir = tempfile::tempdir()?;
591 (temp_dir.path().to_path_buf(), Some(temp_dir))
592 }
593 Some(data_directory) => (data_directory, None),
594 };
595 let scratch_dir = tempfile::tempdir()?;
596 let (consensus_uri, timestamp_oracle_url) = {
597 let seed = config.seed;
598 let cockroach_url = env::var("METADATA_BACKEND_URL")
599 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
600 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
601 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
602 if let Err(err) = conn.await {
603 panic!("connection error: {}", err);
604 };
605 });
606 client
607 .batch_execute(&format!(
608 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
609 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
610 ))
611 .await?;
612 (
613 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
614 .parse()
615 .expect("invalid consensus URI"),
616 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
617 .parse()
618 .expect("invalid timestamp oracle URI"),
619 )
620 };
621 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
622 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
623 image_dir: env::current_exe()?
624 .parent()
625 .unwrap()
626 .parent()
627 .unwrap()
628 .to_path_buf(),
629 suppress_output: false,
630 environment_id: config.environment_id.to_string(),
631 secrets_dir: data_directory.join("secrets"),
632 command_wrapper: vec![],
633 propagate_crashes: config.propagate_crashes,
634 tcp_proxy: None,
635 scratch_directory: scratch_dir.path().to_path_buf(),
636 })
637 .await?;
638 let orchestrator = Arc::new(orchestrator);
639 let persist_now = SYSTEM_TIME.clone();
642 let dyncfgs = mz_dyncfgs::all_dyncfgs();
643
644 let mut updates = ConfigUpdates::default();
645 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
648 updates.apply(&dyncfgs);
649
650 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
651 persist_cfg.build_version = config.code_version;
652 persist_cfg.set_rollup_threshold(5);
654
655 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
656 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
657 let persist_pubsub_tcp_listener =
658 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
659 .await
660 .expect("pubsub addr binding");
661 let persist_pubsub_server_port = persist_pubsub_tcp_listener
662 .local_addr()
663 .expect("pubsub addr has local addr")
664 .port();
665
666 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
668 persist_pubsub_server
669 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
670 .await
671 .expect("success")
672 });
673 let persist_clients =
674 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
675 let persist_clients = Arc::new(persist_clients);
676
677 let secrets_controller = Arc::clone(&orchestrator);
678 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
679 let orchestrator = Arc::new(TracingOrchestrator::new(
680 orchestrator,
681 config.orchestrator_tracing_cli_args,
682 ));
683 let (tracing_handle, tracing_guard) = if config.enable_tracing {
684 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
685 service_name: "environmentd",
686 stderr_log: StderrLogConfig {
687 format: StderrLogFormat::Json,
688 filter: EnvFilter::default(),
689 },
690 opentelemetry: Some(OpenTelemetryConfig {
691 endpoint: "http://fake_address_for_testing:8080".to_string(),
692 headers: http::HeaderMap::new(),
693 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
694 resource: opentelemetry_sdk::resource::Resource::default(),
695 max_batch_queue_size: 2048,
696 max_export_batch_size: 512,
697 max_concurrent_exports: 1,
698 batch_scheduled_delay: Duration::from_millis(5000),
699 max_export_timeout: Duration::from_secs(30),
700 }),
701 tokio_console: None,
702 sentry: None,
703 build_version: crate::BUILD_INFO.version,
704 build_sha: crate::BUILD_INFO.sha,
705 registry: metrics_registry.clone(),
706 capture: config.capture,
707 };
708 let (tracing_handle, tracing_guard) = mz_ore::tracing::configure(config).await?;
709 (tracing_handle, Some(tracing_guard))
710 } else {
711 (TracingHandle::disabled(), None)
712 };
713 let host_name = format!(
714 "localhost:{}",
715 self.inner.http["external"].handle.local_addr.port()
716 );
717 let catalog_config = CatalogConfig {
718 persist_clients: Arc::clone(&persist_clients),
719 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
720 };
721
722 let inner = self
723 .inner
724 .serve(crate::Config {
725 catalog_config,
726 timestamp_oracle_url: Some(timestamp_oracle_url),
727 controller: ControllerConfig {
728 build_info: &crate::BUILD_INFO,
729 orchestrator,
730 clusterd_image: "clusterd".into(),
731 init_container_image: None,
732 deploy_generation: config.deploy_generation,
733 persist_location: PersistLocation {
734 blob_uri: format!("file://{}/persist/blob", data_directory.display())
735 .parse()
736 .expect("invalid blob URI"),
737 consensus_uri,
738 },
739 persist_clients,
740 now: config.now.clone(),
741 metrics_registry: metrics_registry.clone(),
742 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
743 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
744 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
745 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
746 secrets_reader_kubernetes_context: None,
747 secrets_reader_aws_prefix: None,
748 secrets_reader_name_prefix: None,
749 },
750 connection_context,
751 },
752 secrets_controller,
753 cloud_resource_controller: None,
754 tls: config.tls,
755 frontegg: config.frontegg,
756 unsafe_mode: config.unsafe_mode,
757 all_features: false,
758 metrics_registry: metrics_registry.clone(),
759 now: config.now,
760 environment_id: config.environment_id,
761 cors_allowed_origin: AllowOrigin::list([]),
762 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
763 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
764 bootstrap_default_cluster_replication_factor: config
765 .default_cluster_replication_factor,
766 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
767 bootstrap_builtin_catalog_server_cluster_config: config
768 .builtin_catalog_server_cluster_config,
769 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
770 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
771 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
772 system_parameter_defaults: config.system_parameter_defaults,
773 availability_zones: Default::default(),
774 tracing_handle,
775 storage_usage_collection_interval: config.storage_usage_collection_interval,
776 storage_usage_retention_period: config.storage_usage_retention_period,
777 segment_api_key: None,
778 segment_client_side: false,
779 test_only_dummy_segment_client: false,
780 egress_addresses: vec![],
781 aws_account_id: None,
782 aws_privatelink_availability_zones: None,
783 launchdarkly_sdk_key: None,
784 launchdarkly_key_map: Default::default(),
785 config_sync_file_path: None,
786 config_sync_timeout: Duration::from_secs(30),
787 config_sync_loop_interval: None,
788 bootstrap_role: config.bootstrap_role,
789 http_host_name: Some(host_name),
790 internal_console_redirect_url: config.internal_console_redirect_url,
791 tls_reload_certs,
792 helm_chart_version: None,
793 license_key: ValidatedLicenseKey::for_tests(),
794 external_login_password_mz_system: config.external_login_password_mz_system,
795 })
796 .await?;
797
798 Ok(TestServer {
799 inner,
800 metrics_registry,
801 _temp_dir: temp_dir,
802 _scratch_dir: scratch_dir,
803 _tracing_guard: tracing_guard,
804 })
805 }
806}
807
808pub struct TestServer {
810 pub inner: crate::Server,
811 pub metrics_registry: MetricsRegistry,
812 _temp_dir: Option<TempDir>,
814 _scratch_dir: TempDir,
815 _tracing_guard: Option<TracingGuard>,
816}
817
818impl TestServer {
819 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
820 ConnectBuilder::new(self).no_tls()
821 }
822
823 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
824 let internal_client = self.connect().internal().await.unwrap();
825
826 for flag in flags {
827 internal_client
828 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
829 .await
830 .unwrap();
831 }
832 }
833
834 pub fn ws_addr(&self) -> Uri {
835 format!(
836 "ws://{}/api/experimental/sql",
837 self.inner.http_listener_handles["external"].local_addr
838 )
839 .parse()
840 .unwrap()
841 }
842
843 pub fn internal_ws_addr(&self) -> Uri {
844 format!(
845 "ws://{}/api/experimental/sql",
846 self.inner.http_listener_handles["internal"].local_addr
847 )
848 .parse()
849 .unwrap()
850 }
851
852 pub fn http_local_addr(&self) -> SocketAddr {
853 self.inner.http_listener_handles["external"].local_addr
854 }
855
856 pub fn internal_http_local_addr(&self) -> SocketAddr {
857 self.inner.http_listener_handles["internal"].local_addr
858 }
859
860 pub fn sql_local_addr(&self) -> SocketAddr {
861 self.inner.sql_listener_handles["external"].local_addr
862 }
863
864 pub fn internal_sql_local_addr(&self) -> SocketAddr {
865 self.inner.sql_listener_handles["internal"].local_addr
866 }
867}
868
869pub struct ConnectBuilder<'s, T, H> {
873 server: &'s TestServer,
875
876 pg_config: tokio_postgres::Config,
878 port: u16,
880 tls: T,
882
883 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
885
886 _with_handle: H,
888}
889
890impl<'s> ConnectBuilder<'s, (), NoHandle> {
891 fn new(server: &'s TestServer) -> Self {
892 let mut pg_config = tokio_postgres::Config::new();
893 pg_config
894 .host(&Ipv4Addr::LOCALHOST.to_string())
895 .user("materialize")
896 .options("--welcome_message=off")
897 .application_name("environmentd_test_framework");
898
899 ConnectBuilder {
900 server,
901 pg_config,
902 port: server.sql_local_addr().port(),
903 tls: (),
904 notice_callback: None,
905 _with_handle: NoHandle,
906 }
907 }
908}
909
910impl<'s, T, H> ConnectBuilder<'s, T, H> {
911 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
915 ConnectBuilder {
916 server: self.server,
917 pg_config: self.pg_config,
918 port: self.port,
919 tls: postgres::NoTls,
920 notice_callback: self.notice_callback,
921 _with_handle: self._with_handle,
922 }
923 }
924
925 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
927 where
928 Tls: MakeTlsConnect<Socket> + Send + 'static,
929 Tls::TlsConnect: Send,
930 Tls::Stream: Send,
931 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
932 {
933 ConnectBuilder {
934 server: self.server,
935 pg_config: self.pg_config,
936 port: self.port,
937 tls,
938 notice_callback: self.notice_callback,
939 _with_handle: self._with_handle,
940 }
941 }
942
943 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
945 self.pg_config = pg_config;
946 self
947 }
948
949 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
951 self.pg_config.ssl_mode(mode);
952 self
953 }
954
955 pub fn user(mut self, user: &str) -> Self {
957 self.pg_config.user(user);
958 self
959 }
960
961 pub fn password(mut self, password: &str) -> Self {
963 self.pg_config.password(password);
964 self
965 }
966
967 pub fn application_name(mut self, application_name: &str) -> Self {
969 self.pg_config.application_name(application_name);
970 self
971 }
972
973 pub fn dbname(mut self, dbname: &str) -> Self {
975 self.pg_config.dbname(dbname);
976 self
977 }
978
979 pub fn options(mut self, options: &str) -> Self {
981 self.pg_config.options(options);
982 self
983 }
984
985 pub fn internal(mut self) -> Self {
990 self.port = self.server.internal_sql_local_addr().port();
991 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
992 self
993 }
994
995 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
997 ConnectBuilder {
998 notice_callback: Some(Box::new(callback)),
999 ..self
1000 }
1001 }
1002
1003 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1006 ConnectBuilder {
1007 server: self.server,
1008 pg_config: self.pg_config,
1009 port: self.port,
1010 tls: self.tls,
1011 notice_callback: self.notice_callback,
1012 _with_handle: WithHandle,
1013 }
1014 }
1015
1016 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1018 &self.pg_config
1019 }
1020}
1021
1022pub trait IncludeHandle: Send {
1025 type Output;
1026 fn transform_result(
1027 client: tokio_postgres::Client,
1028 handle: mz_ore::task::JoinHandle<()>,
1029 ) -> Self::Output;
1030}
1031
1032pub struct NoHandle;
1035impl IncludeHandle for NoHandle {
1036 type Output = tokio_postgres::Client;
1037 fn transform_result(
1038 client: tokio_postgres::Client,
1039 _handle: mz_ore::task::JoinHandle<()>,
1040 ) -> Self::Output {
1041 client
1042 }
1043}
1044
1045pub struct WithHandle;
1048impl IncludeHandle for WithHandle {
1049 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1050 fn transform_result(
1051 client: tokio_postgres::Client,
1052 handle: mz_ore::task::JoinHandle<()>,
1053 ) -> Self::Output {
1054 (client, handle)
1055 }
1056}
1057
1058impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1059where
1060 T: MakeTlsConnect<Socket> + Send + 'static,
1061 T::TlsConnect: Send,
1062 T::Stream: Send,
1063 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1064 H: IncludeHandle,
1065{
1066 type Output = Result<H::Output, postgres::Error>;
1067 type IntoFuture = BoxFuture<'static, Self::Output>;
1068
1069 fn into_future(mut self) -> Self::IntoFuture {
1070 Box::pin(async move {
1071 assert!(
1072 self.pg_config.get_ports().is_empty(),
1073 "specifying multiple ports is not supported"
1074 );
1075 self.pg_config.port(self.port);
1076
1077 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1078 let mut notice_callback = self.notice_callback.take();
1079
1080 let handle = task::spawn(|| "connect", async move {
1081 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1082 match msg {
1083 Ok(AsyncMessage::Notice(notice)) => {
1084 if let Some(callback) = notice_callback.as_mut() {
1085 callback(notice);
1086 }
1087 }
1088 Ok(msg) => {
1089 tracing::debug!(?msg, "Dropping message from database");
1090 }
1091 Err(e) => {
1092 tracing::info!("connection error: {e}");
1097 break;
1098 }
1099 }
1100 }
1101 tracing::info!("connection closed");
1102 });
1103
1104 let output = H::transform_result(client, handle);
1105 Ok(output)
1106 })
1107 }
1108}
1109
1110pub struct TestServerWithRuntime {
1115 server: TestServer,
1116 runtime: Arc<Runtime>,
1117}
1118
1119impl TestServerWithRuntime {
1120 pub fn runtime(&self) -> &Arc<Runtime> {
1124 &self.runtime
1125 }
1126
1127 pub fn inner(&self) -> &crate::Server {
1129 &self.server.inner
1130 }
1131
1132 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1134 where
1135 T: MakeTlsConnect<Socket> + Send + 'static,
1136 T::TlsConnect: Send,
1137 T::Stream: Send,
1138 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1139 {
1140 self.pg_config().connect(tls)
1141 }
1142
1143 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1145 where
1146 T: MakeTlsConnect<Socket> + Send + 'static,
1147 T::TlsConnect: Send,
1148 T::Stream: Send,
1149 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1150 {
1151 Ok(self.pg_config_internal().connect(tls)?)
1152 }
1153
1154 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1156 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1157
1158 for flag in flags {
1159 internal_client
1160 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1161 .unwrap();
1162 }
1163 }
1164
1165 pub fn pg_config(&self) -> postgres::Config {
1168 let local_addr = self.server.sql_local_addr();
1169 let mut config = postgres::Config::new();
1170 config
1171 .host(&Ipv4Addr::LOCALHOST.to_string())
1172 .port(local_addr.port())
1173 .user("materialize")
1174 .options("--welcome_message=off");
1175 config
1176 }
1177
1178 pub fn pg_config_internal(&self) -> postgres::Config {
1181 let local_addr = self.server.internal_sql_local_addr();
1182 let mut config = postgres::Config::new();
1183 config
1184 .host(&Ipv4Addr::LOCALHOST.to_string())
1185 .port(local_addr.port())
1186 .user("mz_system")
1187 .options("--welcome_message=off");
1188 config
1189 }
1190
1191 pub fn ws_addr(&self) -> Uri {
1192 self.server.ws_addr()
1193 }
1194
1195 pub fn internal_ws_addr(&self) -> Uri {
1196 self.server.internal_ws_addr()
1197 }
1198
1199 pub fn http_local_addr(&self) -> SocketAddr {
1200 self.server.http_local_addr()
1201 }
1202
1203 pub fn internal_http_local_addr(&self) -> SocketAddr {
1204 self.server.internal_http_local_addr()
1205 }
1206
1207 pub fn sql_local_addr(&self) -> SocketAddr {
1208 self.server.sql_local_addr()
1209 }
1210
1211 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1212 self.server.internal_sql_local_addr()
1213 }
1214}
1215
1216#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1217pub struct MzTimestamp(pub u64);
1218
1219impl<'a> FromSql<'a> for MzTimestamp {
1220 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1221 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1222 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1223 }
1224
1225 fn accepts(ty: &Type) -> bool {
1226 mz_pgrepr::Numeric::accepts(ty)
1227 }
1228}
1229
1230pub trait PostgresErrorExt {
1231 fn unwrap_db_error(self) -> DbError;
1232}
1233
1234impl PostgresErrorExt for postgres::Error {
1235 fn unwrap_db_error(self) -> DbError {
1236 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1237 Some(e) => e.clone(),
1238 None => panic!("expected DbError, but got: {:?}", self),
1239 }
1240 }
1241}
1242
1243impl<T, E> PostgresErrorExt for Result<T, E>
1244where
1245 E: PostgresErrorExt,
1246{
1247 fn unwrap_db_error(self) -> DbError {
1248 match self {
1249 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1250 Err(e) => e.unwrap_db_error(),
1251 }
1252 }
1253}
1254
1255pub async fn insert_with_deterministic_timestamps(
1259 table: &'static str,
1260 values: &'static str,
1261 server: &TestServer,
1262 now: Arc<std::sync::Mutex<EpochMillis>>,
1263) -> Result<(), Box<dyn Error>> {
1264 let client_write = server.connect().await?;
1265 let client_read = server.connect().await?;
1266
1267 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1268
1269 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1270
1271 let write_future = client_write.execute(&insert_query, &[]);
1272 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1273
1274 let mut write_future = std::pin::pin!(write_future);
1275 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1276
1277 loop {
1280 tokio::select! {
1281 _ = (&mut write_future) => return Ok(()),
1282 _ = timestamp_interval.tick() => {
1283 current_ts += 1;
1284 *now.lock().expect("lock poisoned") = current_ts;
1285 }
1286 };
1287 }
1288}
1289
1290pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1291 try_get_explain_timestamp(from_suffix, client)
1292 .await
1293 .unwrap()
1294}
1295
1296pub async fn try_get_explain_timestamp(
1297 from_suffix: &str,
1298 client: &Client,
1299) -> Result<EpochMillis, anyhow::Error> {
1300 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1301 let ts = det.determination.timestamp_context.timestamp_or_default();
1302 Ok(ts.into())
1303}
1304
1305pub async fn get_explain_timestamp_determination(
1306 from_suffix: &str,
1307 client: &Client,
1308) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1309 let row = client
1310 .query_one(
1311 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1312 &[],
1313 )
1314 .await?;
1315 let explain: String = row.get(0);
1316 Ok(serde_json::from_str(&explain).unwrap())
1317}
1318
1319pub async fn create_postgres_source_with_table<'a>(
1327 server: &TestServer,
1328 mz_client: &Client,
1329 table_name: &str,
1330 table_schema: &str,
1331 source_name: &str,
1332) -> (
1333 Client,
1334 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1335) {
1336 server
1337 .enable_feature_flags(&["enable_create_table_from_source"])
1338 .await;
1339
1340 let postgres_url = env::var("POSTGRES_URL")
1341 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1342 .unwrap();
1343
1344 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1345 .await
1346 .unwrap();
1347
1348 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1349 let user = pg_config.get_user().unwrap_or("postgres");
1350 let db_name = pg_config.get_dbname().unwrap_or(user);
1351 let ports = pg_config.get_ports();
1352 let port = if ports.is_empty() { 5432 } else { ports[0] };
1353 let hosts = pg_config.get_hosts();
1354 let host = if hosts.is_empty() {
1355 "localhost".to_string()
1356 } else {
1357 match &hosts[0] {
1358 Host::Tcp(host) => host.to_string(),
1359 Host::Unix(host) => host.to_str().unwrap().to_string(),
1360 }
1361 };
1362 let password = pg_config.get_password();
1363
1364 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1365 if let Err(e) = connection.await {
1366 panic!("connection error: {}", e);
1367 }
1368 });
1369
1370 let _ = pg_client
1372 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1373 .await
1374 .unwrap();
1375 let _ = pg_client
1376 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1377 .await
1378 .unwrap();
1379 let _ = pg_client
1380 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1381 .await
1382 .unwrap();
1383 let _ = pg_client
1384 .execute(
1385 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1386 &[],
1387 )
1388 .await
1389 .unwrap();
1390 let _ = pg_client
1391 .execute(
1392 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1393 &[],
1394 )
1395 .await
1396 .unwrap();
1397
1398 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1400 if let Some(password) = password {
1401 let password = std::str::from_utf8(password).unwrap();
1402 mz_client
1403 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1404 .await
1405 .unwrap();
1406 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1407 }
1408 mz_client
1409 .batch_execute(&format!(
1410 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1411 ))
1412 .await
1413 .unwrap();
1414 mz_client
1415 .batch_execute(&format!(
1416 "CREATE SOURCE {source_name}
1417 FROM POSTGRES
1418 CONNECTION pgconn
1419 (PUBLICATION '{source_name}')"
1420 ))
1421 .await
1422 .unwrap();
1423 mz_client
1424 .batch_execute(&format!(
1425 "CREATE TABLE {table_name}
1426 FROM SOURCE {source_name}
1427 (REFERENCE {table_name});"
1428 ))
1429 .await
1430 .unwrap();
1431
1432 let table_name = table_name.to_string();
1433 let source_name = source_name.to_string();
1434 (
1435 pg_client,
1436 move |mz_client: &'a Client, pg_client: &'a Client| {
1437 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1438 mz_client
1439 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1440 .await
1441 .unwrap();
1442 mz_client
1443 .batch_execute("DROP CONNECTION pgconn;")
1444 .await
1445 .unwrap();
1446
1447 let _ = pg_client
1448 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1449 .await
1450 .unwrap();
1451 let _ = pg_client
1452 .execute(&format!("DROP TABLE {table_name};"), &[])
1453 .await
1454 .unwrap();
1455 });
1456 f
1457 },
1458 )
1459}
1460
1461pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1462 let current_isolation = mz_client
1463 .query_one("SHOW transaction_isolation", &[])
1464 .await
1465 .unwrap()
1466 .get::<_, String>(0);
1467 mz_client
1468 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1469 .await
1470 .unwrap();
1471 Retry::default()
1472 .retry_async(|_| async move {
1473 let rows = mz_client
1474 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1475 .await
1476 .unwrap()
1477 .get::<_, i64>(0);
1478 if rows == source_rows {
1479 Ok(())
1480 } else {
1481 Err(format!(
1482 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1483 ))
1484 }
1485 })
1486 .await
1487 .unwrap();
1488 mz_client
1489 .batch_execute(&format!(
1490 "SET transaction_isolation = '{current_isolation}'"
1491 ))
1492 .await
1493 .unwrap();
1494}
1495
1496pub fn auth_with_ws(
1498 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1499 mut options: BTreeMap<String, String>,
1500) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1501 if !options.contains_key("welcome_message") {
1502 options.insert("welcome_message".into(), "off".into());
1503 }
1504 auth_with_ws_impl(
1505 ws,
1506 Message::Text(
1507 serde_json::to_string(&WebSocketAuth::Basic {
1508 user: "materialize".into(),
1509 password: "".into(),
1510 options,
1511 })
1512 .unwrap(),
1513 ),
1514 )
1515}
1516
1517pub fn auth_with_ws_impl(
1518 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1519 auth_message: Message,
1520) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1521 ws.send(auth_message)?;
1522
1523 let mut msgs = Vec::new();
1525 loop {
1526 let resp = ws.read()?;
1527 match resp {
1528 Message::Text(msg) => {
1529 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1530 match msg {
1531 WebSocketResponse::ReadyForQuery(_) => break,
1532 msg => {
1533 msgs.push(msg);
1534 }
1535 }
1536 }
1537 Message::Ping(_) => continue,
1538 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1539 Message::Close(Some(close_frame)) => {
1540 return Err(anyhow!("ws closed after auth").context(close_frame));
1541 }
1542 _ => panic!("unexpected response: {:?}", resp),
1543 }
1544 }
1545 Ok(msgs)
1546}
1547
1548pub fn make_header<H: Header>(h: H) -> HeaderMap {
1549 let mut map = HeaderMap::new();
1550 map.typed_insert(h);
1551 map
1552}
1553
1554pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1555where
1556 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1557{
1558 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1559 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1580 connector_builder.set_options(options);
1581 configure(&mut connector_builder).unwrap();
1582 MakeTlsConnector::new(connector_builder.build())
1583}
1584
1585pub struct Ca {
1587 pub dir: TempDir,
1588 pub name: X509Name,
1589 pub cert: X509,
1590 pub pkey: PKey<Private>,
1591}
1592
1593impl Ca {
1594 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1595 let dir = tempfile::tempdir()?;
1596 let rsa = Rsa::generate(2048)?;
1597 let pkey = PKey::from_rsa(rsa)?;
1598 let name = {
1599 let mut builder = X509NameBuilder::new()?;
1600 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1601 builder.build()
1602 };
1603 let cert = {
1604 let mut builder = X509::builder()?;
1605 builder.set_version(2)?;
1606 builder.set_pubkey(&pkey)?;
1607 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1608 builder.set_subject_name(&name)?;
1609 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1610 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1611 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1612 builder.sign(
1613 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1614 MessageDigest::sha256(),
1615 )?;
1616 builder.build()
1617 };
1618 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1619 Ok(Ca {
1620 dir,
1621 name,
1622 cert,
1623 pkey,
1624 })
1625 }
1626
1627 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1629 Ca::make_ca(name, None)
1630 }
1631
1632 pub fn ca_cert_path(&self) -> PathBuf {
1634 self.dir.path().join("ca.crt")
1635 }
1636
1637 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1639 Ca::make_ca(name, Some(self))
1640 }
1641
1642 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1647 self.request_cert(name, iter::empty())
1648 }
1649
1650 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1653 where
1654 I: IntoIterator<Item = IpAddr>,
1655 {
1656 let rsa = Rsa::generate(2048)?;
1657 let pkey = PKey::from_rsa(rsa)?;
1658 let subject_name = {
1659 let mut builder = X509NameBuilder::new()?;
1660 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1661 builder.build()
1662 };
1663 let cert = {
1664 let mut builder = X509::builder()?;
1665 builder.set_version(2)?;
1666 builder.set_pubkey(&pkey)?;
1667 builder.set_issuer_name(self.cert.subject_name())?;
1668 builder.set_subject_name(&subject_name)?;
1669 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1670 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1671 for ip in ips {
1672 builder.append_extension(
1673 SubjectAlternativeName::new()
1674 .ip(&ip.to_string())
1675 .build(&builder.x509v3_context(None, None))?,
1676 )?;
1677 }
1678 builder.sign(&self.pkey, MessageDigest::sha256())?;
1679 builder.build()
1680 };
1681 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1682 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1683 fs::write(&cert_path, cert.to_pem()?)?;
1684 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1685 Ok((cert_path, key_path))
1686 }
1687}