1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91 WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97#[derive(Clone)]
99pub struct TestHarness {
100 data_directory: Option<PathBuf>,
101 tls: Option<TlsCertConfig>,
102 frontegg: Option<FronteggAuthenticator>,
103 external_login_password_mz_system: Option<Password>,
104 listeners_config: ListenersConfig,
105 unsafe_mode: bool,
106 workers: usize,
107 now: NowFn,
108 seed: u32,
109 storage_usage_collection_interval: Duration,
110 storage_usage_retention_period: Option<Duration>,
111 default_cluster_replica_size: String,
112 default_cluster_replication_factor: u32,
113 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119 propagate_crashes: bool,
120 enable_tracing: bool,
121 orchestrator_tracing_cli_args: TracingCliArgs,
124 bootstrap_role: Option<String>,
125 deploy_generation: u64,
126 system_parameter_defaults: BTreeMap<String, String>,
127 internal_console_redirect_url: Option<String>,
128 metrics_registry: Option<MetricsRegistry>,
129 code_version: semver::Version,
130 capture: Option<SharedStorage>,
131 pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135 fn default() -> TestHarness {
136 TestHarness {
137 data_directory: None,
138 tls: None,
139 frontegg: None,
140 external_login_password_mz_system: None,
141 listeners_config: ListenersConfig {
142 sql: btreemap![
143 "external".to_owned() => SqlListenerConfig {
144 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145 authenticator_kind: AuthenticatorKind::None,
146 allowed_roles: AllowedRoles::Normal,
147 enable_tls: false,
148 },
149 "internal".to_owned() => SqlListenerConfig {
150 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151 authenticator_kind: AuthenticatorKind::None,
152 allowed_roles: AllowedRoles::NormalAndInternal,
153 enable_tls: false,
154 },
155 ],
156 http: btreemap![
157 "external".to_owned() => HttpListenerConfig {
158 base: BaseListenerConfig {
159 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160 authenticator_kind: AuthenticatorKind::None,
161 allowed_roles: AllowedRoles::Normal,
162 enable_tls: false,
163 },
164 routes: HttpRoutesEnabled{
165 base: true,
166 webhook: true,
167 internal: false,
168 metrics: false,
169 profiling: false,
170 },
171 },
172 "internal".to_owned() => HttpListenerConfig {
173 base: BaseListenerConfig {
174 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
175 authenticator_kind: AuthenticatorKind::None,
176 allowed_roles: AllowedRoles::NormalAndInternal,
177 enable_tls: false,
178 },
179 routes: HttpRoutesEnabled{
180 base: true,
181 webhook: true,
182 internal: true,
183 metrics: true,
184 profiling: true,
185 },
186 },
187 ],
188 },
189 unsafe_mode: false,
190 workers: 1,
191 now: SYSTEM_TIME.clone(),
192 seed: rand::random(),
193 storage_usage_collection_interval: Duration::from_secs(3600),
194 storage_usage_retention_period: None,
195 default_cluster_replica_size: "scale=1,workers=1".to_string(),
196 default_cluster_replication_factor: 1,
197 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
198 size: "scale=1,workers=1".to_string(),
199 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
200 },
201 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
202 size: "scale=1,workers=1".to_string(),
203 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204 },
205 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
206 size: "scale=1,workers=1".to_string(),
207 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208 },
209 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
210 size: "scale=1,workers=1".to_string(),
211 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212 },
213 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
214 size: "scale=1,workers=1".to_string(),
215 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216 },
217 propagate_crashes: false,
218 enable_tracing: false,
219 bootstrap_role: Some("materialize".into()),
220 deploy_generation: 0,
221 system_parameter_defaults: BTreeMap::from([(
224 "log_filter".to_string(),
225 "error".to_string(),
226 )]),
227 internal_console_redirect_url: None,
228 metrics_registry: None,
229 orchestrator_tracing_cli_args: TracingCliArgs {
230 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
231 ..Default::default()
232 },
233 code_version: crate::BUILD_INFO.semver_version(),
234 environment_id: EnvironmentId::for_tests(),
235 capture: None,
236 }
237 }
238}
239
240impl TestHarness {
241 pub async fn start(self) -> TestServer {
245 self.try_start().await.expect("Failed to start test Server")
246 }
247
248 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
250 self.try_start_with_trigger(tls_reload_certs)
251 .await
252 .expect("Failed to start test Server")
253 }
254
255 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
257 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
258 .await
259 }
260
261 pub async fn try_start_with_trigger(
263 self,
264 tls_reload_certs: ReloadTrigger,
265 ) -> Result<TestServer, anyhow::Error> {
266 let listeners = Listeners::new(&self).await?;
267 listeners.serve_with_trigger(self, tls_reload_certs).await
268 }
269
270 pub fn start_blocking(self) -> TestServerWithRuntime {
272 let runtime = tokio::runtime::Builder::new_multi_thread()
273 .enable_all()
274 .thread_stack_size(mz_ore::stack::STACK_SIZE)
275 .build()
276 .expect("failed to spawn runtime for test");
277 let runtime = Arc::new(runtime);
278 let server = runtime.block_on(self.start());
279 TestServerWithRuntime { runtime, server }
280 }
281
282 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
283 self.data_directory = Some(data_directory.into());
284 self
285 }
286
287 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
288 self.tls = Some(TlsCertConfig {
289 cert: cert_path.into(),
290 key: key_path.into(),
291 });
292 for (_, listener) in &mut self.listeners_config.sql {
293 listener.enable_tls = true;
294 }
295 for (_, listener) in &mut self.listeners_config.http {
296 listener.base.enable_tls = true;
297 }
298 self
299 }
300
301 pub fn unsafe_mode(mut self) -> Self {
302 self.unsafe_mode = true;
303 self
304 }
305
306 pub fn workers(mut self, workers: usize) -> Self {
307 self.workers = workers;
308 self
309 }
310
311 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
312 self.frontegg = Some(frontegg.clone());
313 let enable_tls = self.tls.is_some();
314 self.listeners_config = ListenersConfig {
315 sql: btreemap! {
316 "external".to_owned() => SqlListenerConfig {
317 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
318 authenticator_kind: AuthenticatorKind::Frontegg,
319 allowed_roles: AllowedRoles::Normal,
320 enable_tls,
321 },
322 "internal".to_owned() => SqlListenerConfig {
323 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
324 authenticator_kind: AuthenticatorKind::None,
325 allowed_roles: AllowedRoles::NormalAndInternal,
326 enable_tls: false,
327 },
328 },
329 http: btreemap! {
330 "external".to_owned() => HttpListenerConfig {
331 base: BaseListenerConfig {
332 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
333 authenticator_kind: AuthenticatorKind::Frontegg,
334 allowed_roles: AllowedRoles::Normal,
335 enable_tls,
336 },
337 routes: HttpRoutesEnabled{
338 base: true,
339 webhook: true,
340 internal: false,
341 metrics: false,
342 profiling: false,
343 },
344 },
345 "internal".to_owned() => HttpListenerConfig {
346 base: BaseListenerConfig {
347 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
348 authenticator_kind: AuthenticatorKind::None,
349 allowed_roles: AllowedRoles::NormalAndInternal,
350 enable_tls: false,
351 },
352 routes: HttpRoutesEnabled{
353 base: true,
354 webhook: true,
355 internal: true,
356 metrics: true,
357 profiling: true,
358 },
359 },
360 },
361 };
362 self
363 }
364
365 pub fn with_oidc_auth(mut self, issuer: Option<String>, audience: Option<String>) -> Self {
366 let enable_tls = self.tls.is_some();
367 self.listeners_config = ListenersConfig {
368 sql: btreemap! {
369 "external".to_owned() => SqlListenerConfig {
370 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
371 authenticator_kind: AuthenticatorKind::Oidc,
372 allowed_roles: AllowedRoles::Normal,
373 enable_tls,
374 },
375 "internal".to_owned() => SqlListenerConfig {
376 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
377 authenticator_kind: AuthenticatorKind::None,
378 allowed_roles: AllowedRoles::NormalAndInternal,
379 enable_tls: false,
380 },
381 },
382 http: btreemap! {
383 "external".to_owned() => HttpListenerConfig {
384 base: BaseListenerConfig {
385 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
386 authenticator_kind: AuthenticatorKind::Oidc,
387 allowed_roles: AllowedRoles::Normal,
388 enable_tls,
389 },
390 routes: HttpRoutesEnabled{
391 base: true,
392 webhook: true,
393 internal: false,
394 metrics: false,
395 profiling: false,
396 },
397 },
398 "internal".to_owned() => HttpListenerConfig {
399 base: BaseListenerConfig {
400 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
401 authenticator_kind: AuthenticatorKind::None,
402 allowed_roles: AllowedRoles::NormalAndInternal,
403 enable_tls: false,
404 },
405 routes: HttpRoutesEnabled{
406 base: true,
407 webhook: true,
408 internal: true,
409 metrics: true,
410 profiling: true,
411 },
412 },
413 },
414 };
415
416 if let Some(issuer) = issuer {
417 self.system_parameter_defaults
418 .insert("oidc_issuer".to_string(), issuer);
419 }
420
421 if let Some(audience) = audience {
422 self.system_parameter_defaults
423 .insert("oidc_audience".to_string(), audience);
424 }
425
426 self
427 }
428
429 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
430 self.external_login_password_mz_system = Some(mz_system_password);
431 let enable_tls = self.tls.is_some();
432 self.listeners_config = ListenersConfig {
433 sql: btreemap! {
434 "external".to_owned() => SqlListenerConfig {
435 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
436 authenticator_kind: AuthenticatorKind::Password,
437 allowed_roles: AllowedRoles::NormalAndInternal,
438 enable_tls,
439 },
440 },
441 http: btreemap! {
442 "external".to_owned() => HttpListenerConfig {
443 base: BaseListenerConfig {
444 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
445 authenticator_kind: AuthenticatorKind::Password,
446 allowed_roles: AllowedRoles::NormalAndInternal,
447 enable_tls,
448 },
449 routes: HttpRoutesEnabled{
450 base: true,
451 webhook: true,
452 internal: true,
453 metrics: false,
454 profiling: true,
455 },
456 },
457 "metrics".to_owned() => HttpListenerConfig {
458 base: BaseListenerConfig {
459 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
460 authenticator_kind: AuthenticatorKind::None,
461 allowed_roles: AllowedRoles::NormalAndInternal,
462 enable_tls: false,
463 },
464 routes: HttpRoutesEnabled{
465 base: false,
466 webhook: false,
467 internal: false,
468 metrics: true,
469 profiling: false,
470 },
471 },
472 },
473 };
474 self
475 }
476
477 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
478 self.external_login_password_mz_system = Some(mz_system_password);
479 let enable_tls = self.tls.is_some();
480 self.listeners_config = ListenersConfig {
481 sql: btreemap! {
482 "external".to_owned() => SqlListenerConfig {
483 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
484 authenticator_kind: AuthenticatorKind::Sasl,
485 allowed_roles: AllowedRoles::NormalAndInternal,
486 enable_tls,
487 },
488 },
489 http: btreemap! {
490 "external".to_owned() => HttpListenerConfig {
491 base: BaseListenerConfig {
492 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
493 authenticator_kind: AuthenticatorKind::Password,
494 allowed_roles: AllowedRoles::NormalAndInternal,
495 enable_tls,
496 },
497 routes: HttpRoutesEnabled{
498 base: true,
499 webhook: true,
500 internal: true,
501 metrics: false,
502 profiling: true,
503 },
504 },
505 "metrics".to_owned() => HttpListenerConfig {
506 base: BaseListenerConfig {
507 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
508 authenticator_kind: AuthenticatorKind::None,
509 allowed_roles: AllowedRoles::NormalAndInternal,
510 enable_tls: false,
511 },
512 routes: HttpRoutesEnabled{
513 base: false,
514 webhook: false,
515 internal: false,
516 metrics: true,
517 profiling: false,
518 },
519 },
520 },
521 };
522 self
523 }
524
525 pub fn with_now(mut self, now: NowFn) -> Self {
526 self.now = now;
527 self
528 }
529
530 pub fn with_storage_usage_collection_interval(
531 mut self,
532 storage_usage_collection_interval: Duration,
533 ) -> Self {
534 self.storage_usage_collection_interval = storage_usage_collection_interval;
535 self
536 }
537
538 pub fn with_storage_usage_retention_period(
539 mut self,
540 storage_usage_retention_period: Duration,
541 ) -> Self {
542 self.storage_usage_retention_period = Some(storage_usage_retention_period);
543 self
544 }
545
546 pub fn with_default_cluster_replica_size(
547 mut self,
548 default_cluster_replica_size: String,
549 ) -> Self {
550 self.default_cluster_replica_size = default_cluster_replica_size;
551 self
552 }
553
554 pub fn with_builtin_system_cluster_replica_size(
555 mut self,
556 builtin_system_cluster_replica_size: String,
557 ) -> Self {
558 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
559 self
560 }
561
562 pub fn with_builtin_system_cluster_replication_factor(
563 mut self,
564 builtin_system_cluster_replication_factor: u32,
565 ) -> Self {
566 self.builtin_system_cluster_config.replication_factor =
567 builtin_system_cluster_replication_factor;
568 self
569 }
570
571 pub fn with_builtin_catalog_server_cluster_replica_size(
572 mut self,
573 builtin_catalog_server_cluster_replica_size: String,
574 ) -> Self {
575 self.builtin_catalog_server_cluster_config.size =
576 builtin_catalog_server_cluster_replica_size;
577 self
578 }
579
580 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
581 self.propagate_crashes = propagate_crashes;
582 self
583 }
584
585 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
586 self.enable_tracing = enable_tracing;
587 self
588 }
589
590 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
591 self.bootstrap_role = bootstrap_role;
592 self
593 }
594
595 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
596 self.deploy_generation = deploy_generation;
597 self
598 }
599
600 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
601 self.system_parameter_defaults.insert(param, value);
602 self
603 }
604
605 pub fn with_internal_console_redirect_url(
606 mut self,
607 internal_console_redirect_url: Option<String>,
608 ) -> Self {
609 self.internal_console_redirect_url = internal_console_redirect_url;
610 self
611 }
612
613 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
614 self.metrics_registry = Some(registry);
615 self
616 }
617
618 pub fn with_code_version(mut self, version: semver::Version) -> Self {
619 self.code_version = version;
620 self
621 }
622
623 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
624 self.capture = Some(storage);
625 self
626 }
627}
628
629pub struct Listeners {
630 pub inner: crate::Listeners,
631}
632
633impl Listeners {
634 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
635 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
636 Ok(Listeners { inner })
637 }
638
639 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
640 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
641 .await
642 }
643
644 pub async fn serve_with_trigger(
645 self,
646 config: TestHarness,
647 tls_reload_certs: ReloadTrigger,
648 ) -> Result<TestServer, anyhow::Error> {
649 let (data_directory, temp_dir) = match config.data_directory {
650 None => {
651 let temp_dir = tempfile::tempdir()?;
656 (temp_dir.path().to_path_buf(), Some(temp_dir))
657 }
658 Some(data_directory) => (data_directory, None),
659 };
660 let scratch_dir = tempfile::tempdir()?;
661 let (consensus_uri, timestamp_oracle_url) = {
662 let seed = config.seed;
663 let cockroach_url = env::var("METADATA_BACKEND_URL")
664 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
665 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
666 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
667 if let Err(err) = conn.await {
668 panic!("connection error: {}", err);
669 };
670 });
671 client
672 .batch_execute(&format!(
673 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
674 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
675 ))
676 .await?;
677 (
678 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
679 .parse()
680 .expect("invalid consensus URI"),
681 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
682 .parse()
683 .expect("invalid timestamp oracle URI"),
684 )
685 };
686 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
687 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
688 image_dir: env::current_exe()?
689 .parent()
690 .unwrap()
691 .parent()
692 .unwrap()
693 .to_path_buf(),
694 suppress_output: false,
695 environment_id: config.environment_id.to_string(),
696 secrets_dir: data_directory.join("secrets"),
697 command_wrapper: vec![],
698 propagate_crashes: config.propagate_crashes,
699 tcp_proxy: None,
700 scratch_directory: scratch_dir.path().to_path_buf(),
701 })
702 .await?;
703 let orchestrator = Arc::new(orchestrator);
704 let persist_now = SYSTEM_TIME.clone();
707 let dyncfgs = mz_dyncfgs::all_dyncfgs();
708
709 let mut updates = ConfigUpdates::default();
710 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
713 updates.apply(&dyncfgs);
714
715 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
716 persist_cfg.build_version = config.code_version;
717 persist_cfg.set_rollup_threshold(5);
719
720 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
721 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
722 let persist_pubsub_tcp_listener =
723 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
724 .await
725 .expect("pubsub addr binding");
726 let persist_pubsub_server_port = persist_pubsub_tcp_listener
727 .local_addr()
728 .expect("pubsub addr has local addr")
729 .port();
730
731 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
733 persist_pubsub_server
734 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
735 .await
736 .expect("success")
737 });
738 let persist_clients =
739 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
740 let persist_clients = Arc::new(persist_clients);
741
742 let secrets_controller = Arc::clone(&orchestrator);
743 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
744 let orchestrator = Arc::new(TracingOrchestrator::new(
745 orchestrator,
746 config.orchestrator_tracing_cli_args,
747 ));
748 let tracing_handle = if config.enable_tracing {
749 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
750 service_name: "environmentd",
751 stderr_log: StderrLogConfig {
752 format: StderrLogFormat::Json,
753 filter: EnvFilter::default(),
754 },
755 opentelemetry: Some(OpenTelemetryConfig {
756 endpoint: "http://fake_address_for_testing:8080".to_string(),
757 headers: http::HeaderMap::new(),
758 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
759 resource: opentelemetry_sdk::resource::Resource::builder().build(),
760 max_batch_queue_size: 2048,
761 max_export_batch_size: 512,
762 max_concurrent_exports: 1,
763 batch_scheduled_delay: Duration::from_millis(5000),
764 max_export_timeout: Duration::from_secs(30),
765 }),
766 tokio_console: None,
767 sentry: None,
768 build_version: crate::BUILD_INFO.version,
769 build_sha: crate::BUILD_INFO.sha,
770 registry: metrics_registry.clone(),
771 capture: config.capture,
772 };
773 mz_ore::tracing::configure(config).await?
774 } else {
775 TracingHandle::disabled()
776 };
777 let host_name = format!(
778 "localhost:{}",
779 self.inner.http["external"].handle.local_addr.port()
780 );
781 let catalog_config = CatalogConfig {
782 persist_clients: Arc::clone(&persist_clients),
783 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
784 };
785
786 let inner = self
787 .inner
788 .serve(crate::Config {
789 catalog_config,
790 timestamp_oracle_url: Some(timestamp_oracle_url),
791 controller: ControllerConfig {
792 build_info: &crate::BUILD_INFO,
793 orchestrator,
794 clusterd_image: "clusterd".into(),
795 init_container_image: None,
796 deploy_generation: config.deploy_generation,
797 persist_location: PersistLocation {
798 blob_uri: format!("file://{}/persist/blob", data_directory.display())
799 .parse()
800 .expect("invalid blob URI"),
801 consensus_uri,
802 },
803 persist_clients,
804 now: config.now.clone(),
805 metrics_registry: metrics_registry.clone(),
806 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
807 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
808 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
809 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
810 secrets_reader_kubernetes_context: None,
811 secrets_reader_aws_prefix: None,
812 secrets_reader_name_prefix: None,
813 },
814 connection_context,
815 replica_http_locator: Default::default(),
816 },
817 secrets_controller,
818 cloud_resource_controller: None,
819 tls: config.tls,
820 frontegg: config.frontegg,
821 unsafe_mode: config.unsafe_mode,
822 all_features: false,
823 metrics_registry: metrics_registry.clone(),
824 now: config.now,
825 environment_id: config.environment_id,
826 cors_allowed_origin: AllowOrigin::list([]),
827 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
828 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
829 bootstrap_default_cluster_replication_factor: config
830 .default_cluster_replication_factor,
831 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
832 bootstrap_builtin_catalog_server_cluster_config: config
833 .builtin_catalog_server_cluster_config,
834 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
835 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
836 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
837 system_parameter_defaults: config.system_parameter_defaults,
838 availability_zones: Default::default(),
839 tracing_handle,
840 storage_usage_collection_interval: config.storage_usage_collection_interval,
841 storage_usage_retention_period: config.storage_usage_retention_period,
842 segment_api_key: None,
843 segment_client_side: false,
844 test_only_dummy_segment_client: false,
845 egress_addresses: vec![],
846 aws_account_id: None,
847 aws_privatelink_availability_zones: None,
848 launchdarkly_sdk_key: None,
849 launchdarkly_key_map: Default::default(),
850 config_sync_file_path: None,
851 config_sync_timeout: Duration::from_secs(30),
852 config_sync_loop_interval: None,
853 bootstrap_role: config.bootstrap_role,
854 http_host_name: Some(host_name),
855 internal_console_redirect_url: config.internal_console_redirect_url,
856 tls_reload_certs,
857 helm_chart_version: None,
858 license_key: ValidatedLicenseKey::for_tests(),
859 external_login_password_mz_system: config.external_login_password_mz_system,
860 force_builtin_schema_migration: None,
861 })
862 .await?;
863
864 Ok(TestServer {
865 inner,
866 metrics_registry,
867 _temp_dir: temp_dir,
868 _scratch_dir: scratch_dir,
869 })
870 }
871}
872
873pub struct TestServer {
875 pub inner: crate::Server,
876 pub metrics_registry: MetricsRegistry,
877 _temp_dir: Option<TempDir>,
879 _scratch_dir: TempDir,
880}
881
882impl TestServer {
883 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
884 ConnectBuilder::new(self).no_tls()
885 }
886
887 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
888 let internal_client = self.connect().internal().await.unwrap();
889
890 for flag in flags {
891 internal_client
892 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
893 .await
894 .unwrap();
895 }
896 }
897
898 pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
899 let internal_client = self.connect().internal().await.unwrap();
900
901 for flag in flags {
902 internal_client
903 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
904 .await
905 .unwrap();
906 }
907 }
908
909 pub fn ws_addr(&self) -> Uri {
910 format!(
911 "ws://{}/api/experimental/sql",
912 self.inner.http_listener_handles["external"].local_addr
913 )
914 .parse()
915 .unwrap()
916 }
917
918 pub fn internal_ws_addr(&self) -> Uri {
919 format!(
920 "ws://{}/api/experimental/sql",
921 self.inner.http_listener_handles["internal"].local_addr
922 )
923 .parse()
924 .unwrap()
925 }
926
927 pub fn http_local_addr(&self) -> SocketAddr {
928 self.inner.http_listener_handles["external"].local_addr
929 }
930
931 pub fn internal_http_local_addr(&self) -> SocketAddr {
932 self.inner.http_listener_handles["internal"].local_addr
933 }
934
935 pub fn sql_local_addr(&self) -> SocketAddr {
936 self.inner.sql_listener_handles["external"].local_addr
937 }
938
939 pub fn internal_sql_local_addr(&self) -> SocketAddr {
940 self.inner.sql_listener_handles["internal"].local_addr
941 }
942}
943
944pub struct ConnectBuilder<'s, T, H> {
948 server: &'s TestServer,
950
951 pg_config: tokio_postgres::Config,
953 port: u16,
955 tls: T,
957
958 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
960
961 _with_handle: H,
963}
964
965impl<'s> ConnectBuilder<'s, (), NoHandle> {
966 fn new(server: &'s TestServer) -> Self {
967 let mut pg_config = tokio_postgres::Config::new();
968 pg_config
969 .host(&Ipv4Addr::LOCALHOST.to_string())
970 .user("materialize")
971 .options("--welcome_message=off")
972 .application_name("environmentd_test_framework");
973
974 ConnectBuilder {
975 server,
976 pg_config,
977 port: server.sql_local_addr().port(),
978 tls: (),
979 notice_callback: None,
980 _with_handle: NoHandle,
981 }
982 }
983}
984
985impl<'s, T, H> ConnectBuilder<'s, T, H> {
986 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
990 ConnectBuilder {
991 server: self.server,
992 pg_config: self.pg_config,
993 port: self.port,
994 tls: postgres::NoTls,
995 notice_callback: self.notice_callback,
996 _with_handle: self._with_handle,
997 }
998 }
999
1000 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1002 where
1003 Tls: MakeTlsConnect<Socket> + Send + 'static,
1004 Tls::TlsConnect: Send,
1005 Tls::Stream: Send,
1006 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1007 {
1008 ConnectBuilder {
1009 server: self.server,
1010 pg_config: self.pg_config,
1011 port: self.port,
1012 tls,
1013 notice_callback: self.notice_callback,
1014 _with_handle: self._with_handle,
1015 }
1016 }
1017
1018 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1020 self.pg_config = pg_config;
1021 self
1022 }
1023
1024 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1026 self.pg_config.ssl_mode(mode);
1027 self
1028 }
1029
1030 pub fn user(mut self, user: &str) -> Self {
1032 self.pg_config.user(user);
1033 self
1034 }
1035
1036 pub fn password(mut self, password: &str) -> Self {
1038 self.pg_config.password(password);
1039 self
1040 }
1041
1042 pub fn application_name(mut self, application_name: &str) -> Self {
1044 self.pg_config.application_name(application_name);
1045 self
1046 }
1047
1048 pub fn dbname(mut self, dbname: &str) -> Self {
1050 self.pg_config.dbname(dbname);
1051 self
1052 }
1053
1054 pub fn options(mut self, options: &str) -> Self {
1056 self.pg_config.options(options);
1057 self
1058 }
1059
1060 pub fn internal(mut self) -> Self {
1065 self.port = self.server.internal_sql_local_addr().port();
1066 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1067 self
1068 }
1069
1070 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1072 ConnectBuilder {
1073 notice_callback: Some(Box::new(callback)),
1074 ..self
1075 }
1076 }
1077
1078 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1081 ConnectBuilder {
1082 server: self.server,
1083 pg_config: self.pg_config,
1084 port: self.port,
1085 tls: self.tls,
1086 notice_callback: self.notice_callback,
1087 _with_handle: WithHandle,
1088 }
1089 }
1090
1091 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1093 &self.pg_config
1094 }
1095}
1096
1097pub trait IncludeHandle: Send {
1100 type Output;
1101 fn transform_result(
1102 client: tokio_postgres::Client,
1103 handle: mz_ore::task::JoinHandle<()>,
1104 ) -> Self::Output;
1105}
1106
1107pub struct NoHandle;
1110impl IncludeHandle for NoHandle {
1111 type Output = tokio_postgres::Client;
1112 fn transform_result(
1113 client: tokio_postgres::Client,
1114 _handle: mz_ore::task::JoinHandle<()>,
1115 ) -> Self::Output {
1116 client
1117 }
1118}
1119
1120pub struct WithHandle;
1123impl IncludeHandle for WithHandle {
1124 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1125 fn transform_result(
1126 client: tokio_postgres::Client,
1127 handle: mz_ore::task::JoinHandle<()>,
1128 ) -> Self::Output {
1129 (client, handle)
1130 }
1131}
1132
1133impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1134where
1135 T: MakeTlsConnect<Socket> + Send + 'static,
1136 T::TlsConnect: Send,
1137 T::Stream: Send,
1138 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1139 H: IncludeHandle,
1140{
1141 type Output = Result<H::Output, postgres::Error>;
1142 type IntoFuture = BoxFuture<'static, Self::Output>;
1143
1144 fn into_future(mut self) -> Self::IntoFuture {
1145 Box::pin(async move {
1146 assert!(
1147 self.pg_config.get_ports().is_empty(),
1148 "specifying multiple ports is not supported"
1149 );
1150 self.pg_config.port(self.port);
1151
1152 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1153 let mut notice_callback = self.notice_callback.take();
1154
1155 let handle = task::spawn(|| "connect", async move {
1156 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1157 match msg {
1158 Ok(AsyncMessage::Notice(notice)) => {
1159 if let Some(callback) = notice_callback.as_mut() {
1160 callback(notice);
1161 }
1162 }
1163 Ok(msg) => {
1164 tracing::debug!(?msg, "Dropping message from database");
1165 }
1166 Err(e) => {
1167 tracing::info!("connection error: {e}");
1172 break;
1173 }
1174 }
1175 }
1176 tracing::info!("connection closed");
1177 });
1178
1179 let output = H::transform_result(client, handle);
1180 Ok(output)
1181 })
1182 }
1183}
1184
1185pub struct TestServerWithRuntime {
1190 server: TestServer,
1191 runtime: Arc<Runtime>,
1192}
1193
1194impl TestServerWithRuntime {
1195 pub fn runtime(&self) -> &Arc<Runtime> {
1199 &self.runtime
1200 }
1201
1202 pub fn inner(&self) -> &crate::Server {
1204 &self.server.inner
1205 }
1206
1207 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1209 where
1210 T: MakeTlsConnect<Socket> + Send + 'static,
1211 T::TlsConnect: Send,
1212 T::Stream: Send,
1213 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1214 {
1215 self.pg_config().connect(tls)
1216 }
1217
1218 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1220 where
1221 T: MakeTlsConnect<Socket> + Send + 'static,
1222 T::TlsConnect: Send,
1223 T::Stream: Send,
1224 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1225 {
1226 Ok(self.pg_config_internal().connect(tls)?)
1227 }
1228
1229 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1231 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1232
1233 for flag in flags {
1234 internal_client
1235 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1236 .unwrap();
1237 }
1238 }
1239
1240 pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1242 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1243
1244 for flag in flags {
1245 internal_client
1246 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1247 .unwrap();
1248 }
1249 }
1250
1251 pub fn pg_config(&self) -> postgres::Config {
1254 let local_addr = self.server.sql_local_addr();
1255 let mut config = postgres::Config::new();
1256 config
1257 .host(&Ipv4Addr::LOCALHOST.to_string())
1258 .port(local_addr.port())
1259 .user("materialize")
1260 .options("--welcome_message=off");
1261 config
1262 }
1263
1264 pub fn pg_config_internal(&self) -> postgres::Config {
1267 let local_addr = self.server.internal_sql_local_addr();
1268 let mut config = postgres::Config::new();
1269 config
1270 .host(&Ipv4Addr::LOCALHOST.to_string())
1271 .port(local_addr.port())
1272 .user("mz_system")
1273 .options("--welcome_message=off");
1274 config
1275 }
1276
1277 pub fn ws_addr(&self) -> Uri {
1278 self.server.ws_addr()
1279 }
1280
1281 pub fn internal_ws_addr(&self) -> Uri {
1282 self.server.internal_ws_addr()
1283 }
1284
1285 pub fn http_local_addr(&self) -> SocketAddr {
1286 self.server.http_local_addr()
1287 }
1288
1289 pub fn internal_http_local_addr(&self) -> SocketAddr {
1290 self.server.internal_http_local_addr()
1291 }
1292
1293 pub fn sql_local_addr(&self) -> SocketAddr {
1294 self.server.sql_local_addr()
1295 }
1296
1297 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1298 self.server.internal_sql_local_addr()
1299 }
1300
1301 pub fn metrics_registry(&self) -> &MetricsRegistry {
1303 &self.server.metrics_registry
1304 }
1305}
1306
1307#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1308pub struct MzTimestamp(pub u64);
1309
1310impl<'a> FromSql<'a> for MzTimestamp {
1311 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1312 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1313 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1314 }
1315
1316 fn accepts(ty: &Type) -> bool {
1317 mz_pgrepr::Numeric::accepts(ty)
1318 }
1319}
1320
1321pub trait PostgresErrorExt {
1322 fn unwrap_db_error(self) -> DbError;
1323}
1324
1325impl PostgresErrorExt for postgres::Error {
1326 fn unwrap_db_error(self) -> DbError {
1327 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1328 Some(e) => e.clone(),
1329 None => panic!("expected DbError, but got: {:?}", self),
1330 }
1331 }
1332}
1333
1334impl<T, E> PostgresErrorExt for Result<T, E>
1335where
1336 E: PostgresErrorExt,
1337{
1338 fn unwrap_db_error(self) -> DbError {
1339 match self {
1340 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1341 Err(e) => e.unwrap_db_error(),
1342 }
1343 }
1344}
1345
1346pub async fn insert_with_deterministic_timestamps(
1350 table: &'static str,
1351 values: &'static str,
1352 server: &TestServer,
1353 now: Arc<std::sync::Mutex<EpochMillis>>,
1354) -> Result<(), Box<dyn Error>> {
1355 let client_write = server.connect().await?;
1356 let client_read = server.connect().await?;
1357
1358 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1359
1360 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1361
1362 let write_future = client_write.execute(&insert_query, &[]);
1363 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1364
1365 let mut write_future = std::pin::pin!(write_future);
1366 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1367
1368 loop {
1371 tokio::select! {
1372 _ = (&mut write_future) => return Ok(()),
1373 _ = timestamp_interval.tick() => {
1374 current_ts += 1;
1375 *now.lock().expect("lock poisoned") = current_ts;
1376 }
1377 };
1378 }
1379}
1380
1381pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1382 try_get_explain_timestamp(from_suffix, client)
1383 .await
1384 .unwrap()
1385}
1386
1387pub async fn try_get_explain_timestamp(
1388 from_suffix: &str,
1389 client: &Client,
1390) -> Result<EpochMillis, anyhow::Error> {
1391 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1392 let ts = det.determination.timestamp_context.timestamp_or_default();
1393 Ok(ts.into())
1394}
1395
1396pub async fn get_explain_timestamp_determination(
1397 from_suffix: &str,
1398 client: &Client,
1399) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1400 let row = client
1401 .query_one(
1402 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1403 &[],
1404 )
1405 .await?;
1406 let explain: String = row.get(0);
1407 Ok(serde_json::from_str(&explain).unwrap())
1408}
1409
1410pub async fn create_postgres_source_with_table<'a>(
1418 server: &TestServer,
1419 mz_client: &Client,
1420 table_name: &str,
1421 table_schema: &str,
1422 source_name: &str,
1423) -> (
1424 Client,
1425 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1426) {
1427 server
1428 .enable_feature_flags(&["enable_create_table_from_source"])
1429 .await;
1430
1431 let postgres_url = env::var("POSTGRES_URL")
1432 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1433 .unwrap();
1434
1435 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1436 .await
1437 .unwrap();
1438
1439 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1440 let user = pg_config.get_user().unwrap_or("postgres");
1441 let db_name = pg_config.get_dbname().unwrap_or(user);
1442 let ports = pg_config.get_ports();
1443 let port = if ports.is_empty() { 5432 } else { ports[0] };
1444 let hosts = pg_config.get_hosts();
1445 let host = if hosts.is_empty() {
1446 "localhost".to_string()
1447 } else {
1448 match &hosts[0] {
1449 Host::Tcp(host) => host.to_string(),
1450 Host::Unix(host) => host.to_str().unwrap().to_string(),
1451 }
1452 };
1453 let password = pg_config.get_password();
1454
1455 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1456 if let Err(e) = connection.await {
1457 panic!("connection error: {}", e);
1458 }
1459 });
1460
1461 let _ = pg_client
1463 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1464 .await
1465 .unwrap();
1466 let _ = pg_client
1467 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1468 .await
1469 .unwrap();
1470 let _ = pg_client
1471 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1472 .await
1473 .unwrap();
1474 let _ = pg_client
1475 .execute(
1476 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1477 &[],
1478 )
1479 .await
1480 .unwrap();
1481 let _ = pg_client
1482 .execute(
1483 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1484 &[],
1485 )
1486 .await
1487 .unwrap();
1488
1489 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1491 if let Some(password) = password {
1492 let password = std::str::from_utf8(password).unwrap();
1493 mz_client
1494 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1495 .await
1496 .unwrap();
1497 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1498 }
1499 mz_client
1500 .batch_execute(&format!(
1501 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1502 ))
1503 .await
1504 .unwrap();
1505 mz_client
1506 .batch_execute(&format!(
1507 "CREATE SOURCE {source_name}
1508 FROM POSTGRES
1509 CONNECTION pgconn
1510 (PUBLICATION '{source_name}')"
1511 ))
1512 .await
1513 .unwrap();
1514 mz_client
1515 .batch_execute(&format!(
1516 "CREATE TABLE {table_name}
1517 FROM SOURCE {source_name}
1518 (REFERENCE {table_name});"
1519 ))
1520 .await
1521 .unwrap();
1522
1523 let table_name = table_name.to_string();
1524 let source_name = source_name.to_string();
1525 (
1526 pg_client,
1527 move |mz_client: &'a Client, pg_client: &'a Client| {
1528 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1529 mz_client
1530 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1531 .await
1532 .unwrap();
1533 mz_client
1534 .batch_execute("DROP CONNECTION pgconn;")
1535 .await
1536 .unwrap();
1537
1538 let _ = pg_client
1539 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1540 .await
1541 .unwrap();
1542 let _ = pg_client
1543 .execute(&format!("DROP TABLE {table_name};"), &[])
1544 .await
1545 .unwrap();
1546 });
1547 f
1548 },
1549 )
1550}
1551
1552pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1553 let current_isolation = mz_client
1554 .query_one("SHOW transaction_isolation", &[])
1555 .await
1556 .unwrap()
1557 .get::<_, String>(0);
1558 mz_client
1559 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1560 .await
1561 .unwrap();
1562 Retry::default()
1563 .retry_async(|_| async move {
1564 let rows = mz_client
1565 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1566 .await
1567 .unwrap()
1568 .get::<_, i64>(0);
1569 if rows == source_rows {
1570 Ok(())
1571 } else {
1572 Err(format!(
1573 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1574 ))
1575 }
1576 })
1577 .await
1578 .unwrap();
1579 mz_client
1580 .batch_execute(&format!(
1581 "SET transaction_isolation = '{current_isolation}'"
1582 ))
1583 .await
1584 .unwrap();
1585}
1586
1587pub fn auth_with_ws(
1589 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1590 mut options: BTreeMap<String, String>,
1591) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1592 if !options.contains_key("welcome_message") {
1593 options.insert("welcome_message".into(), "off".into());
1594 }
1595 auth_with_ws_impl(
1596 ws,
1597 Message::Text(
1598 serde_json::to_string(&WebSocketAuth::Basic {
1599 user: "materialize".into(),
1600 password: "".into(),
1601 options,
1602 })
1603 .unwrap()
1604 .into(),
1605 ),
1606 )
1607}
1608
1609pub fn auth_with_ws_impl(
1610 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1611 auth_message: Message,
1612) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1613 ws.send(auth_message)?;
1614
1615 let mut msgs = Vec::new();
1617 loop {
1618 let resp = ws.read()?;
1619 match resp {
1620 Message::Text(msg) => {
1621 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1622 match msg {
1623 WebSocketResponse::ReadyForQuery(_) => break,
1624 msg => {
1625 msgs.push(msg);
1626 }
1627 }
1628 }
1629 Message::Ping(_) => continue,
1630 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1631 Message::Close(Some(close_frame)) => {
1632 return Err(anyhow!("ws closed after auth").context(close_frame));
1633 }
1634 _ => panic!("unexpected response: {:?}", resp),
1635 }
1636 }
1637 Ok(msgs)
1638}
1639
1640pub fn make_header<H: Header>(h: H) -> HeaderMap {
1641 let mut map = HeaderMap::new();
1642 map.typed_insert(h);
1643 map
1644}
1645
1646pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1647where
1648 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1649{
1650 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1651 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1672 connector_builder.set_options(options);
1673 configure(&mut connector_builder).unwrap();
1674 MakeTlsConnector::new(connector_builder.build())
1675}
1676
1677pub struct Ca {
1679 pub dir: TempDir,
1680 pub name: X509Name,
1681 pub cert: X509,
1682 pub pkey: PKey<Private>,
1683}
1684
1685impl Ca {
1686 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1687 let dir = tempfile::tempdir()?;
1688 let rsa = Rsa::generate(2048)?;
1689 let pkey = PKey::from_rsa(rsa)?;
1690 let name = {
1691 let mut builder = X509NameBuilder::new()?;
1692 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1693 builder.build()
1694 };
1695 let cert = {
1696 let mut builder = X509::builder()?;
1697 builder.set_version(2)?;
1698 builder.set_pubkey(&pkey)?;
1699 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1700 builder.set_subject_name(&name)?;
1701 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1702 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1703 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1704 builder.sign(
1705 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1706 MessageDigest::sha256(),
1707 )?;
1708 builder.build()
1709 };
1710 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1711 Ok(Ca {
1712 dir,
1713 name,
1714 cert,
1715 pkey,
1716 })
1717 }
1718
1719 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1721 Ca::make_ca(name, None)
1722 }
1723
1724 pub fn ca_cert_path(&self) -> PathBuf {
1726 self.dir.path().join("ca.crt")
1727 }
1728
1729 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1731 Ca::make_ca(name, Some(self))
1732 }
1733
1734 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1739 self.request_cert(name, iter::empty())
1740 }
1741
1742 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1745 where
1746 I: IntoIterator<Item = IpAddr>,
1747 {
1748 let rsa = Rsa::generate(2048)?;
1749 let pkey = PKey::from_rsa(rsa)?;
1750 let subject_name = {
1751 let mut builder = X509NameBuilder::new()?;
1752 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1753 builder.build()
1754 };
1755 let cert = {
1756 let mut builder = X509::builder()?;
1757 builder.set_version(2)?;
1758 builder.set_pubkey(&pkey)?;
1759 builder.set_issuer_name(self.cert.subject_name())?;
1760 builder.set_subject_name(&subject_name)?;
1761 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1762 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1763 for ip in ips {
1764 builder.append_extension(
1765 SubjectAlternativeName::new()
1766 .ip(&ip.to_string())
1767 .build(&builder.x509v3_context(None, None))?,
1768 )?;
1769 }
1770 builder.sign(&self.pkey, MessageDigest::sha256())?;
1771 builder.build()
1772 };
1773 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1774 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1775 fs::write(&cert_path, cert.to_pem()?)?;
1776 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1777 Ok((cert_path, key_path))
1778 }
1779}