1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91 WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97#[derive(Clone)]
99pub struct TestHarness {
100 data_directory: Option<PathBuf>,
101 tls: Option<TlsCertConfig>,
102 frontegg: Option<FronteggAuthenticator>,
103 external_login_password_mz_system: Option<Password>,
104 listeners_config: ListenersConfig,
105 unsafe_mode: bool,
106 workers: usize,
107 now: NowFn,
108 seed: u32,
109 storage_usage_collection_interval: Duration,
110 storage_usage_retention_period: Option<Duration>,
111 default_cluster_replica_size: String,
112 default_cluster_replication_factor: u32,
113 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119 propagate_crashes: bool,
120 enable_tracing: bool,
121 orchestrator_tracing_cli_args: TracingCliArgs,
124 bootstrap_role: Option<String>,
125 deploy_generation: u64,
126 system_parameter_defaults: BTreeMap<String, String>,
127 internal_console_redirect_url: Option<String>,
128 metrics_registry: Option<MetricsRegistry>,
129 code_version: semver::Version,
130 capture: Option<SharedStorage>,
131 pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135 fn default() -> TestHarness {
136 TestHarness {
137 data_directory: None,
138 tls: None,
139 frontegg: None,
140 external_login_password_mz_system: None,
141 listeners_config: ListenersConfig {
142 sql: btreemap![
143 "external".to_owned() => SqlListenerConfig {
144 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145 authenticator_kind: AuthenticatorKind::None,
146 allowed_roles: AllowedRoles::Normal,
147 enable_tls: false,
148 },
149 "internal".to_owned() => SqlListenerConfig {
150 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151 authenticator_kind: AuthenticatorKind::None,
152 allowed_roles: AllowedRoles::NormalAndInternal,
153 enable_tls: false,
154 },
155 ],
156 http: btreemap![
157 "external".to_owned() => HttpListenerConfig {
158 base: BaseListenerConfig {
159 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160 authenticator_kind: AuthenticatorKind::None,
161 allowed_roles: AllowedRoles::Normal,
162 enable_tls: false,
163 },
164 routes: HttpRoutesEnabled{
165 base: true,
166 webhook: true,
167 internal: false,
168 metrics: false,
169 profiling: false,
170 },
171 },
172 "internal".to_owned() => HttpListenerConfig {
173 base: BaseListenerConfig {
174 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
175 authenticator_kind: AuthenticatorKind::None,
176 allowed_roles: AllowedRoles::NormalAndInternal,
177 enable_tls: false,
178 },
179 routes: HttpRoutesEnabled{
180 base: true,
181 webhook: true,
182 internal: true,
183 metrics: true,
184 profiling: true,
185 },
186 },
187 ],
188 },
189 unsafe_mode: false,
190 workers: 1,
191 now: SYSTEM_TIME.clone(),
192 seed: rand::random(),
193 storage_usage_collection_interval: Duration::from_secs(3600),
194 storage_usage_retention_period: None,
195 default_cluster_replica_size: "scale=1,workers=1".to_string(),
196 default_cluster_replication_factor: 1,
197 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
198 size: "scale=1,workers=1".to_string(),
199 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
200 },
201 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
202 size: "scale=1,workers=1".to_string(),
203 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204 },
205 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
206 size: "scale=1,workers=1".to_string(),
207 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208 },
209 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
210 size: "scale=1,workers=1".to_string(),
211 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212 },
213 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
214 size: "scale=1,workers=1".to_string(),
215 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216 },
217 propagate_crashes: false,
218 enable_tracing: false,
219 bootstrap_role: Some("materialize".into()),
220 deploy_generation: 0,
221 system_parameter_defaults: BTreeMap::from([(
224 "log_filter".to_string(),
225 "error".to_string(),
226 )]),
227 internal_console_redirect_url: None,
228 metrics_registry: None,
229 orchestrator_tracing_cli_args: TracingCliArgs {
230 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
231 ..Default::default()
232 },
233 code_version: crate::BUILD_INFO.semver_version(),
234 environment_id: EnvironmentId::for_tests(),
235 capture: None,
236 }
237 }
238}
239
240impl TestHarness {
241 pub async fn start(self) -> TestServer {
245 self.try_start().await.expect("Failed to start test Server")
246 }
247
248 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
250 self.try_start_with_trigger(tls_reload_certs)
251 .await
252 .expect("Failed to start test Server")
253 }
254
255 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
257 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
258 .await
259 }
260
261 pub async fn try_start_with_trigger(
263 self,
264 tls_reload_certs: ReloadTrigger,
265 ) -> Result<TestServer, anyhow::Error> {
266 let listeners = Listeners::new(&self).await?;
267 listeners.serve_with_trigger(self, tls_reload_certs).await
268 }
269
270 pub fn start_blocking(self) -> TestServerWithRuntime {
272 stacker::grow(mz_ore::stack::STACK_SIZE, || {
273 let runtime = Runtime::new().expect("failed to spawn runtime for test");
274 let runtime = Arc::new(runtime);
275 let server = runtime.block_on(self.start());
276 TestServerWithRuntime { runtime, server }
277 })
278 }
279
280 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
281 self.data_directory = Some(data_directory.into());
282 self
283 }
284
285 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
286 self.tls = Some(TlsCertConfig {
287 cert: cert_path.into(),
288 key: key_path.into(),
289 });
290 for (_, listener) in &mut self.listeners_config.sql {
291 listener.enable_tls = true;
292 }
293 for (_, listener) in &mut self.listeners_config.http {
294 listener.base.enable_tls = true;
295 }
296 self
297 }
298
299 pub fn unsafe_mode(mut self) -> Self {
300 self.unsafe_mode = true;
301 self
302 }
303
304 pub fn workers(mut self, workers: usize) -> Self {
305 self.workers = workers;
306 self
307 }
308
309 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
310 self.frontegg = Some(frontegg.clone());
311 let enable_tls = self.tls.is_some();
312 self.listeners_config = ListenersConfig {
313 sql: btreemap! {
314 "external".to_owned() => SqlListenerConfig {
315 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
316 authenticator_kind: AuthenticatorKind::Frontegg,
317 allowed_roles: AllowedRoles::Normal,
318 enable_tls,
319 },
320 "internal".to_owned() => SqlListenerConfig {
321 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
322 authenticator_kind: AuthenticatorKind::None,
323 allowed_roles: AllowedRoles::NormalAndInternal,
324 enable_tls: false,
325 },
326 },
327 http: btreemap! {
328 "external".to_owned() => HttpListenerConfig {
329 base: BaseListenerConfig {
330 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
331 authenticator_kind: AuthenticatorKind::Frontegg,
332 allowed_roles: AllowedRoles::Normal,
333 enable_tls,
334 },
335 routes: HttpRoutesEnabled{
336 base: true,
337 webhook: true,
338 internal: false,
339 metrics: false,
340 profiling: false,
341 },
342 },
343 "internal".to_owned() => HttpListenerConfig {
344 base: BaseListenerConfig {
345 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
346 authenticator_kind: AuthenticatorKind::None,
347 allowed_roles: AllowedRoles::NormalAndInternal,
348 enable_tls: false,
349 },
350 routes: HttpRoutesEnabled{
351 base: true,
352 webhook: true,
353 internal: true,
354 metrics: true,
355 profiling: true,
356 },
357 },
358 },
359 };
360 self
361 }
362
363 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
364 self.external_login_password_mz_system = Some(mz_system_password);
365 let enable_tls = self.tls.is_some();
366 self.listeners_config = ListenersConfig {
367 sql: btreemap! {
368 "external".to_owned() => SqlListenerConfig {
369 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
370 authenticator_kind: AuthenticatorKind::Password,
371 allowed_roles: AllowedRoles::NormalAndInternal,
372 enable_tls,
373 },
374 },
375 http: btreemap! {
376 "external".to_owned() => HttpListenerConfig {
377 base: BaseListenerConfig {
378 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
379 authenticator_kind: AuthenticatorKind::Password,
380 allowed_roles: AllowedRoles::NormalAndInternal,
381 enable_tls,
382 },
383 routes: HttpRoutesEnabled{
384 base: true,
385 webhook: true,
386 internal: true,
387 metrics: false,
388 profiling: true,
389 },
390 },
391 "metrics".to_owned() => HttpListenerConfig {
392 base: BaseListenerConfig {
393 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
394 authenticator_kind: AuthenticatorKind::None,
395 allowed_roles: AllowedRoles::NormalAndInternal,
396 enable_tls: false,
397 },
398 routes: HttpRoutesEnabled{
399 base: false,
400 webhook: false,
401 internal: false,
402 metrics: true,
403 profiling: false,
404 },
405 },
406 },
407 };
408 self
409 }
410
411 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
412 self.external_login_password_mz_system = Some(mz_system_password);
413 let enable_tls = self.tls.is_some();
414 self.listeners_config = ListenersConfig {
415 sql: btreemap! {
416 "external".to_owned() => SqlListenerConfig {
417 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
418 authenticator_kind: AuthenticatorKind::Sasl,
419 allowed_roles: AllowedRoles::NormalAndInternal,
420 enable_tls,
421 },
422 },
423 http: btreemap! {
424 "external".to_owned() => HttpListenerConfig {
425 base: BaseListenerConfig {
426 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
427 authenticator_kind: AuthenticatorKind::Password,
428 allowed_roles: AllowedRoles::NormalAndInternal,
429 enable_tls,
430 },
431 routes: HttpRoutesEnabled{
432 base: true,
433 webhook: true,
434 internal: true,
435 metrics: false,
436 profiling: true,
437 },
438 },
439 "metrics".to_owned() => HttpListenerConfig {
440 base: BaseListenerConfig {
441 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
442 authenticator_kind: AuthenticatorKind::None,
443 allowed_roles: AllowedRoles::NormalAndInternal,
444 enable_tls: false,
445 },
446 routes: HttpRoutesEnabled{
447 base: false,
448 webhook: false,
449 internal: false,
450 metrics: true,
451 profiling: false,
452 },
453 },
454 },
455 };
456 self
457 }
458
459 pub fn with_now(mut self, now: NowFn) -> Self {
460 self.now = now;
461 self
462 }
463
464 pub fn with_storage_usage_collection_interval(
465 mut self,
466 storage_usage_collection_interval: Duration,
467 ) -> Self {
468 self.storage_usage_collection_interval = storage_usage_collection_interval;
469 self
470 }
471
472 pub fn with_storage_usage_retention_period(
473 mut self,
474 storage_usage_retention_period: Duration,
475 ) -> Self {
476 self.storage_usage_retention_period = Some(storage_usage_retention_period);
477 self
478 }
479
480 pub fn with_default_cluster_replica_size(
481 mut self,
482 default_cluster_replica_size: String,
483 ) -> Self {
484 self.default_cluster_replica_size = default_cluster_replica_size;
485 self
486 }
487
488 pub fn with_builtin_system_cluster_replica_size(
489 mut self,
490 builtin_system_cluster_replica_size: String,
491 ) -> Self {
492 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
493 self
494 }
495
496 pub fn with_builtin_system_cluster_replication_factor(
497 mut self,
498 builtin_system_cluster_replication_factor: u32,
499 ) -> Self {
500 self.builtin_system_cluster_config.replication_factor =
501 builtin_system_cluster_replication_factor;
502 self
503 }
504
505 pub fn with_builtin_catalog_server_cluster_replica_size(
506 mut self,
507 builtin_catalog_server_cluster_replica_size: String,
508 ) -> Self {
509 self.builtin_catalog_server_cluster_config.size =
510 builtin_catalog_server_cluster_replica_size;
511 self
512 }
513
514 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
515 self.propagate_crashes = propagate_crashes;
516 self
517 }
518
519 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
520 self.enable_tracing = enable_tracing;
521 self
522 }
523
524 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
525 self.bootstrap_role = bootstrap_role;
526 self
527 }
528
529 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
530 self.deploy_generation = deploy_generation;
531 self
532 }
533
534 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
535 self.system_parameter_defaults.insert(param, value);
536 self
537 }
538
539 pub fn with_internal_console_redirect_url(
540 mut self,
541 internal_console_redirect_url: Option<String>,
542 ) -> Self {
543 self.internal_console_redirect_url = internal_console_redirect_url;
544 self
545 }
546
547 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
548 self.metrics_registry = Some(registry);
549 self
550 }
551
552 pub fn with_code_version(mut self, version: semver::Version) -> Self {
553 self.code_version = version;
554 self
555 }
556
557 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
558 self.capture = Some(storage);
559 self
560 }
561}
562
563pub struct Listeners {
564 pub inner: crate::Listeners,
565}
566
567impl Listeners {
568 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
569 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
570 Ok(Listeners { inner })
571 }
572
573 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
574 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
575 .await
576 }
577
578 pub async fn serve_with_trigger(
579 self,
580 config: TestHarness,
581 tls_reload_certs: ReloadTrigger,
582 ) -> Result<TestServer, anyhow::Error> {
583 let (data_directory, temp_dir) = match config.data_directory {
584 None => {
585 let temp_dir = tempfile::tempdir()?;
590 (temp_dir.path().to_path_buf(), Some(temp_dir))
591 }
592 Some(data_directory) => (data_directory, None),
593 };
594 let scratch_dir = tempfile::tempdir()?;
595 let (consensus_uri, timestamp_oracle_url) = {
596 let seed = config.seed;
597 let cockroach_url = env::var("METADATA_BACKEND_URL")
598 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
599 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
600 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
601 if let Err(err) = conn.await {
602 panic!("connection error: {}", err);
603 };
604 });
605 client
606 .batch_execute(&format!(
607 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
608 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
609 ))
610 .await?;
611 (
612 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
613 .parse()
614 .expect("invalid consensus URI"),
615 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
616 .parse()
617 .expect("invalid timestamp oracle URI"),
618 )
619 };
620 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
621 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
622 image_dir: env::current_exe()?
623 .parent()
624 .unwrap()
625 .parent()
626 .unwrap()
627 .to_path_buf(),
628 suppress_output: false,
629 environment_id: config.environment_id.to_string(),
630 secrets_dir: data_directory.join("secrets"),
631 command_wrapper: vec![],
632 propagate_crashes: config.propagate_crashes,
633 tcp_proxy: None,
634 scratch_directory: scratch_dir.path().to_path_buf(),
635 })
636 .await?;
637 let orchestrator = Arc::new(orchestrator);
638 let persist_now = SYSTEM_TIME.clone();
641 let dyncfgs = mz_dyncfgs::all_dyncfgs();
642
643 let mut updates = ConfigUpdates::default();
644 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
647 updates.apply(&dyncfgs);
648
649 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
650 persist_cfg.build_version = config.code_version;
651 persist_cfg.set_rollup_threshold(5);
653
654 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
655 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
656 let persist_pubsub_tcp_listener =
657 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
658 .await
659 .expect("pubsub addr binding");
660 let persist_pubsub_server_port = persist_pubsub_tcp_listener
661 .local_addr()
662 .expect("pubsub addr has local addr")
663 .port();
664
665 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
667 persist_pubsub_server
668 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
669 .await
670 .expect("success")
671 });
672 let persist_clients =
673 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
674 let persist_clients = Arc::new(persist_clients);
675
676 let secrets_controller = Arc::clone(&orchestrator);
677 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
678 let orchestrator = Arc::new(TracingOrchestrator::new(
679 orchestrator,
680 config.orchestrator_tracing_cli_args,
681 ));
682 let tracing_handle = if config.enable_tracing {
683 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
684 service_name: "environmentd",
685 stderr_log: StderrLogConfig {
686 format: StderrLogFormat::Json,
687 filter: EnvFilter::default(),
688 },
689 opentelemetry: Some(OpenTelemetryConfig {
690 endpoint: "http://fake_address_for_testing:8080".to_string(),
691 headers: http::HeaderMap::new(),
692 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
693 resource: opentelemetry_sdk::resource::Resource::default(),
694 max_batch_queue_size: 2048,
695 max_export_batch_size: 512,
696 max_concurrent_exports: 1,
697 batch_scheduled_delay: Duration::from_millis(5000),
698 max_export_timeout: Duration::from_secs(30),
699 }),
700 tokio_console: None,
701 sentry: None,
702 build_version: crate::BUILD_INFO.version,
703 build_sha: crate::BUILD_INFO.sha,
704 registry: metrics_registry.clone(),
705 capture: config.capture,
706 };
707 mz_ore::tracing::configure(config).await?
708 } else {
709 TracingHandle::disabled()
710 };
711 let host_name = format!(
712 "localhost:{}",
713 self.inner.http["external"].handle.local_addr.port()
714 );
715 let catalog_config = CatalogConfig {
716 persist_clients: Arc::clone(&persist_clients),
717 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
718 };
719
720 let inner = self
721 .inner
722 .serve(crate::Config {
723 catalog_config,
724 timestamp_oracle_url: Some(timestamp_oracle_url),
725 controller: ControllerConfig {
726 build_info: &crate::BUILD_INFO,
727 orchestrator,
728 clusterd_image: "clusterd".into(),
729 init_container_image: None,
730 deploy_generation: config.deploy_generation,
731 persist_location: PersistLocation {
732 blob_uri: format!("file://{}/persist/blob", data_directory.display())
733 .parse()
734 .expect("invalid blob URI"),
735 consensus_uri,
736 },
737 persist_clients,
738 now: config.now.clone(),
739 metrics_registry: metrics_registry.clone(),
740 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
741 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
742 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
743 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
744 secrets_reader_kubernetes_context: None,
745 secrets_reader_aws_prefix: None,
746 secrets_reader_name_prefix: None,
747 },
748 connection_context,
749 },
750 secrets_controller,
751 cloud_resource_controller: None,
752 tls: config.tls,
753 frontegg: config.frontegg,
754 unsafe_mode: config.unsafe_mode,
755 all_features: false,
756 metrics_registry: metrics_registry.clone(),
757 now: config.now,
758 environment_id: config.environment_id,
759 cors_allowed_origin: AllowOrigin::list([]),
760 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
761 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
762 bootstrap_default_cluster_replication_factor: config
763 .default_cluster_replication_factor,
764 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
765 bootstrap_builtin_catalog_server_cluster_config: config
766 .builtin_catalog_server_cluster_config,
767 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
768 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
769 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
770 system_parameter_defaults: config.system_parameter_defaults,
771 availability_zones: Default::default(),
772 tracing_handle,
773 storage_usage_collection_interval: config.storage_usage_collection_interval,
774 storage_usage_retention_period: config.storage_usage_retention_period,
775 segment_api_key: None,
776 segment_client_side: false,
777 test_only_dummy_segment_client: false,
778 egress_addresses: vec![],
779 aws_account_id: None,
780 aws_privatelink_availability_zones: None,
781 launchdarkly_sdk_key: None,
782 launchdarkly_key_map: Default::default(),
783 config_sync_file_path: None,
784 config_sync_timeout: Duration::from_secs(30),
785 config_sync_loop_interval: None,
786 bootstrap_role: config.bootstrap_role,
787 http_host_name: Some(host_name),
788 internal_console_redirect_url: config.internal_console_redirect_url,
789 tls_reload_certs,
790 helm_chart_version: None,
791 license_key: ValidatedLicenseKey::for_tests(),
792 external_login_password_mz_system: config.external_login_password_mz_system,
793 force_builtin_schema_migration: None,
794 })
795 .await?;
796
797 Ok(TestServer {
798 inner,
799 metrics_registry,
800 _temp_dir: temp_dir,
801 _scratch_dir: scratch_dir,
802 })
803 }
804}
805
806pub struct TestServer {
808 pub inner: crate::Server,
809 pub metrics_registry: MetricsRegistry,
810 _temp_dir: Option<TempDir>,
812 _scratch_dir: TempDir,
813}
814
815impl TestServer {
816 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
817 ConnectBuilder::new(self).no_tls()
818 }
819
820 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
821 let internal_client = self.connect().internal().await.unwrap();
822
823 for flag in flags {
824 internal_client
825 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
826 .await
827 .unwrap();
828 }
829 }
830
831 pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
832 let internal_client = self.connect().internal().await.unwrap();
833
834 for flag in flags {
835 internal_client
836 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
837 .await
838 .unwrap();
839 }
840 }
841
842 pub fn ws_addr(&self) -> Uri {
843 format!(
844 "ws://{}/api/experimental/sql",
845 self.inner.http_listener_handles["external"].local_addr
846 )
847 .parse()
848 .unwrap()
849 }
850
851 pub fn internal_ws_addr(&self) -> Uri {
852 format!(
853 "ws://{}/api/experimental/sql",
854 self.inner.http_listener_handles["internal"].local_addr
855 )
856 .parse()
857 .unwrap()
858 }
859
860 pub fn http_local_addr(&self) -> SocketAddr {
861 self.inner.http_listener_handles["external"].local_addr
862 }
863
864 pub fn internal_http_local_addr(&self) -> SocketAddr {
865 self.inner.http_listener_handles["internal"].local_addr
866 }
867
868 pub fn sql_local_addr(&self) -> SocketAddr {
869 self.inner.sql_listener_handles["external"].local_addr
870 }
871
872 pub fn internal_sql_local_addr(&self) -> SocketAddr {
873 self.inner.sql_listener_handles["internal"].local_addr
874 }
875}
876
877pub struct ConnectBuilder<'s, T, H> {
881 server: &'s TestServer,
883
884 pg_config: tokio_postgres::Config,
886 port: u16,
888 tls: T,
890
891 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
893
894 _with_handle: H,
896}
897
898impl<'s> ConnectBuilder<'s, (), NoHandle> {
899 fn new(server: &'s TestServer) -> Self {
900 let mut pg_config = tokio_postgres::Config::new();
901 pg_config
902 .host(&Ipv4Addr::LOCALHOST.to_string())
903 .user("materialize")
904 .options("--welcome_message=off")
905 .application_name("environmentd_test_framework");
906
907 ConnectBuilder {
908 server,
909 pg_config,
910 port: server.sql_local_addr().port(),
911 tls: (),
912 notice_callback: None,
913 _with_handle: NoHandle,
914 }
915 }
916}
917
918impl<'s, T, H> ConnectBuilder<'s, T, H> {
919 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
923 ConnectBuilder {
924 server: self.server,
925 pg_config: self.pg_config,
926 port: self.port,
927 tls: postgres::NoTls,
928 notice_callback: self.notice_callback,
929 _with_handle: self._with_handle,
930 }
931 }
932
933 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
935 where
936 Tls: MakeTlsConnect<Socket> + Send + 'static,
937 Tls::TlsConnect: Send,
938 Tls::Stream: Send,
939 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
940 {
941 ConnectBuilder {
942 server: self.server,
943 pg_config: self.pg_config,
944 port: self.port,
945 tls,
946 notice_callback: self.notice_callback,
947 _with_handle: self._with_handle,
948 }
949 }
950
951 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
953 self.pg_config = pg_config;
954 self
955 }
956
957 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
959 self.pg_config.ssl_mode(mode);
960 self
961 }
962
963 pub fn user(mut self, user: &str) -> Self {
965 self.pg_config.user(user);
966 self
967 }
968
969 pub fn password(mut self, password: &str) -> Self {
971 self.pg_config.password(password);
972 self
973 }
974
975 pub fn application_name(mut self, application_name: &str) -> Self {
977 self.pg_config.application_name(application_name);
978 self
979 }
980
981 pub fn dbname(mut self, dbname: &str) -> Self {
983 self.pg_config.dbname(dbname);
984 self
985 }
986
987 pub fn options(mut self, options: &str) -> Self {
989 self.pg_config.options(options);
990 self
991 }
992
993 pub fn internal(mut self) -> Self {
998 self.port = self.server.internal_sql_local_addr().port();
999 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1000 self
1001 }
1002
1003 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1005 ConnectBuilder {
1006 notice_callback: Some(Box::new(callback)),
1007 ..self
1008 }
1009 }
1010
1011 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1014 ConnectBuilder {
1015 server: self.server,
1016 pg_config: self.pg_config,
1017 port: self.port,
1018 tls: self.tls,
1019 notice_callback: self.notice_callback,
1020 _with_handle: WithHandle,
1021 }
1022 }
1023
1024 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1026 &self.pg_config
1027 }
1028}
1029
1030pub trait IncludeHandle: Send {
1033 type Output;
1034 fn transform_result(
1035 client: tokio_postgres::Client,
1036 handle: mz_ore::task::JoinHandle<()>,
1037 ) -> Self::Output;
1038}
1039
1040pub struct NoHandle;
1043impl IncludeHandle for NoHandle {
1044 type Output = tokio_postgres::Client;
1045 fn transform_result(
1046 client: tokio_postgres::Client,
1047 _handle: mz_ore::task::JoinHandle<()>,
1048 ) -> Self::Output {
1049 client
1050 }
1051}
1052
1053pub struct WithHandle;
1056impl IncludeHandle for WithHandle {
1057 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1058 fn transform_result(
1059 client: tokio_postgres::Client,
1060 handle: mz_ore::task::JoinHandle<()>,
1061 ) -> Self::Output {
1062 (client, handle)
1063 }
1064}
1065
1066impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1067where
1068 T: MakeTlsConnect<Socket> + Send + 'static,
1069 T::TlsConnect: Send,
1070 T::Stream: Send,
1071 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1072 H: IncludeHandle,
1073{
1074 type Output = Result<H::Output, postgres::Error>;
1075 type IntoFuture = BoxFuture<'static, Self::Output>;
1076
1077 fn into_future(mut self) -> Self::IntoFuture {
1078 Box::pin(async move {
1079 assert!(
1080 self.pg_config.get_ports().is_empty(),
1081 "specifying multiple ports is not supported"
1082 );
1083 self.pg_config.port(self.port);
1084
1085 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1086 let mut notice_callback = self.notice_callback.take();
1087
1088 let handle = task::spawn(|| "connect", async move {
1089 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1090 match msg {
1091 Ok(AsyncMessage::Notice(notice)) => {
1092 if let Some(callback) = notice_callback.as_mut() {
1093 callback(notice);
1094 }
1095 }
1096 Ok(msg) => {
1097 tracing::debug!(?msg, "Dropping message from database");
1098 }
1099 Err(e) => {
1100 tracing::info!("connection error: {e}");
1105 break;
1106 }
1107 }
1108 }
1109 tracing::info!("connection closed");
1110 });
1111
1112 let output = H::transform_result(client, handle);
1113 Ok(output)
1114 })
1115 }
1116}
1117
1118pub struct TestServerWithRuntime {
1123 server: TestServer,
1124 runtime: Arc<Runtime>,
1125}
1126
1127impl TestServerWithRuntime {
1128 pub fn runtime(&self) -> &Arc<Runtime> {
1132 &self.runtime
1133 }
1134
1135 pub fn inner(&self) -> &crate::Server {
1137 &self.server.inner
1138 }
1139
1140 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1142 where
1143 T: MakeTlsConnect<Socket> + Send + 'static,
1144 T::TlsConnect: Send,
1145 T::Stream: Send,
1146 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1147 {
1148 self.pg_config().connect(tls)
1149 }
1150
1151 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1153 where
1154 T: MakeTlsConnect<Socket> + Send + 'static,
1155 T::TlsConnect: Send,
1156 T::Stream: Send,
1157 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1158 {
1159 Ok(self.pg_config_internal().connect(tls)?)
1160 }
1161
1162 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1164 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1165
1166 for flag in flags {
1167 internal_client
1168 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1169 .unwrap();
1170 }
1171 }
1172
1173 pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1175 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1176
1177 for flag in flags {
1178 internal_client
1179 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1180 .unwrap();
1181 }
1182 }
1183
1184 pub fn pg_config(&self) -> postgres::Config {
1187 let local_addr = self.server.sql_local_addr();
1188 let mut config = postgres::Config::new();
1189 config
1190 .host(&Ipv4Addr::LOCALHOST.to_string())
1191 .port(local_addr.port())
1192 .user("materialize")
1193 .options("--welcome_message=off");
1194 config
1195 }
1196
1197 pub fn pg_config_internal(&self) -> postgres::Config {
1200 let local_addr = self.server.internal_sql_local_addr();
1201 let mut config = postgres::Config::new();
1202 config
1203 .host(&Ipv4Addr::LOCALHOST.to_string())
1204 .port(local_addr.port())
1205 .user("mz_system")
1206 .options("--welcome_message=off");
1207 config
1208 }
1209
1210 pub fn ws_addr(&self) -> Uri {
1211 self.server.ws_addr()
1212 }
1213
1214 pub fn internal_ws_addr(&self) -> Uri {
1215 self.server.internal_ws_addr()
1216 }
1217
1218 pub fn http_local_addr(&self) -> SocketAddr {
1219 self.server.http_local_addr()
1220 }
1221
1222 pub fn internal_http_local_addr(&self) -> SocketAddr {
1223 self.server.internal_http_local_addr()
1224 }
1225
1226 pub fn sql_local_addr(&self) -> SocketAddr {
1227 self.server.sql_local_addr()
1228 }
1229
1230 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1231 self.server.internal_sql_local_addr()
1232 }
1233}
1234
1235#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1236pub struct MzTimestamp(pub u64);
1237
1238impl<'a> FromSql<'a> for MzTimestamp {
1239 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1240 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1241 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1242 }
1243
1244 fn accepts(ty: &Type) -> bool {
1245 mz_pgrepr::Numeric::accepts(ty)
1246 }
1247}
1248
1249pub trait PostgresErrorExt {
1250 fn unwrap_db_error(self) -> DbError;
1251}
1252
1253impl PostgresErrorExt for postgres::Error {
1254 fn unwrap_db_error(self) -> DbError {
1255 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1256 Some(e) => e.clone(),
1257 None => panic!("expected DbError, but got: {:?}", self),
1258 }
1259 }
1260}
1261
1262impl<T, E> PostgresErrorExt for Result<T, E>
1263where
1264 E: PostgresErrorExt,
1265{
1266 fn unwrap_db_error(self) -> DbError {
1267 match self {
1268 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1269 Err(e) => e.unwrap_db_error(),
1270 }
1271 }
1272}
1273
1274pub async fn insert_with_deterministic_timestamps(
1278 table: &'static str,
1279 values: &'static str,
1280 server: &TestServer,
1281 now: Arc<std::sync::Mutex<EpochMillis>>,
1282) -> Result<(), Box<dyn Error>> {
1283 let client_write = server.connect().await?;
1284 let client_read = server.connect().await?;
1285
1286 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1287
1288 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1289
1290 let write_future = client_write.execute(&insert_query, &[]);
1291 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1292
1293 let mut write_future = std::pin::pin!(write_future);
1294 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1295
1296 loop {
1299 tokio::select! {
1300 _ = (&mut write_future) => return Ok(()),
1301 _ = timestamp_interval.tick() => {
1302 current_ts += 1;
1303 *now.lock().expect("lock poisoned") = current_ts;
1304 }
1305 };
1306 }
1307}
1308
1309pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1310 try_get_explain_timestamp(from_suffix, client)
1311 .await
1312 .unwrap()
1313}
1314
1315pub async fn try_get_explain_timestamp(
1316 from_suffix: &str,
1317 client: &Client,
1318) -> Result<EpochMillis, anyhow::Error> {
1319 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1320 let ts = det.determination.timestamp_context.timestamp_or_default();
1321 Ok(ts.into())
1322}
1323
1324pub async fn get_explain_timestamp_determination(
1325 from_suffix: &str,
1326 client: &Client,
1327) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1328 let row = client
1329 .query_one(
1330 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1331 &[],
1332 )
1333 .await?;
1334 let explain: String = row.get(0);
1335 Ok(serde_json::from_str(&explain).unwrap())
1336}
1337
1338pub async fn create_postgres_source_with_table<'a>(
1346 server: &TestServer,
1347 mz_client: &Client,
1348 table_name: &str,
1349 table_schema: &str,
1350 source_name: &str,
1351) -> (
1352 Client,
1353 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1354) {
1355 server
1356 .enable_feature_flags(&["enable_create_table_from_source"])
1357 .await;
1358
1359 let postgres_url = env::var("POSTGRES_URL")
1360 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1361 .unwrap();
1362
1363 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1364 .await
1365 .unwrap();
1366
1367 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1368 let user = pg_config.get_user().unwrap_or("postgres");
1369 let db_name = pg_config.get_dbname().unwrap_or(user);
1370 let ports = pg_config.get_ports();
1371 let port = if ports.is_empty() { 5432 } else { ports[0] };
1372 let hosts = pg_config.get_hosts();
1373 let host = if hosts.is_empty() {
1374 "localhost".to_string()
1375 } else {
1376 match &hosts[0] {
1377 Host::Tcp(host) => host.to_string(),
1378 Host::Unix(host) => host.to_str().unwrap().to_string(),
1379 }
1380 };
1381 let password = pg_config.get_password();
1382
1383 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1384 if let Err(e) = connection.await {
1385 panic!("connection error: {}", e);
1386 }
1387 });
1388
1389 let _ = pg_client
1391 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1392 .await
1393 .unwrap();
1394 let _ = pg_client
1395 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1396 .await
1397 .unwrap();
1398 let _ = pg_client
1399 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1400 .await
1401 .unwrap();
1402 let _ = pg_client
1403 .execute(
1404 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1405 &[],
1406 )
1407 .await
1408 .unwrap();
1409 let _ = pg_client
1410 .execute(
1411 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1412 &[],
1413 )
1414 .await
1415 .unwrap();
1416
1417 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1419 if let Some(password) = password {
1420 let password = std::str::from_utf8(password).unwrap();
1421 mz_client
1422 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1423 .await
1424 .unwrap();
1425 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1426 }
1427 mz_client
1428 .batch_execute(&format!(
1429 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1430 ))
1431 .await
1432 .unwrap();
1433 mz_client
1434 .batch_execute(&format!(
1435 "CREATE SOURCE {source_name}
1436 FROM POSTGRES
1437 CONNECTION pgconn
1438 (PUBLICATION '{source_name}')"
1439 ))
1440 .await
1441 .unwrap();
1442 mz_client
1443 .batch_execute(&format!(
1444 "CREATE TABLE {table_name}
1445 FROM SOURCE {source_name}
1446 (REFERENCE {table_name});"
1447 ))
1448 .await
1449 .unwrap();
1450
1451 let table_name = table_name.to_string();
1452 let source_name = source_name.to_string();
1453 (
1454 pg_client,
1455 move |mz_client: &'a Client, pg_client: &'a Client| {
1456 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1457 mz_client
1458 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1459 .await
1460 .unwrap();
1461 mz_client
1462 .batch_execute("DROP CONNECTION pgconn;")
1463 .await
1464 .unwrap();
1465
1466 let _ = pg_client
1467 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1468 .await
1469 .unwrap();
1470 let _ = pg_client
1471 .execute(&format!("DROP TABLE {table_name};"), &[])
1472 .await
1473 .unwrap();
1474 });
1475 f
1476 },
1477 )
1478}
1479
1480pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1481 let current_isolation = mz_client
1482 .query_one("SHOW transaction_isolation", &[])
1483 .await
1484 .unwrap()
1485 .get::<_, String>(0);
1486 mz_client
1487 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1488 .await
1489 .unwrap();
1490 Retry::default()
1491 .retry_async(|_| async move {
1492 let rows = mz_client
1493 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1494 .await
1495 .unwrap()
1496 .get::<_, i64>(0);
1497 if rows == source_rows {
1498 Ok(())
1499 } else {
1500 Err(format!(
1501 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1502 ))
1503 }
1504 })
1505 .await
1506 .unwrap();
1507 mz_client
1508 .batch_execute(&format!(
1509 "SET transaction_isolation = '{current_isolation}'"
1510 ))
1511 .await
1512 .unwrap();
1513}
1514
1515pub fn auth_with_ws(
1517 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1518 mut options: BTreeMap<String, String>,
1519) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1520 if !options.contains_key("welcome_message") {
1521 options.insert("welcome_message".into(), "off".into());
1522 }
1523 auth_with_ws_impl(
1524 ws,
1525 Message::Text(
1526 serde_json::to_string(&WebSocketAuth::Basic {
1527 user: "materialize".into(),
1528 password: "".into(),
1529 options,
1530 })
1531 .unwrap(),
1532 ),
1533 )
1534}
1535
1536pub fn auth_with_ws_impl(
1537 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1538 auth_message: Message,
1539) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1540 ws.send(auth_message)?;
1541
1542 let mut msgs = Vec::new();
1544 loop {
1545 let resp = ws.read()?;
1546 match resp {
1547 Message::Text(msg) => {
1548 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1549 match msg {
1550 WebSocketResponse::ReadyForQuery(_) => break,
1551 msg => {
1552 msgs.push(msg);
1553 }
1554 }
1555 }
1556 Message::Ping(_) => continue,
1557 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1558 Message::Close(Some(close_frame)) => {
1559 return Err(anyhow!("ws closed after auth").context(close_frame));
1560 }
1561 _ => panic!("unexpected response: {:?}", resp),
1562 }
1563 }
1564 Ok(msgs)
1565}
1566
1567pub fn make_header<H: Header>(h: H) -> HeaderMap {
1568 let mut map = HeaderMap::new();
1569 map.typed_insert(h);
1570 map
1571}
1572
1573pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1574where
1575 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1576{
1577 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1578 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1599 connector_builder.set_options(options);
1600 configure(&mut connector_builder).unwrap();
1601 MakeTlsConnector::new(connector_builder.build())
1602}
1603
1604pub struct Ca {
1606 pub dir: TempDir,
1607 pub name: X509Name,
1608 pub cert: X509,
1609 pub pkey: PKey<Private>,
1610}
1611
1612impl Ca {
1613 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1614 let dir = tempfile::tempdir()?;
1615 let rsa = Rsa::generate(2048)?;
1616 let pkey = PKey::from_rsa(rsa)?;
1617 let name = {
1618 let mut builder = X509NameBuilder::new()?;
1619 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1620 builder.build()
1621 };
1622 let cert = {
1623 let mut builder = X509::builder()?;
1624 builder.set_version(2)?;
1625 builder.set_pubkey(&pkey)?;
1626 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1627 builder.set_subject_name(&name)?;
1628 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1629 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1630 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1631 builder.sign(
1632 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1633 MessageDigest::sha256(),
1634 )?;
1635 builder.build()
1636 };
1637 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1638 Ok(Ca {
1639 dir,
1640 name,
1641 cert,
1642 pkey,
1643 })
1644 }
1645
1646 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1648 Ca::make_ca(name, None)
1649 }
1650
1651 pub fn ca_cert_path(&self) -> PathBuf {
1653 self.dir.path().join("ca.crt")
1654 }
1655
1656 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1658 Ca::make_ca(name, Some(self))
1659 }
1660
1661 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1666 self.request_cert(name, iter::empty())
1667 }
1668
1669 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1672 where
1673 I: IntoIterator<Item = IpAddr>,
1674 {
1675 let rsa = Rsa::generate(2048)?;
1676 let pkey = PKey::from_rsa(rsa)?;
1677 let subject_name = {
1678 let mut builder = X509NameBuilder::new()?;
1679 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1680 builder.build()
1681 };
1682 let cert = {
1683 let mut builder = X509::builder()?;
1684 builder.set_version(2)?;
1685 builder.set_pubkey(&pkey)?;
1686 builder.set_issuer_name(self.cert.subject_name())?;
1687 builder.set_subject_name(&subject_name)?;
1688 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1689 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1690 for ip in ips {
1691 builder.append_extension(
1692 SubjectAlternativeName::new()
1693 .ip(&ip.to_string())
1694 .build(&builder.x509v3_context(None, None))?,
1695 )?;
1696 }
1697 builder.sign(&self.pkey, MessageDigest::sha256())?;
1698 builder.build()
1699 };
1700 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1701 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1702 fs::write(&cert_path, cert.to_pem()?)?;
1703 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1704 Ok((cert_path, key_path))
1705 }
1706}