1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91 WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97#[derive(Clone)]
99pub struct TestHarness {
100 data_directory: Option<PathBuf>,
101 tls: Option<TlsCertConfig>,
102 frontegg: Option<FronteggAuthenticator>,
103 external_login_password_mz_system: Option<Password>,
104 listeners_config: ListenersConfig,
105 unsafe_mode: bool,
106 workers: usize,
107 now: NowFn,
108 seed: u32,
109 storage_usage_collection_interval: Duration,
110 storage_usage_retention_period: Option<Duration>,
111 default_cluster_replica_size: String,
112 default_cluster_replication_factor: u32,
113 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119 propagate_crashes: bool,
120 enable_tracing: bool,
121 orchestrator_tracing_cli_args: TracingCliArgs,
124 bootstrap_role: Option<String>,
125 deploy_generation: u64,
126 system_parameter_defaults: BTreeMap<String, String>,
127 internal_console_redirect_url: Option<String>,
128 metrics_registry: Option<MetricsRegistry>,
129 code_version: semver::Version,
130 capture: Option<SharedStorage>,
131 pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135 fn default() -> TestHarness {
136 TestHarness {
137 data_directory: None,
138 tls: None,
139 frontegg: None,
140 external_login_password_mz_system: None,
141 listeners_config: ListenersConfig {
142 sql: btreemap![
143 "external".to_owned() => SqlListenerConfig {
144 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145 authenticator_kind: AuthenticatorKind::None,
146 allowed_roles: AllowedRoles::Normal,
147 enable_tls: false,
148 },
149 "internal".to_owned() => SqlListenerConfig {
150 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151 authenticator_kind: AuthenticatorKind::None,
152 allowed_roles: AllowedRoles::NormalAndInternal,
153 enable_tls: false,
154 },
155 ],
156 http: btreemap![
157 "external".to_owned() => HttpListenerConfig {
158 base: BaseListenerConfig {
159 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160 authenticator_kind: AuthenticatorKind::None,
161 allowed_roles: AllowedRoles::Normal,
162 enable_tls: false,
163 },
164 routes: HttpRoutesEnabled{
165 base: true,
166 webhook: true,
167 internal: false,
168 metrics: false,
169 profiling: false,
170 mcp_agents: false,
171 mcp_observatory: false,
172 },
173 },
174 "internal".to_owned() => HttpListenerConfig {
175 base: BaseListenerConfig {
176 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
177 authenticator_kind: AuthenticatorKind::None,
178 allowed_roles: AllowedRoles::NormalAndInternal,
179 enable_tls: false,
180 },
181 routes: HttpRoutesEnabled{
182 base: true,
183 webhook: true,
184 internal: true,
185 metrics: true,
186 profiling: true,
187 mcp_agents: false,
188 mcp_observatory: false,
189 },
190 },
191 ],
192 },
193 unsafe_mode: false,
194 workers: 1,
195 now: SYSTEM_TIME.clone(),
196 seed: rand::random(),
197 storage_usage_collection_interval: Duration::from_secs(3600),
198 storage_usage_retention_period: None,
199 default_cluster_replica_size: "scale=1,workers=1".to_string(),
200 default_cluster_replication_factor: 1,
201 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
202 size: "scale=1,workers=1".to_string(),
203 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204 },
205 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
206 size: "scale=1,workers=1".to_string(),
207 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208 },
209 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
210 size: "scale=1,workers=1".to_string(),
211 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212 },
213 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
214 size: "scale=1,workers=1".to_string(),
215 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216 },
217 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
218 size: "scale=1,workers=1".to_string(),
219 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
220 },
221 propagate_crashes: false,
222 enable_tracing: false,
223 bootstrap_role: Some("materialize".into()),
224 deploy_generation: 0,
225 system_parameter_defaults: BTreeMap::from([(
228 "log_filter".to_string(),
229 "error".to_string(),
230 )]),
231 internal_console_redirect_url: None,
232 metrics_registry: None,
233 orchestrator_tracing_cli_args: TracingCliArgs {
234 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
235 ..Default::default()
236 },
237 code_version: crate::BUILD_INFO.semver_version(),
238 environment_id: EnvironmentId::for_tests(),
239 capture: None,
240 }
241 }
242}
243
244impl TestHarness {
245 pub async fn start(self) -> TestServer {
249 self.try_start().await.expect("Failed to start test Server")
250 }
251
252 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
254 self.try_start_with_trigger(tls_reload_certs)
255 .await
256 .expect("Failed to start test Server")
257 }
258
259 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
261 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
262 .await
263 }
264
265 pub async fn try_start_with_trigger(
267 self,
268 tls_reload_certs: ReloadTrigger,
269 ) -> Result<TestServer, anyhow::Error> {
270 let listeners = Listeners::new(&self).await?;
271 listeners.serve_with_trigger(self, tls_reload_certs).await
272 }
273
274 pub fn start_blocking(self) -> TestServerWithRuntime {
276 let runtime = tokio::runtime::Builder::new_multi_thread()
277 .enable_all()
278 .thread_stack_size(mz_ore::stack::STACK_SIZE)
279 .build()
280 .expect("failed to spawn runtime for test");
281 let runtime = Arc::new(runtime);
282 let server = runtime.block_on(self.start());
283 TestServerWithRuntime { runtime, server }
284 }
285
286 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
287 self.data_directory = Some(data_directory.into());
288 self
289 }
290
291 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
292 self.tls = Some(TlsCertConfig {
293 cert: cert_path.into(),
294 key: key_path.into(),
295 });
296 for (_, listener) in &mut self.listeners_config.sql {
297 listener.enable_tls = true;
298 }
299 for (_, listener) in &mut self.listeners_config.http {
300 listener.base.enable_tls = true;
301 }
302 self
303 }
304
305 pub fn unsafe_mode(mut self) -> Self {
306 self.unsafe_mode = true;
307 self
308 }
309
310 pub fn workers(mut self, workers: usize) -> Self {
311 self.workers = workers;
312 self
313 }
314
315 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
316 self.frontegg = Some(frontegg.clone());
317 let enable_tls = self.tls.is_some();
318 self.listeners_config = ListenersConfig {
319 sql: btreemap! {
320 "external".to_owned() => SqlListenerConfig {
321 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
322 authenticator_kind: AuthenticatorKind::Frontegg,
323 allowed_roles: AllowedRoles::Normal,
324 enable_tls,
325 },
326 "internal".to_owned() => SqlListenerConfig {
327 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
328 authenticator_kind: AuthenticatorKind::None,
329 allowed_roles: AllowedRoles::NormalAndInternal,
330 enable_tls: false,
331 },
332 },
333 http: btreemap! {
334 "external".to_owned() => HttpListenerConfig {
335 base: BaseListenerConfig {
336 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
337 authenticator_kind: AuthenticatorKind::Frontegg,
338 allowed_roles: AllowedRoles::Normal,
339 enable_tls,
340 },
341 routes: HttpRoutesEnabled{
342 base: true,
343 webhook: true,
344 internal: false,
345 metrics: false,
346 profiling: false,
347 mcp_agents: false,
348 mcp_observatory: false,
349 },
350 },
351 "internal".to_owned() => HttpListenerConfig {
352 base: BaseListenerConfig {
353 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
354 authenticator_kind: AuthenticatorKind::None,
355 allowed_roles: AllowedRoles::NormalAndInternal,
356 enable_tls: false,
357 },
358 routes: HttpRoutesEnabled{
359 base: true,
360 webhook: true,
361 internal: true,
362 metrics: true,
363 profiling: true,
364 mcp_agents: false,
365 mcp_observatory: false,
366 },
367 },
368 },
369 };
370 self
371 }
372
373 pub fn with_oidc_auth(
374 mut self,
375 issuer: Option<String>,
376 authentication_claim: Option<String>,
377 expected_audiences: Option<Vec<String>>,
378 ) -> Self {
379 let enable_tls = self.tls.is_some();
380 self.listeners_config = ListenersConfig {
381 sql: btreemap! {
382 "external".to_owned() => SqlListenerConfig {
383 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
384 authenticator_kind: AuthenticatorKind::Oidc,
385 allowed_roles: AllowedRoles::Normal,
386 enable_tls,
387 },
388 "internal".to_owned() => SqlListenerConfig {
389 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
390 authenticator_kind: AuthenticatorKind::None,
391 allowed_roles: AllowedRoles::NormalAndInternal,
392 enable_tls: false,
393 },
394 },
395 http: btreemap! {
396 "external".to_owned() => HttpListenerConfig {
397 base: BaseListenerConfig {
398 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
399 authenticator_kind: AuthenticatorKind::Oidc,
400 allowed_roles: AllowedRoles::Normal,
401 enable_tls,
402 },
403 routes: HttpRoutesEnabled{
404 base: true,
405 webhook: true,
406 internal: false,
407 metrics: false,
408 profiling: false,
409 mcp_agents: false,
410 mcp_observatory: false,
411 },
412 },
413 "internal".to_owned() => HttpListenerConfig {
414 base: BaseListenerConfig {
415 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
416 authenticator_kind: AuthenticatorKind::None,
417 allowed_roles: AllowedRoles::NormalAndInternal,
418 enable_tls: false,
419 },
420 routes: HttpRoutesEnabled{
421 base: true,
422 webhook: true,
423 internal: true,
424 metrics: true,
425 profiling: true,
426 mcp_agents: false,
427 mcp_observatory: false,
428 },
429 },
430 },
431 };
432
433 if let Some(issuer) = issuer {
434 self.system_parameter_defaults
435 .insert("oidc_issuer".to_string(), issuer);
436 }
437
438 if let Some(authentication_claim) = authentication_claim {
439 self.system_parameter_defaults.insert(
440 "oidc_authentication_claim".to_string(),
441 authentication_claim,
442 );
443 }
444
445 if let Some(expected_audiences) = expected_audiences {
446 self.system_parameter_defaults.insert(
447 "oidc_audience".to_string(),
448 serde_json::to_string(&expected_audiences).unwrap(),
449 );
450 }
451
452 self
453 }
454
455 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
456 self.external_login_password_mz_system = Some(mz_system_password);
457 let enable_tls = self.tls.is_some();
458 self.listeners_config = ListenersConfig {
459 sql: btreemap! {
460 "external".to_owned() => SqlListenerConfig {
461 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
462 authenticator_kind: AuthenticatorKind::Password,
463 allowed_roles: AllowedRoles::NormalAndInternal,
464 enable_tls,
465 },
466 },
467 http: btreemap! {
468 "external".to_owned() => HttpListenerConfig {
469 base: BaseListenerConfig {
470 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
471 authenticator_kind: AuthenticatorKind::Password,
472 allowed_roles: AllowedRoles::NormalAndInternal,
473 enable_tls,
474 },
475 routes: HttpRoutesEnabled{
476 base: true,
477 webhook: true,
478 internal: true,
479 metrics: false,
480 profiling: true,
481 mcp_agents: false,
482 mcp_observatory: false,
483 },
484 },
485 "metrics".to_owned() => HttpListenerConfig {
486 base: BaseListenerConfig {
487 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
488 authenticator_kind: AuthenticatorKind::None,
489 allowed_roles: AllowedRoles::NormalAndInternal,
490 enable_tls: false,
491 },
492 routes: HttpRoutesEnabled{
493 base: false,
494 webhook: false,
495 internal: false,
496 metrics: true,
497 profiling: false,
498 mcp_agents: false,
499 mcp_observatory: false,
500 },
501 },
502 },
503 };
504 self
505 }
506
507 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
508 self.external_login_password_mz_system = Some(mz_system_password);
509 let enable_tls = self.tls.is_some();
510 self.listeners_config = ListenersConfig {
511 sql: btreemap! {
512 "external".to_owned() => SqlListenerConfig {
513 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
514 authenticator_kind: AuthenticatorKind::Sasl,
515 allowed_roles: AllowedRoles::NormalAndInternal,
516 enable_tls,
517 },
518 },
519 http: btreemap! {
520 "external".to_owned() => HttpListenerConfig {
521 base: BaseListenerConfig {
522 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
523 authenticator_kind: AuthenticatorKind::Password,
524 allowed_roles: AllowedRoles::NormalAndInternal,
525 enable_tls,
526 },
527 routes: HttpRoutesEnabled{
528 base: true,
529 webhook: true,
530 internal: true,
531 metrics: false,
532 profiling: true,
533 mcp_agents: false,
534 mcp_observatory: false,
535 },
536 },
537 "metrics".to_owned() => HttpListenerConfig {
538 base: BaseListenerConfig {
539 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
540 authenticator_kind: AuthenticatorKind::None,
541 allowed_roles: AllowedRoles::NormalAndInternal,
542 enable_tls: false,
543 },
544 routes: HttpRoutesEnabled{
545 base: false,
546 webhook: false,
547 internal: false,
548 metrics: true,
549 profiling: false,
550 mcp_agents: false,
551 mcp_observatory: false,
552 },
553 },
554 },
555 };
556 self
557 }
558
559 pub fn with_now(mut self, now: NowFn) -> Self {
560 self.now = now;
561 self
562 }
563
564 pub fn with_storage_usage_collection_interval(
565 mut self,
566 storage_usage_collection_interval: Duration,
567 ) -> Self {
568 self.storage_usage_collection_interval = storage_usage_collection_interval;
569 self
570 }
571
572 pub fn with_storage_usage_retention_period(
573 mut self,
574 storage_usage_retention_period: Duration,
575 ) -> Self {
576 self.storage_usage_retention_period = Some(storage_usage_retention_period);
577 self
578 }
579
580 pub fn with_default_cluster_replica_size(
581 mut self,
582 default_cluster_replica_size: String,
583 ) -> Self {
584 self.default_cluster_replica_size = default_cluster_replica_size;
585 self
586 }
587
588 pub fn with_builtin_system_cluster_replica_size(
589 mut self,
590 builtin_system_cluster_replica_size: String,
591 ) -> Self {
592 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
593 self
594 }
595
596 pub fn with_builtin_system_cluster_replication_factor(
597 mut self,
598 builtin_system_cluster_replication_factor: u32,
599 ) -> Self {
600 self.builtin_system_cluster_config.replication_factor =
601 builtin_system_cluster_replication_factor;
602 self
603 }
604
605 pub fn with_builtin_catalog_server_cluster_replica_size(
606 mut self,
607 builtin_catalog_server_cluster_replica_size: String,
608 ) -> Self {
609 self.builtin_catalog_server_cluster_config.size =
610 builtin_catalog_server_cluster_replica_size;
611 self
612 }
613
614 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
615 self.propagate_crashes = propagate_crashes;
616 self
617 }
618
619 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
620 self.enable_tracing = enable_tracing;
621 self
622 }
623
624 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
625 self.bootstrap_role = bootstrap_role;
626 self
627 }
628
629 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
630 self.deploy_generation = deploy_generation;
631 self
632 }
633
634 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
635 self.system_parameter_defaults.insert(param, value);
636 self
637 }
638
639 pub fn with_internal_console_redirect_url(
640 mut self,
641 internal_console_redirect_url: Option<String>,
642 ) -> Self {
643 self.internal_console_redirect_url = internal_console_redirect_url;
644 self
645 }
646
647 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
648 self.metrics_registry = Some(registry);
649 self
650 }
651
652 pub fn with_code_version(mut self, version: semver::Version) -> Self {
653 self.code_version = version;
654 self
655 }
656
657 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
658 self.capture = Some(storage);
659 self
660 }
661}
662
663pub struct Listeners {
664 pub inner: crate::Listeners,
665}
666
667impl Listeners {
668 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
669 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
670 Ok(Listeners { inner })
671 }
672
673 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
674 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
675 .await
676 }
677
678 pub async fn serve_with_trigger(
679 self,
680 config: TestHarness,
681 tls_reload_certs: ReloadTrigger,
682 ) -> Result<TestServer, anyhow::Error> {
683 let (data_directory, temp_dir) = match config.data_directory {
684 None => {
685 let temp_dir = tempfile::tempdir()?;
690 (temp_dir.path().to_path_buf(), Some(temp_dir))
691 }
692 Some(data_directory) => (data_directory, None),
693 };
694 let scratch_dir = tempfile::tempdir()?;
695 let (consensus_uri, timestamp_oracle_url) = {
696 let seed = config.seed;
697 let cockroach_url = env::var("METADATA_BACKEND_URL")
698 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
699 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
700 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
701 if let Err(err) = conn.await {
702 panic!("connection error: {}", err);
703 };
704 });
705 client
706 .batch_execute(&format!(
707 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
708 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
709 ))
710 .await?;
711 (
712 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
713 .parse()
714 .expect("invalid consensus URI"),
715 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
716 .parse()
717 .expect("invalid timestamp oracle URI"),
718 )
719 };
720 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
721 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
722 image_dir: env::current_exe()?
723 .parent()
724 .unwrap()
725 .parent()
726 .unwrap()
727 .to_path_buf(),
728 suppress_output: false,
729 environment_id: config.environment_id.to_string(),
730 secrets_dir: data_directory.join("secrets"),
731 command_wrapper: vec![],
732 propagate_crashes: config.propagate_crashes,
733 tcp_proxy: None,
734 scratch_directory: scratch_dir.path().to_path_buf(),
735 })
736 .await?;
737 let orchestrator = Arc::new(orchestrator);
738 let persist_now = SYSTEM_TIME.clone();
741 let dyncfgs = mz_dyncfgs::all_dyncfgs();
742
743 let mut updates = ConfigUpdates::default();
744 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
747 updates.apply(&dyncfgs);
748
749 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
750 persist_cfg.build_version = config.code_version;
751 persist_cfg.set_rollup_threshold(5);
753
754 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
755 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
756 let persist_pubsub_tcp_listener =
757 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
758 .await
759 .expect("pubsub addr binding");
760 let persist_pubsub_server_port = persist_pubsub_tcp_listener
761 .local_addr()
762 .expect("pubsub addr has local addr")
763 .port();
764
765 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
767 persist_pubsub_server
768 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
769 .await
770 .expect("success")
771 });
772 let persist_clients =
773 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
774 let persist_clients = Arc::new(persist_clients);
775
776 let secrets_controller = Arc::clone(&orchestrator);
777 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
778 let orchestrator = Arc::new(TracingOrchestrator::new(
779 orchestrator,
780 config.orchestrator_tracing_cli_args,
781 ));
782 let tracing_handle = if config.enable_tracing {
783 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
784 service_name: "environmentd",
785 stderr_log: StderrLogConfig {
786 format: StderrLogFormat::Json,
787 filter: EnvFilter::default(),
788 },
789 opentelemetry: Some(OpenTelemetryConfig {
790 endpoint: "http://fake_address_for_testing:8080".to_string(),
791 headers: http::HeaderMap::new(),
792 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
793 resource: opentelemetry_sdk::resource::Resource::builder().build(),
794 max_batch_queue_size: 2048,
795 max_export_batch_size: 512,
796 max_concurrent_exports: 1,
797 batch_scheduled_delay: Duration::from_millis(5000),
798 max_export_timeout: Duration::from_secs(30),
799 }),
800 tokio_console: None,
801 sentry: None,
802 build_version: crate::BUILD_INFO.version,
803 build_sha: crate::BUILD_INFO.sha,
804 registry: metrics_registry.clone(),
805 capture: config.capture,
806 };
807 mz_ore::tracing::configure(config).await?
808 } else {
809 TracingHandle::disabled()
810 };
811 let host_name = format!(
812 "localhost:{}",
813 self.inner.http["external"].handle.local_addr.port()
814 );
815 let catalog_config = CatalogConfig {
816 persist_clients: Arc::clone(&persist_clients),
817 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
818 };
819
820 let inner = self
821 .inner
822 .serve(crate::Config {
823 catalog_config,
824 timestamp_oracle_url: Some(timestamp_oracle_url),
825 controller: ControllerConfig {
826 build_info: &crate::BUILD_INFO,
827 orchestrator,
828 clusterd_image: "clusterd".into(),
829 init_container_image: None,
830 deploy_generation: config.deploy_generation,
831 persist_location: PersistLocation {
832 blob_uri: format!("file://{}/persist/blob", data_directory.display())
833 .parse()
834 .expect("invalid blob URI"),
835 consensus_uri,
836 },
837 persist_clients,
838 now: config.now.clone(),
839 metrics_registry: metrics_registry.clone(),
840 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
841 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
842 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
843 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
844 secrets_reader_kubernetes_context: None,
845 secrets_reader_aws_prefix: None,
846 secrets_reader_name_prefix: None,
847 },
848 connection_context,
849 replica_http_locator: Default::default(),
850 },
851 secrets_controller,
852 cloud_resource_controller: None,
853 tls: config.tls,
854 frontegg: config.frontegg,
855 unsafe_mode: config.unsafe_mode,
856 all_features: false,
857 metrics_registry: metrics_registry.clone(),
858 now: config.now,
859 environment_id: config.environment_id,
860 cors_allowed_origin: AllowOrigin::list([]),
861 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
862 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
863 bootstrap_default_cluster_replication_factor: config
864 .default_cluster_replication_factor,
865 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
866 bootstrap_builtin_catalog_server_cluster_config: config
867 .builtin_catalog_server_cluster_config,
868 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
869 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
870 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
871 system_parameter_defaults: config.system_parameter_defaults,
872 availability_zones: Default::default(),
873 tracing_handle,
874 storage_usage_collection_interval: config.storage_usage_collection_interval,
875 storage_usage_retention_period: config.storage_usage_retention_period,
876 segment_api_key: None,
877 segment_client_side: false,
878 test_only_dummy_segment_client: false,
879 egress_addresses: vec![],
880 aws_account_id: None,
881 aws_privatelink_availability_zones: None,
882 launchdarkly_sdk_key: None,
883 launchdarkly_key_map: Default::default(),
884 config_sync_file_path: None,
885 config_sync_timeout: Duration::from_secs(30),
886 config_sync_loop_interval: None,
887 bootstrap_role: config.bootstrap_role,
888 http_host_name: Some(host_name),
889 internal_console_redirect_url: config.internal_console_redirect_url,
890 tls_reload_certs,
891 helm_chart_version: None,
892 license_key: ValidatedLicenseKey::for_tests(),
893 external_login_password_mz_system: config.external_login_password_mz_system,
894 force_builtin_schema_migration: None,
895 })
896 .await?;
897
898 Ok(TestServer {
899 inner,
900 metrics_registry,
901 _temp_dir: temp_dir,
902 _scratch_dir: scratch_dir,
903 })
904 }
905}
906
907pub struct TestServer {
909 pub inner: crate::Server,
910 pub metrics_registry: MetricsRegistry,
911 _temp_dir: Option<TempDir>,
913 _scratch_dir: TempDir,
914}
915
916impl TestServer {
917 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
918 ConnectBuilder::new(self).no_tls()
919 }
920
921 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
922 let internal_client = self.connect().internal().await.unwrap();
923
924 for flag in flags {
925 internal_client
926 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
927 .await
928 .unwrap();
929 }
930 }
931
932 pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
933 let internal_client = self.connect().internal().await.unwrap();
934
935 for flag in flags {
936 internal_client
937 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
938 .await
939 .unwrap();
940 }
941 }
942
943 pub fn ws_addr(&self) -> Uri {
944 format!(
945 "ws://{}/api/experimental/sql",
946 self.inner.http_listener_handles["external"].local_addr
947 )
948 .parse()
949 .unwrap()
950 }
951
952 pub fn internal_ws_addr(&self) -> Uri {
953 format!(
954 "ws://{}/api/experimental/sql",
955 self.inner.http_listener_handles["internal"].local_addr
956 )
957 .parse()
958 .unwrap()
959 }
960
961 pub fn http_local_addr(&self) -> SocketAddr {
962 self.inner.http_listener_handles["external"].local_addr
963 }
964
965 pub fn internal_http_local_addr(&self) -> SocketAddr {
966 self.inner.http_listener_handles["internal"].local_addr
967 }
968
969 pub fn sql_local_addr(&self) -> SocketAddr {
970 self.inner.sql_listener_handles["external"].local_addr
971 }
972
973 pub fn internal_sql_local_addr(&self) -> SocketAddr {
974 self.inner.sql_listener_handles["internal"].local_addr
975 }
976}
977
978pub struct ConnectBuilder<'s, T, H> {
982 server: &'s TestServer,
984
985 pg_config: tokio_postgres::Config,
987 port: u16,
989 tls: T,
991
992 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
994
995 _with_handle: H,
997}
998
999impl<'s> ConnectBuilder<'s, (), NoHandle> {
1000 fn new(server: &'s TestServer) -> Self {
1001 let mut pg_config = tokio_postgres::Config::new();
1002 pg_config
1003 .host(&Ipv4Addr::LOCALHOST.to_string())
1004 .user("materialize")
1005 .options("--welcome_message=off")
1006 .application_name("environmentd_test_framework");
1007
1008 ConnectBuilder {
1009 server,
1010 pg_config,
1011 port: server.sql_local_addr().port(),
1012 tls: (),
1013 notice_callback: None,
1014 _with_handle: NoHandle,
1015 }
1016 }
1017}
1018
1019impl<'s, T, H> ConnectBuilder<'s, T, H> {
1020 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
1024 ConnectBuilder {
1025 server: self.server,
1026 pg_config: self.pg_config,
1027 port: self.port,
1028 tls: postgres::NoTls,
1029 notice_callback: self.notice_callback,
1030 _with_handle: self._with_handle,
1031 }
1032 }
1033
1034 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1036 where
1037 Tls: MakeTlsConnect<Socket> + Send + 'static,
1038 Tls::TlsConnect: Send,
1039 Tls::Stream: Send,
1040 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1041 {
1042 ConnectBuilder {
1043 server: self.server,
1044 pg_config: self.pg_config,
1045 port: self.port,
1046 tls,
1047 notice_callback: self.notice_callback,
1048 _with_handle: self._with_handle,
1049 }
1050 }
1051
1052 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1054 self.pg_config = pg_config;
1055 self
1056 }
1057
1058 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1060 self.pg_config.ssl_mode(mode);
1061 self
1062 }
1063
1064 pub fn user(mut self, user: &str) -> Self {
1066 self.pg_config.user(user);
1067 self
1068 }
1069
1070 pub fn password(mut self, password: &str) -> Self {
1072 self.pg_config.password(password);
1073 self
1074 }
1075
1076 pub fn application_name(mut self, application_name: &str) -> Self {
1078 self.pg_config.application_name(application_name);
1079 self
1080 }
1081
1082 pub fn dbname(mut self, dbname: &str) -> Self {
1084 self.pg_config.dbname(dbname);
1085 self
1086 }
1087
1088 pub fn options(mut self, options: &str) -> Self {
1090 self.pg_config.options(options);
1091 self
1092 }
1093
1094 pub fn internal(mut self) -> Self {
1099 self.port = self.server.internal_sql_local_addr().port();
1100 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1101 self
1102 }
1103
1104 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1106 ConnectBuilder {
1107 notice_callback: Some(Box::new(callback)),
1108 ..self
1109 }
1110 }
1111
1112 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1115 ConnectBuilder {
1116 server: self.server,
1117 pg_config: self.pg_config,
1118 port: self.port,
1119 tls: self.tls,
1120 notice_callback: self.notice_callback,
1121 _with_handle: WithHandle,
1122 }
1123 }
1124
1125 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1127 &self.pg_config
1128 }
1129}
1130
1131pub trait IncludeHandle: Send {
1134 type Output;
1135 fn transform_result(
1136 client: tokio_postgres::Client,
1137 handle: mz_ore::task::JoinHandle<()>,
1138 ) -> Self::Output;
1139}
1140
1141pub struct NoHandle;
1144impl IncludeHandle for NoHandle {
1145 type Output = tokio_postgres::Client;
1146 fn transform_result(
1147 client: tokio_postgres::Client,
1148 _handle: mz_ore::task::JoinHandle<()>,
1149 ) -> Self::Output {
1150 client
1151 }
1152}
1153
1154pub struct WithHandle;
1157impl IncludeHandle for WithHandle {
1158 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1159 fn transform_result(
1160 client: tokio_postgres::Client,
1161 handle: mz_ore::task::JoinHandle<()>,
1162 ) -> Self::Output {
1163 (client, handle)
1164 }
1165}
1166
1167impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1168where
1169 T: MakeTlsConnect<Socket> + Send + 'static,
1170 T::TlsConnect: Send,
1171 T::Stream: Send,
1172 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1173 H: IncludeHandle,
1174{
1175 type Output = Result<H::Output, postgres::Error>;
1176 type IntoFuture = BoxFuture<'static, Self::Output>;
1177
1178 fn into_future(mut self) -> Self::IntoFuture {
1179 Box::pin(async move {
1180 assert!(
1181 self.pg_config.get_ports().is_empty(),
1182 "specifying multiple ports is not supported"
1183 );
1184 self.pg_config.port(self.port);
1185
1186 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1187 let mut notice_callback = self.notice_callback.take();
1188
1189 let handle = task::spawn(|| "connect", async move {
1190 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1191 match msg {
1192 Ok(AsyncMessage::Notice(notice)) => {
1193 if let Some(callback) = notice_callback.as_mut() {
1194 callback(notice);
1195 }
1196 }
1197 Ok(msg) => {
1198 tracing::debug!(?msg, "Dropping message from database");
1199 }
1200 Err(e) => {
1201 tracing::info!("connection error: {e}");
1206 break;
1207 }
1208 }
1209 }
1210 tracing::info!("connection closed");
1211 });
1212
1213 let output = H::transform_result(client, handle);
1214 Ok(output)
1215 })
1216 }
1217}
1218
1219pub struct TestServerWithRuntime {
1224 server: TestServer,
1225 runtime: Arc<Runtime>,
1226}
1227
1228impl TestServerWithRuntime {
1229 pub fn runtime(&self) -> &Arc<Runtime> {
1233 &self.runtime
1234 }
1235
1236 pub fn inner(&self) -> &crate::Server {
1238 &self.server.inner
1239 }
1240
1241 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1243 where
1244 T: MakeTlsConnect<Socket> + Send + 'static,
1245 T::TlsConnect: Send,
1246 T::Stream: Send,
1247 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1248 {
1249 self.pg_config().connect(tls)
1250 }
1251
1252 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1254 where
1255 T: MakeTlsConnect<Socket> + Send + 'static,
1256 T::TlsConnect: Send,
1257 T::Stream: Send,
1258 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1259 {
1260 Ok(self.pg_config_internal().connect(tls)?)
1261 }
1262
1263 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1265 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1266
1267 for flag in flags {
1268 internal_client
1269 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1270 .unwrap();
1271 }
1272 }
1273
1274 pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1276 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1277
1278 for flag in flags {
1279 internal_client
1280 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1281 .unwrap();
1282 }
1283 }
1284
1285 pub fn pg_config(&self) -> postgres::Config {
1288 let local_addr = self.server.sql_local_addr();
1289 let mut config = postgres::Config::new();
1290 config
1291 .host(&Ipv4Addr::LOCALHOST.to_string())
1292 .port(local_addr.port())
1293 .user("materialize")
1294 .options("--welcome_message=off");
1295 config
1296 }
1297
1298 pub fn pg_config_internal(&self) -> postgres::Config {
1301 let local_addr = self.server.internal_sql_local_addr();
1302 let mut config = postgres::Config::new();
1303 config
1304 .host(&Ipv4Addr::LOCALHOST.to_string())
1305 .port(local_addr.port())
1306 .user("mz_system")
1307 .options("--welcome_message=off");
1308 config
1309 }
1310
1311 pub fn ws_addr(&self) -> Uri {
1312 self.server.ws_addr()
1313 }
1314
1315 pub fn internal_ws_addr(&self) -> Uri {
1316 self.server.internal_ws_addr()
1317 }
1318
1319 pub fn http_local_addr(&self) -> SocketAddr {
1320 self.server.http_local_addr()
1321 }
1322
1323 pub fn internal_http_local_addr(&self) -> SocketAddr {
1324 self.server.internal_http_local_addr()
1325 }
1326
1327 pub fn sql_local_addr(&self) -> SocketAddr {
1328 self.server.sql_local_addr()
1329 }
1330
1331 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1332 self.server.internal_sql_local_addr()
1333 }
1334
1335 pub fn metrics_registry(&self) -> &MetricsRegistry {
1337 &self.server.metrics_registry
1338 }
1339}
1340
1341#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1342pub struct MzTimestamp(pub u64);
1343
1344impl<'a> FromSql<'a> for MzTimestamp {
1345 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1346 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1347 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1348 }
1349
1350 fn accepts(ty: &Type) -> bool {
1351 mz_pgrepr::Numeric::accepts(ty)
1352 }
1353}
1354
1355pub trait PostgresErrorExt {
1356 fn unwrap_db_error(self) -> DbError;
1357}
1358
1359impl PostgresErrorExt for postgres::Error {
1360 fn unwrap_db_error(self) -> DbError {
1361 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1362 Some(e) => e.clone(),
1363 None => panic!("expected DbError, but got: {:?}", self),
1364 }
1365 }
1366}
1367
1368impl<T, E> PostgresErrorExt for Result<T, E>
1369where
1370 E: PostgresErrorExt,
1371{
1372 fn unwrap_db_error(self) -> DbError {
1373 match self {
1374 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1375 Err(e) => e.unwrap_db_error(),
1376 }
1377 }
1378}
1379
1380pub async fn insert_with_deterministic_timestamps(
1384 table: &'static str,
1385 values: &'static str,
1386 server: &TestServer,
1387 now: Arc<std::sync::Mutex<EpochMillis>>,
1388) -> Result<(), Box<dyn Error>> {
1389 let client_write = server.connect().await?;
1390 let client_read = server.connect().await?;
1391
1392 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1393
1394 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1395
1396 let write_future = client_write.execute(&insert_query, &[]);
1397 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1398
1399 let mut write_future = std::pin::pin!(write_future);
1400 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1401
1402 loop {
1405 tokio::select! {
1406 _ = (&mut write_future) => return Ok(()),
1407 _ = timestamp_interval.tick() => {
1408 current_ts += 1;
1409 *now.lock().expect("lock poisoned") = current_ts;
1410 }
1411 };
1412 }
1413}
1414
1415pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1416 try_get_explain_timestamp(from_suffix, client)
1417 .await
1418 .unwrap()
1419}
1420
1421pub async fn try_get_explain_timestamp(
1422 from_suffix: &str,
1423 client: &Client,
1424) -> Result<EpochMillis, anyhow::Error> {
1425 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1426 let ts = det.determination.timestamp_context.timestamp_or_default();
1427 Ok(ts.into())
1428}
1429
1430pub async fn get_explain_timestamp_determination(
1431 from_suffix: &str,
1432 client: &Client,
1433) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1434 let row = client
1435 .query_one(
1436 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1437 &[],
1438 )
1439 .await?;
1440 let explain: String = row.get(0);
1441 Ok(serde_json::from_str(&explain).unwrap())
1442}
1443
1444pub async fn create_postgres_source_with_table<'a>(
1452 server: &TestServer,
1453 mz_client: &Client,
1454 table_name: &str,
1455 table_schema: &str,
1456 source_name: &str,
1457) -> (
1458 Client,
1459 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1460) {
1461 server
1462 .enable_feature_flags(&["enable_create_table_from_source"])
1463 .await;
1464
1465 let postgres_url = env::var("POSTGRES_URL")
1466 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1467 .unwrap();
1468
1469 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1470 .await
1471 .unwrap();
1472
1473 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1474 let user = pg_config.get_user().unwrap_or("postgres");
1475 let db_name = pg_config.get_dbname().unwrap_or(user);
1476 let ports = pg_config.get_ports();
1477 let port = if ports.is_empty() { 5432 } else { ports[0] };
1478 let hosts = pg_config.get_hosts();
1479 let host = if hosts.is_empty() {
1480 "localhost".to_string()
1481 } else {
1482 match &hosts[0] {
1483 Host::Tcp(host) => host.to_string(),
1484 Host::Unix(host) => host.to_str().unwrap().to_string(),
1485 }
1486 };
1487 let password = pg_config.get_password();
1488
1489 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1490 if let Err(e) = connection.await {
1491 panic!("connection error: {}", e);
1492 }
1493 });
1494
1495 let _ = pg_client
1497 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1498 .await
1499 .unwrap();
1500 let _ = pg_client
1501 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1502 .await
1503 .unwrap();
1504 let _ = pg_client
1505 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1506 .await
1507 .unwrap();
1508 let _ = pg_client
1509 .execute(
1510 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1511 &[],
1512 )
1513 .await
1514 .unwrap();
1515 let _ = pg_client
1516 .execute(
1517 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1518 &[],
1519 )
1520 .await
1521 .unwrap();
1522
1523 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1525 if let Some(password) = password {
1526 let password = std::str::from_utf8(password).unwrap();
1527 mz_client
1528 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1529 .await
1530 .unwrap();
1531 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1532 }
1533 mz_client
1534 .batch_execute(&format!(
1535 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1536 ))
1537 .await
1538 .unwrap();
1539 mz_client
1540 .batch_execute(&format!(
1541 "CREATE SOURCE {source_name}
1542 FROM POSTGRES
1543 CONNECTION pgconn
1544 (PUBLICATION '{source_name}')"
1545 ))
1546 .await
1547 .unwrap();
1548 mz_client
1549 .batch_execute(&format!(
1550 "CREATE TABLE {table_name}
1551 FROM SOURCE {source_name}
1552 (REFERENCE {table_name});"
1553 ))
1554 .await
1555 .unwrap();
1556
1557 let table_name = table_name.to_string();
1558 let source_name = source_name.to_string();
1559 (
1560 pg_client,
1561 move |mz_client: &'a Client, pg_client: &'a Client| {
1562 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1563 mz_client
1564 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1565 .await
1566 .unwrap();
1567 mz_client
1568 .batch_execute("DROP CONNECTION pgconn;")
1569 .await
1570 .unwrap();
1571
1572 let _ = pg_client
1573 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1574 .await
1575 .unwrap();
1576 let _ = pg_client
1577 .execute(&format!("DROP TABLE {table_name};"), &[])
1578 .await
1579 .unwrap();
1580 });
1581 f
1582 },
1583 )
1584}
1585
1586pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1587 let current_isolation = mz_client
1588 .query_one("SHOW transaction_isolation", &[])
1589 .await
1590 .unwrap()
1591 .get::<_, String>(0);
1592 mz_client
1593 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1594 .await
1595 .unwrap();
1596 Retry::default()
1597 .retry_async(|_| async move {
1598 let rows = mz_client
1599 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1600 .await
1601 .unwrap()
1602 .get::<_, i64>(0);
1603 if rows == source_rows {
1604 Ok(())
1605 } else {
1606 Err(format!(
1607 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1608 ))
1609 }
1610 })
1611 .await
1612 .unwrap();
1613 mz_client
1614 .batch_execute(&format!(
1615 "SET transaction_isolation = '{current_isolation}'"
1616 ))
1617 .await
1618 .unwrap();
1619}
1620
1621pub fn auth_with_ws(
1623 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1624 mut options: BTreeMap<String, String>,
1625) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1626 if !options.contains_key("welcome_message") {
1627 options.insert("welcome_message".into(), "off".into());
1628 }
1629 auth_with_ws_impl(
1630 ws,
1631 Message::Text(
1632 serde_json::to_string(&WebSocketAuth::Basic {
1633 user: "materialize".into(),
1634 password: "".into(),
1635 options,
1636 })
1637 .unwrap()
1638 .into(),
1639 ),
1640 )
1641}
1642
1643pub fn auth_with_ws_impl(
1644 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1645 auth_message: Message,
1646) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1647 ws.send(auth_message)?;
1648
1649 let mut msgs = Vec::new();
1651 loop {
1652 let resp = ws.read()?;
1653 match resp {
1654 Message::Text(msg) => {
1655 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1656 match msg {
1657 WebSocketResponse::ReadyForQuery(_) => break,
1658 msg => {
1659 msgs.push(msg);
1660 }
1661 }
1662 }
1663 Message::Ping(_) => continue,
1664 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1665 Message::Close(Some(close_frame)) => {
1666 return Err(anyhow!("ws closed after auth").context(close_frame));
1667 }
1668 _ => panic!("unexpected response: {:?}", resp),
1669 }
1670 }
1671 Ok(msgs)
1672}
1673
1674pub fn make_header<H: Header>(h: H) -> HeaderMap {
1675 let mut map = HeaderMap::new();
1676 map.typed_insert(h);
1677 map
1678}
1679
1680pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1681where
1682 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1683{
1684 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1685 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1706 connector_builder.set_options(options);
1707 configure(&mut connector_builder).unwrap();
1708 MakeTlsConnector::new(connector_builder.build())
1709}
1710
1711pub struct Ca {
1713 pub dir: TempDir,
1714 pub name: X509Name,
1715 pub cert: X509,
1716 pub pkey: PKey<Private>,
1717}
1718
1719impl Ca {
1720 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1721 let dir = tempfile::tempdir()?;
1722 let rsa = Rsa::generate(2048)?;
1723 let pkey = PKey::from_rsa(rsa)?;
1724 let name = {
1725 let mut builder = X509NameBuilder::new()?;
1726 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1727 builder.build()
1728 };
1729 let cert = {
1730 let mut builder = X509::builder()?;
1731 builder.set_version(2)?;
1732 builder.set_pubkey(&pkey)?;
1733 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1734 builder.set_subject_name(&name)?;
1735 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1736 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1737 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1738 builder.sign(
1739 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1740 MessageDigest::sha256(),
1741 )?;
1742 builder.build()
1743 };
1744 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1745 Ok(Ca {
1746 dir,
1747 name,
1748 cert,
1749 pkey,
1750 })
1751 }
1752
1753 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1755 Ca::make_ca(name, None)
1756 }
1757
1758 pub fn ca_cert_path(&self) -> PathBuf {
1760 self.dir.path().join("ca.crt")
1761 }
1762
1763 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1765 Ca::make_ca(name, Some(self))
1766 }
1767
1768 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1773 self.request_cert(name, iter::empty())
1774 }
1775
1776 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1779 where
1780 I: IntoIterator<Item = IpAddr>,
1781 {
1782 let rsa = Rsa::generate(2048)?;
1783 let pkey = PKey::from_rsa(rsa)?;
1784 let subject_name = {
1785 let mut builder = X509NameBuilder::new()?;
1786 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1787 builder.build()
1788 };
1789 let cert = {
1790 let mut builder = X509::builder()?;
1791 builder.set_version(2)?;
1792 builder.set_pubkey(&pkey)?;
1793 builder.set_issuer_name(self.cert.subject_name())?;
1794 builder.set_subject_name(&subject_name)?;
1795 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1796 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1797 for ip in ips {
1798 builder.append_extension(
1799 SubjectAlternativeName::new()
1800 .ip(&ip.to_string())
1801 .build(&builder.x509v3_context(None, None))?,
1802 )?;
1803 }
1804 builder.sign(&self.pkey, MessageDigest::sha256())?;
1805 builder.build()
1806 };
1807 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1808 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1809 fs::write(&cert_path, cert.to_pem()?)?;
1810 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1811 Ok((cert_path, key_path))
1812 }
1813}