1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91 WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97#[derive(Clone)]
99pub struct TestHarness {
100 data_directory: Option<PathBuf>,
101 tls: Option<TlsCertConfig>,
102 frontegg: Option<FronteggAuthenticator>,
103 external_login_password_mz_system: Option<Password>,
104 listeners_config: ListenersConfig,
105 unsafe_mode: bool,
106 workers: usize,
107 now: NowFn,
108 seed: u32,
109 storage_usage_collection_interval: Duration,
110 storage_usage_retention_period: Option<Duration>,
111 default_cluster_replica_size: String,
112 default_cluster_replication_factor: u32,
113 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119 propagate_crashes: bool,
120 enable_tracing: bool,
121 orchestrator_tracing_cli_args: TracingCliArgs,
124 bootstrap_role: Option<String>,
125 deploy_generation: u64,
126 system_parameter_defaults: BTreeMap<String, String>,
127 internal_console_redirect_url: Option<String>,
128 metrics_registry: Option<MetricsRegistry>,
129 code_version: semver::Version,
130 capture: Option<SharedStorage>,
131 pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135 fn default() -> TestHarness {
136 TestHarness {
137 data_directory: None,
138 tls: None,
139 frontegg: None,
140 external_login_password_mz_system: None,
141 listeners_config: ListenersConfig {
142 sql: btreemap![
143 "external".to_owned() => SqlListenerConfig {
144 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145 authenticator_kind: AuthenticatorKind::None,
146 allowed_roles: AllowedRoles::Normal,
147 enable_tls: false,
148 },
149 "internal".to_owned() => SqlListenerConfig {
150 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151 authenticator_kind: AuthenticatorKind::None,
152 allowed_roles: AllowedRoles::NormalAndInternal,
153 enable_tls: false,
154 },
155 ],
156 http: btreemap![
157 "external".to_owned() => HttpListenerConfig {
158 base: BaseListenerConfig {
159 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160 authenticator_kind: AuthenticatorKind::None,
161 allowed_roles: AllowedRoles::Normal,
162 enable_tls: false,
163 },
164 routes: HttpRoutesEnabled{
165 base: true,
166 webhook: true,
167 internal: false,
168 metrics: false,
169 profiling: false,
170 mcp_agent: false,
171 mcp_developer: false,
172 console_config: true,
173 },
174 },
175 "internal".to_owned() => HttpListenerConfig {
176 base: BaseListenerConfig {
177 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
178 authenticator_kind: AuthenticatorKind::None,
179 allowed_roles: AllowedRoles::NormalAndInternal,
180 enable_tls: false,
181 },
182 routes: HttpRoutesEnabled{
183 base: true,
184 webhook: true,
185 internal: true,
186 metrics: true,
187 profiling: true,
188 mcp_agent: false,
189 mcp_developer: false,
190 console_config: true,
191 },
192 },
193 ],
194 },
195 unsafe_mode: false,
196 workers: 1,
197 now: SYSTEM_TIME.clone(),
198 seed: rand::random(),
199 storage_usage_collection_interval: Duration::from_secs(3600),
200 storage_usage_retention_period: None,
201 default_cluster_replica_size: "scale=1,workers=1".to_string(),
202 default_cluster_replication_factor: 1,
203 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
204 size: "scale=1,workers=1".to_string(),
205 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
206 },
207 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
208 size: "scale=1,workers=1".to_string(),
209 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
210 },
211 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
212 size: "scale=1,workers=1".to_string(),
213 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
214 },
215 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
216 size: "scale=1,workers=1".to_string(),
217 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
218 },
219 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
220 size: "scale=1,workers=1".to_string(),
221 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
222 },
223 propagate_crashes: false,
224 enable_tracing: false,
225 bootstrap_role: Some("materialize".into()),
226 deploy_generation: 0,
227 system_parameter_defaults: BTreeMap::from([(
230 "log_filter".to_string(),
231 "error".to_string(),
232 )]),
233 internal_console_redirect_url: None,
234 metrics_registry: None,
235 orchestrator_tracing_cli_args: TracingCliArgs {
236 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
237 ..Default::default()
238 },
239 code_version: crate::BUILD_INFO.semver_version(),
240 environment_id: EnvironmentId::for_tests(),
241 capture: None,
242 }
243 }
244}
245
246impl TestHarness {
247 pub async fn start(self) -> TestServer {
251 self.try_start().await.expect("Failed to start test Server")
252 }
253
254 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
256 self.try_start_with_trigger(tls_reload_certs)
257 .await
258 .expect("Failed to start test Server")
259 }
260
261 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
263 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
264 .await
265 }
266
267 pub async fn try_start_with_trigger(
269 self,
270 tls_reload_certs: ReloadTrigger,
271 ) -> Result<TestServer, anyhow::Error> {
272 let listeners = Listeners::new(&self).await?;
273 listeners.serve_with_trigger(self, tls_reload_certs).await
274 }
275
276 pub fn start_blocking(self) -> TestServerWithRuntime {
278 let runtime = tokio::runtime::Builder::new_multi_thread()
279 .enable_all()
280 .thread_stack_size(mz_ore::stack::STACK_SIZE)
281 .build()
282 .expect("failed to spawn runtime for test");
283 let runtime = Arc::new(runtime);
284 let server = runtime.block_on(self.start());
285 TestServerWithRuntime { runtime, server }
286 }
287
288 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
289 self.data_directory = Some(data_directory.into());
290 self
291 }
292
293 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
294 self.tls = Some(TlsCertConfig {
295 cert: cert_path.into(),
296 key: key_path.into(),
297 });
298 for (_, listener) in &mut self.listeners_config.sql {
299 listener.enable_tls = true;
300 }
301 for (_, listener) in &mut self.listeners_config.http {
302 listener.base.enable_tls = true;
303 }
304 self
305 }
306
307 pub fn unsafe_mode(mut self) -> Self {
308 self.unsafe_mode = true;
309 self
310 }
311
312 pub fn workers(mut self, workers: usize) -> Self {
313 self.workers = workers;
314 self
315 }
316
317 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
318 self.frontegg = Some(frontegg.clone());
319 let enable_tls = self.tls.is_some();
320 self.listeners_config = ListenersConfig {
321 sql: btreemap! {
322 "external".to_owned() => SqlListenerConfig {
323 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
324 authenticator_kind: AuthenticatorKind::Frontegg,
325 allowed_roles: AllowedRoles::Normal,
326 enable_tls,
327 },
328 "internal".to_owned() => SqlListenerConfig {
329 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
330 authenticator_kind: AuthenticatorKind::None,
331 allowed_roles: AllowedRoles::NormalAndInternal,
332 enable_tls: false,
333 },
334 },
335 http: btreemap! {
336 "external".to_owned() => HttpListenerConfig {
337 base: BaseListenerConfig {
338 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
339 authenticator_kind: AuthenticatorKind::Frontegg,
340 allowed_roles: AllowedRoles::Normal,
341 enable_tls,
342 },
343 routes: HttpRoutesEnabled{
344 base: true,
345 webhook: true,
346 internal: false,
347 metrics: false,
348 profiling: false,
349 mcp_agent: false,
350 mcp_developer: false,
351 console_config: true,
352 },
353 },
354 "internal".to_owned() => HttpListenerConfig {
355 base: BaseListenerConfig {
356 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
357 authenticator_kind: AuthenticatorKind::None,
358 allowed_roles: AllowedRoles::NormalAndInternal,
359 enable_tls: false,
360 },
361 routes: HttpRoutesEnabled{
362 base: true,
363 webhook: true,
364 internal: true,
365 metrics: true,
366 profiling: true,
367 mcp_agent: false,
368 mcp_developer: false,
369 console_config: true,
370 },
371 },
372 },
373 };
374 self
375 }
376
377 pub fn with_oidc_auth(
378 mut self,
379 issuer: Option<String>,
380 authentication_claim: Option<String>,
381 expected_audiences: Option<Vec<String>>,
382 external_login_password_mz_system: Option<Password>,
383 ) -> Self {
384 let enable_tls = self.tls.is_some();
385 self.listeners_config = ListenersConfig {
386 sql: btreemap! {
387 "external".to_owned() => SqlListenerConfig {
388 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
389 authenticator_kind: AuthenticatorKind::Oidc,
390 allowed_roles: AllowedRoles::NormalAndInternal,
391 enable_tls,
392 },
393 "internal".to_owned() => SqlListenerConfig {
394 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
395 authenticator_kind: AuthenticatorKind::None,
396 allowed_roles: AllowedRoles::NormalAndInternal,
397 enable_tls: false,
398 },
399 },
400 http: btreemap! {
401 "external".to_owned() => HttpListenerConfig {
402 base: BaseListenerConfig {
403 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
404 authenticator_kind: AuthenticatorKind::Oidc,
405 allowed_roles: AllowedRoles::NormalAndInternal,
406 enable_tls,
407 },
408 routes: HttpRoutesEnabled{
409 base: true,
410 webhook: true,
411 internal: false,
412 metrics: false,
413 profiling: false,
414 mcp_agent: false,
415 mcp_developer: false,
416 console_config: true,
417 },
418 },
419 "internal".to_owned() => HttpListenerConfig {
420 base: BaseListenerConfig {
421 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
422 authenticator_kind: AuthenticatorKind::None,
423 allowed_roles: AllowedRoles::NormalAndInternal,
424 enable_tls: false,
425 },
426 routes: HttpRoutesEnabled{
427 base: true,
428 webhook: true,
429 internal: true,
430 metrics: true,
431 profiling: true,
432 mcp_agent: false,
433 mcp_developer: false,
434 console_config: true,
435 },
436 },
437 },
438 };
439
440 if let Some(issuer) = issuer {
441 self.system_parameter_defaults
442 .insert("oidc_issuer".to_string(), issuer);
443 }
444
445 if let Some(authentication_claim) = authentication_claim {
446 self.system_parameter_defaults.insert(
447 "oidc_authentication_claim".to_string(),
448 authentication_claim,
449 );
450 }
451
452 if let Some(expected_audiences) = expected_audiences {
453 self.system_parameter_defaults.insert(
454 "oidc_audience".to_string(),
455 serde_json::to_string(&expected_audiences).unwrap(),
456 );
457 }
458
459 if let Some(external_login_password_mz_system) = external_login_password_mz_system {
460 self.external_login_password_mz_system = Some(external_login_password_mz_system);
461 self.system_parameter_defaults
462 .insert("enable_password_auth".to_string(), "true".to_string());
463 }
464
465 self
466 }
467
468 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
469 self.external_login_password_mz_system = Some(mz_system_password);
470 let enable_tls = self.tls.is_some();
471 self.listeners_config = ListenersConfig {
472 sql: btreemap! {
473 "external".to_owned() => SqlListenerConfig {
474 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
475 authenticator_kind: AuthenticatorKind::Password,
476 allowed_roles: AllowedRoles::NormalAndInternal,
477 enable_tls,
478 },
479 },
480 http: btreemap! {
481 "external".to_owned() => HttpListenerConfig {
482 base: BaseListenerConfig {
483 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
484 authenticator_kind: AuthenticatorKind::Password,
485 allowed_roles: AllowedRoles::NormalAndInternal,
486 enable_tls,
487 },
488 routes: HttpRoutesEnabled{
489 base: true,
490 webhook: true,
491 internal: true,
492 metrics: false,
493 profiling: true,
494 mcp_agent: false,
495 mcp_developer: false,
496 console_config: true,
497 },
498 },
499 "metrics".to_owned() => HttpListenerConfig {
500 base: BaseListenerConfig {
501 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
502 authenticator_kind: AuthenticatorKind::None,
503 allowed_roles: AllowedRoles::NormalAndInternal,
504 enable_tls: false,
505 },
506 routes: HttpRoutesEnabled{
507 base: false,
508 webhook: false,
509 internal: false,
510 metrics: true,
511 profiling: false,
512 mcp_agent: false,
513 mcp_developer: false,
514 console_config: true,
515 },
516 },
517 },
518 };
519 self
520 }
521
522 pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
523 self.external_login_password_mz_system = Some(mz_system_password);
524 let enable_tls = self.tls.is_some();
525 self.listeners_config = ListenersConfig {
526 sql: btreemap! {
527 "external".to_owned() => SqlListenerConfig {
528 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
529 authenticator_kind: AuthenticatorKind::Sasl,
530 allowed_roles: AllowedRoles::NormalAndInternal,
531 enable_tls,
532 },
533 },
534 http: btreemap! {
535 "external".to_owned() => HttpListenerConfig {
536 base: BaseListenerConfig {
537 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
538 authenticator_kind: AuthenticatorKind::Password,
539 allowed_roles: AllowedRoles::NormalAndInternal,
540 enable_tls,
541 },
542 routes: HttpRoutesEnabled{
543 base: true,
544 webhook: true,
545 internal: true,
546 metrics: false,
547 profiling: true,
548 mcp_agent: false,
549 mcp_developer: false,
550 console_config: true,
551 },
552 },
553 "metrics".to_owned() => HttpListenerConfig {
554 base: BaseListenerConfig {
555 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
556 authenticator_kind: AuthenticatorKind::None,
557 allowed_roles: AllowedRoles::NormalAndInternal,
558 enable_tls: false,
559 },
560 routes: HttpRoutesEnabled{
561 base: false,
562 webhook: false,
563 internal: false,
564 metrics: true,
565 profiling: false,
566 mcp_agent: false,
567 mcp_developer: false,
568 console_config: true,
569 },
570 },
571 },
572 };
573 self
574 }
575
576 pub fn with_now(mut self, now: NowFn) -> Self {
577 self.now = now;
578 self
579 }
580
581 pub fn with_storage_usage_collection_interval(
582 mut self,
583 storage_usage_collection_interval: Duration,
584 ) -> Self {
585 self.storage_usage_collection_interval = storage_usage_collection_interval;
586 self
587 }
588
589 pub fn with_storage_usage_retention_period(
590 mut self,
591 storage_usage_retention_period: Duration,
592 ) -> Self {
593 self.storage_usage_retention_period = Some(storage_usage_retention_period);
594 self
595 }
596
597 pub fn with_default_cluster_replica_size(
598 mut self,
599 default_cluster_replica_size: String,
600 ) -> Self {
601 self.default_cluster_replica_size = default_cluster_replica_size;
602 self
603 }
604
605 pub fn with_builtin_system_cluster_replica_size(
606 mut self,
607 builtin_system_cluster_replica_size: String,
608 ) -> Self {
609 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
610 self
611 }
612
613 pub fn with_builtin_system_cluster_replication_factor(
614 mut self,
615 builtin_system_cluster_replication_factor: u32,
616 ) -> Self {
617 self.builtin_system_cluster_config.replication_factor =
618 builtin_system_cluster_replication_factor;
619 self
620 }
621
622 pub fn with_builtin_catalog_server_cluster_replica_size(
623 mut self,
624 builtin_catalog_server_cluster_replica_size: String,
625 ) -> Self {
626 self.builtin_catalog_server_cluster_config.size =
627 builtin_catalog_server_cluster_replica_size;
628 self
629 }
630
631 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
632 self.propagate_crashes = propagate_crashes;
633 self
634 }
635
636 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
637 self.enable_tracing = enable_tracing;
638 self
639 }
640
641 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
642 self.bootstrap_role = bootstrap_role;
643 self
644 }
645
646 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
647 self.deploy_generation = deploy_generation;
648 self
649 }
650
651 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
652 self.system_parameter_defaults.insert(param, value);
653 self
654 }
655
656 pub fn with_mcp_routes(mut self, agent: bool, developer: bool) -> Self {
657 for config in self.listeners_config.http.values_mut() {
658 config.routes.mcp_agent = agent;
659 config.routes.mcp_developer = developer;
660 }
661 self
662 }
663
664 pub fn with_internal_console_redirect_url(
665 mut self,
666 internal_console_redirect_url: Option<String>,
667 ) -> Self {
668 self.internal_console_redirect_url = internal_console_redirect_url;
669 self
670 }
671
672 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
673 self.metrics_registry = Some(registry);
674 self
675 }
676
677 pub fn with_code_version(mut self, version: semver::Version) -> Self {
678 self.code_version = version;
679 self
680 }
681
682 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
683 self.capture = Some(storage);
684 self
685 }
686}
687
688pub struct Listeners {
689 pub inner: crate::Listeners,
690}
691
692impl Listeners {
693 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
694 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
695 Ok(Listeners { inner })
696 }
697
698 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
699 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
700 .await
701 }
702
703 pub async fn serve_with_trigger(
704 self,
705 config: TestHarness,
706 tls_reload_certs: ReloadTrigger,
707 ) -> Result<TestServer, anyhow::Error> {
708 let (data_directory, temp_dir) = match config.data_directory {
709 None => {
710 let temp_dir = tempfile::tempdir()?;
715 (temp_dir.path().to_path_buf(), Some(temp_dir))
716 }
717 Some(data_directory) => (data_directory, None),
718 };
719 let scratch_dir = tempfile::tempdir()?;
720 let (consensus_uri, timestamp_oracle_url) = {
721 let seed = config.seed;
722 let cockroach_url = env::var("METADATA_BACKEND_URL")
723 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
724 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
725 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
726 if let Err(err) = conn.await {
727 panic!("connection error: {}", err);
728 };
729 });
730 client
731 .batch_execute(&format!(
732 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
733 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
734 ))
735 .await?;
736 (
737 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
738 .parse()
739 .expect("invalid consensus URI"),
740 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
741 .parse()
742 .expect("invalid timestamp oracle URI"),
743 )
744 };
745 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
746 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
747 image_dir: env::current_exe()?
748 .parent()
749 .unwrap()
750 .parent()
751 .unwrap()
752 .to_path_buf(),
753 suppress_output: false,
754 environment_id: config.environment_id.to_string(),
755 secrets_dir: data_directory.join("secrets"),
756 command_wrapper: vec![],
757 propagate_crashes: config.propagate_crashes,
758 tcp_proxy: None,
759 scratch_directory: scratch_dir.path().to_path_buf(),
760 })
761 .await?;
762 let orchestrator = Arc::new(orchestrator);
763 let persist_now = SYSTEM_TIME.clone();
766 let dyncfgs = mz_dyncfgs::all_dyncfgs();
767
768 let mut updates = ConfigUpdates::default();
769 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
772 updates.apply(&dyncfgs);
773
774 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
775 persist_cfg.build_version = config.code_version;
776 persist_cfg.set_rollup_threshold(5);
778
779 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
780 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
781 let persist_pubsub_tcp_listener =
782 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
783 .await
784 .expect("pubsub addr binding");
785 let persist_pubsub_server_port = persist_pubsub_tcp_listener
786 .local_addr()
787 .expect("pubsub addr has local addr")
788 .port();
789
790 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
792 persist_pubsub_server
793 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
794 .await
795 .expect("success")
796 });
797 let persist_clients =
798 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
799 let persist_clients = Arc::new(persist_clients);
800
801 let secrets_controller = Arc::clone(&orchestrator);
802 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
803 let orchestrator = Arc::new(TracingOrchestrator::new(
804 orchestrator,
805 config.orchestrator_tracing_cli_args,
806 ));
807 let tracing_handle = if config.enable_tracing {
808 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
809 service_name: "environmentd",
810 stderr_log: StderrLogConfig {
811 format: StderrLogFormat::Json,
812 filter: EnvFilter::default(),
813 },
814 opentelemetry: Some(OpenTelemetryConfig {
815 endpoint: "http://fake_address_for_testing:8080".to_string(),
816 headers: http::HeaderMap::new(),
817 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
818 resource: opentelemetry_sdk::resource::Resource::builder().build(),
819 max_batch_queue_size: 2048,
820 max_export_batch_size: 512,
821 max_concurrent_exports: 1,
822 batch_scheduled_delay: Duration::from_millis(5000),
823 max_export_timeout: Duration::from_secs(30),
824 }),
825 tokio_console: None,
826 sentry: None,
827 build_version: crate::BUILD_INFO.version,
828 build_sha: crate::BUILD_INFO.sha,
829 registry: metrics_registry.clone(),
830 capture: config.capture,
831 };
832 mz_ore::tracing::configure(config).await?
833 } else {
834 TracingHandle::disabled()
835 };
836 let host_name = format!(
837 "localhost:{}",
838 self.inner.http["external"].handle.local_addr.port()
839 );
840 let catalog_config = CatalogConfig {
841 persist_clients: Arc::clone(&persist_clients),
842 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
843 };
844
845 let inner = self
846 .inner
847 .serve(crate::Config {
848 catalog_config,
849 timestamp_oracle_url: Some(timestamp_oracle_url),
850 controller: ControllerConfig {
851 build_info: &crate::BUILD_INFO,
852 orchestrator,
853 clusterd_image: "clusterd".into(),
854 init_container_image: None,
855 deploy_generation: config.deploy_generation,
856 persist_location: PersistLocation {
857 blob_uri: format!("file://{}/persist/blob", data_directory.display())
858 .parse()
859 .expect("invalid blob URI"),
860 consensus_uri,
861 },
862 persist_clients,
863 now: config.now.clone(),
864 metrics_registry: metrics_registry.clone(),
865 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
866 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
867 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
868 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
869 secrets_reader_kubernetes_context: None,
870 secrets_reader_aws_prefix: None,
871 secrets_reader_name_prefix: None,
872 },
873 connection_context,
874 replica_http_locator: Default::default(),
875 },
876 secrets_controller,
877 cloud_resource_controller: None,
878 tls: config.tls,
879 frontegg: config.frontegg,
880 unsafe_mode: config.unsafe_mode,
881 all_features: false,
882 metrics_registry: metrics_registry.clone(),
883 now: config.now,
884 environment_id: config.environment_id,
885 cors_allowed_origin: AllowOrigin::list([]),
886 cors_allowed_origin_list: Vec::new(),
887 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
888 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
889 bootstrap_default_cluster_replication_factor: config
890 .default_cluster_replication_factor,
891 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
892 bootstrap_builtin_catalog_server_cluster_config: config
893 .builtin_catalog_server_cluster_config,
894 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
895 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
896 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
897 system_parameter_defaults: config.system_parameter_defaults,
898 availability_zones: Default::default(),
899 tracing_handle,
900 storage_usage_collection_interval: config.storage_usage_collection_interval,
901 storage_usage_retention_period: config.storage_usage_retention_period,
902 segment_api_key: None,
903 segment_client_side: false,
904 test_only_dummy_segment_client: false,
905 egress_addresses: vec![],
906 aws_account_id: None,
907 aws_privatelink_availability_zones: None,
908 launchdarkly_sdk_key: None,
909 launchdarkly_key_map: Default::default(),
910 config_sync_file_path: None,
911 config_sync_timeout: Duration::from_secs(30),
912 config_sync_loop_interval: None,
913 bootstrap_role: config.bootstrap_role,
914 http_host_name: Some(host_name),
915 internal_console_redirect_url: config.internal_console_redirect_url,
916 tls_reload_certs,
917 helm_chart_version: None,
918 license_key: ValidatedLicenseKey::for_tests(),
919 external_login_password_mz_system: config.external_login_password_mz_system,
920 force_builtin_schema_migration: None,
921 })
922 .await?;
923
924 Ok(TestServer {
925 inner,
926 metrics_registry,
927 _temp_dir: temp_dir,
928 _scratch_dir: scratch_dir,
929 })
930 }
931}
932
933pub struct TestServer {
935 pub inner: crate::Server,
936 pub metrics_registry: MetricsRegistry,
937 _temp_dir: Option<TempDir>,
939 _scratch_dir: TempDir,
940}
941
942impl TestServer {
943 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
944 ConnectBuilder::new(self).no_tls()
945 }
946
947 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
948 let internal_client = self.connect().internal().await.unwrap();
949
950 for flag in flags {
951 internal_client
952 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
953 .await
954 .unwrap();
955 }
956 }
957
958 pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
959 let internal_client = self.connect().internal().await.unwrap();
960
961 for flag in flags {
962 internal_client
963 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
964 .await
965 .unwrap();
966 }
967 }
968
969 pub fn ws_addr(&self) -> Uri {
970 format!(
971 "ws://{}/api/experimental/sql",
972 self.inner.http_listener_handles["external"].local_addr
973 )
974 .parse()
975 .unwrap()
976 }
977
978 pub fn internal_ws_addr(&self) -> Uri {
979 format!(
980 "ws://{}/api/experimental/sql",
981 self.inner.http_listener_handles["internal"].local_addr
982 )
983 .parse()
984 .unwrap()
985 }
986
987 pub fn http_local_addr(&self) -> SocketAddr {
988 self.inner.http_listener_handles["external"].local_addr
989 }
990
991 pub fn internal_http_local_addr(&self) -> SocketAddr {
992 self.inner.http_listener_handles["internal"].local_addr
993 }
994
995 pub fn sql_local_addr(&self) -> SocketAddr {
996 self.inner.sql_listener_handles["external"].local_addr
997 }
998
999 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1000 self.inner.sql_listener_handles["internal"].local_addr
1001 }
1002}
1003
1004pub struct ConnectBuilder<'s, T, H> {
1008 server: &'s TestServer,
1010
1011 pg_config: tokio_postgres::Config,
1013 port: u16,
1015 tls: T,
1017
1018 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
1020
1021 _with_handle: H,
1023}
1024
1025impl<'s> ConnectBuilder<'s, (), NoHandle> {
1026 fn new(server: &'s TestServer) -> Self {
1027 let mut pg_config = tokio_postgres::Config::new();
1028 pg_config
1029 .host(&Ipv4Addr::LOCALHOST.to_string())
1030 .user("materialize")
1031 .options("--welcome_message=off")
1032 .application_name("environmentd_test_framework");
1033
1034 ConnectBuilder {
1035 server,
1036 pg_config,
1037 port: server.sql_local_addr().port(),
1038 tls: (),
1039 notice_callback: None,
1040 _with_handle: NoHandle,
1041 }
1042 }
1043}
1044
1045impl<'s, T, H> ConnectBuilder<'s, T, H> {
1046 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
1050 ConnectBuilder {
1051 server: self.server,
1052 pg_config: self.pg_config,
1053 port: self.port,
1054 tls: postgres::NoTls,
1055 notice_callback: self.notice_callback,
1056 _with_handle: self._with_handle,
1057 }
1058 }
1059
1060 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1062 where
1063 Tls: MakeTlsConnect<Socket> + Send + 'static,
1064 Tls::TlsConnect: Send,
1065 Tls::Stream: Send,
1066 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1067 {
1068 ConnectBuilder {
1069 server: self.server,
1070 pg_config: self.pg_config,
1071 port: self.port,
1072 tls,
1073 notice_callback: self.notice_callback,
1074 _with_handle: self._with_handle,
1075 }
1076 }
1077
1078 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1080 self.pg_config = pg_config;
1081 self
1082 }
1083
1084 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1086 self.pg_config.ssl_mode(mode);
1087 self
1088 }
1089
1090 pub fn user(mut self, user: &str) -> Self {
1092 self.pg_config.user(user);
1093 self
1094 }
1095
1096 pub fn password(mut self, password: &str) -> Self {
1098 self.pg_config.password(password);
1099 self
1100 }
1101
1102 pub fn application_name(mut self, application_name: &str) -> Self {
1104 self.pg_config.application_name(application_name);
1105 self
1106 }
1107
1108 pub fn dbname(mut self, dbname: &str) -> Self {
1110 self.pg_config.dbname(dbname);
1111 self
1112 }
1113
1114 pub fn options(mut self, options: &str) -> Self {
1116 self.pg_config.options(options);
1117 self
1118 }
1119
1120 pub fn internal(mut self) -> Self {
1125 self.port = self.server.internal_sql_local_addr().port();
1126 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1127 self
1128 }
1129
1130 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1132 ConnectBuilder {
1133 notice_callback: Some(Box::new(callback)),
1134 ..self
1135 }
1136 }
1137
1138 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1141 ConnectBuilder {
1142 server: self.server,
1143 pg_config: self.pg_config,
1144 port: self.port,
1145 tls: self.tls,
1146 notice_callback: self.notice_callback,
1147 _with_handle: WithHandle,
1148 }
1149 }
1150
1151 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1153 &self.pg_config
1154 }
1155}
1156
1157pub trait IncludeHandle: Send {
1160 type Output;
1161 fn transform_result(
1162 client: tokio_postgres::Client,
1163 handle: mz_ore::task::JoinHandle<()>,
1164 ) -> Self::Output;
1165}
1166
1167pub struct NoHandle;
1170impl IncludeHandle for NoHandle {
1171 type Output = tokio_postgres::Client;
1172 fn transform_result(
1173 client: tokio_postgres::Client,
1174 _handle: mz_ore::task::JoinHandle<()>,
1175 ) -> Self::Output {
1176 client
1177 }
1178}
1179
1180pub struct WithHandle;
1183impl IncludeHandle for WithHandle {
1184 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1185 fn transform_result(
1186 client: tokio_postgres::Client,
1187 handle: mz_ore::task::JoinHandle<()>,
1188 ) -> Self::Output {
1189 (client, handle)
1190 }
1191}
1192
1193impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1194where
1195 T: MakeTlsConnect<Socket> + Send + 'static,
1196 T::TlsConnect: Send,
1197 T::Stream: Send,
1198 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1199 H: IncludeHandle,
1200{
1201 type Output = Result<H::Output, postgres::Error>;
1202 type IntoFuture = BoxFuture<'static, Self::Output>;
1203
1204 fn into_future(mut self) -> Self::IntoFuture {
1205 Box::pin(async move {
1206 assert!(
1207 self.pg_config.get_ports().is_empty(),
1208 "specifying multiple ports is not supported"
1209 );
1210 self.pg_config.port(self.port);
1211
1212 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1213 let mut notice_callback = self.notice_callback.take();
1214
1215 let handle = task::spawn(|| "connect", async move {
1216 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1217 match msg {
1218 Ok(AsyncMessage::Notice(notice)) => {
1219 if let Some(callback) = notice_callback.as_mut() {
1220 callback(notice);
1221 }
1222 }
1223 Ok(msg) => {
1224 tracing::debug!(?msg, "Dropping message from database");
1225 }
1226 Err(e) => {
1227 tracing::info!("connection error: {e}");
1232 break;
1233 }
1234 }
1235 }
1236 tracing::info!("connection closed");
1237 });
1238
1239 let output = H::transform_result(client, handle);
1240 Ok(output)
1241 })
1242 }
1243}
1244
1245pub struct TestServerWithRuntime {
1250 server: TestServer,
1251 runtime: Arc<Runtime>,
1252}
1253
1254impl TestServerWithRuntime {
1255 pub fn runtime(&self) -> &Arc<Runtime> {
1259 &self.runtime
1260 }
1261
1262 pub fn inner(&self) -> &crate::Server {
1264 &self.server.inner
1265 }
1266
1267 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1269 where
1270 T: MakeTlsConnect<Socket> + Send + 'static,
1271 T::TlsConnect: Send,
1272 T::Stream: Send,
1273 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1274 {
1275 self.pg_config().connect(tls)
1276 }
1277
1278 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1280 where
1281 T: MakeTlsConnect<Socket> + Send + 'static,
1282 T::TlsConnect: Send,
1283 T::Stream: Send,
1284 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1285 {
1286 Ok(self.pg_config_internal().connect(tls)?)
1287 }
1288
1289 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1291 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1292
1293 for flag in flags {
1294 internal_client
1295 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1296 .unwrap();
1297 }
1298 }
1299
1300 pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1302 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1303
1304 for flag in flags {
1305 internal_client
1306 .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1307 .unwrap();
1308 }
1309 }
1310
1311 pub fn pg_config(&self) -> postgres::Config {
1314 let local_addr = self.server.sql_local_addr();
1315 let mut config = postgres::Config::new();
1316 config
1317 .host(&Ipv4Addr::LOCALHOST.to_string())
1318 .port(local_addr.port())
1319 .user("materialize")
1320 .options("--welcome_message=off");
1321 config
1322 }
1323
1324 pub fn pg_config_internal(&self) -> postgres::Config {
1327 let local_addr = self.server.internal_sql_local_addr();
1328 let mut config = postgres::Config::new();
1329 config
1330 .host(&Ipv4Addr::LOCALHOST.to_string())
1331 .port(local_addr.port())
1332 .user("mz_system")
1333 .options("--welcome_message=off");
1334 config
1335 }
1336
1337 pub fn ws_addr(&self) -> Uri {
1338 self.server.ws_addr()
1339 }
1340
1341 pub fn internal_ws_addr(&self) -> Uri {
1342 self.server.internal_ws_addr()
1343 }
1344
1345 pub fn http_local_addr(&self) -> SocketAddr {
1346 self.server.http_local_addr()
1347 }
1348
1349 pub fn internal_http_local_addr(&self) -> SocketAddr {
1350 self.server.internal_http_local_addr()
1351 }
1352
1353 pub fn sql_local_addr(&self) -> SocketAddr {
1354 self.server.sql_local_addr()
1355 }
1356
1357 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1358 self.server.internal_sql_local_addr()
1359 }
1360
1361 pub fn metrics_registry(&self) -> &MetricsRegistry {
1363 &self.server.metrics_registry
1364 }
1365}
1366
1367#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1368pub struct MzTimestamp(pub u64);
1369
1370impl<'a> FromSql<'a> for MzTimestamp {
1371 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1372 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1373 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1374 }
1375
1376 fn accepts(ty: &Type) -> bool {
1377 mz_pgrepr::Numeric::accepts(ty)
1378 }
1379}
1380
1381pub trait PostgresErrorExt {
1382 fn unwrap_db_error(self) -> DbError;
1383}
1384
1385impl PostgresErrorExt for postgres::Error {
1386 fn unwrap_db_error(self) -> DbError {
1387 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1388 Some(e) => e.clone(),
1389 None => panic!("expected DbError, but got: {:?}", self),
1390 }
1391 }
1392}
1393
1394impl<T, E> PostgresErrorExt for Result<T, E>
1395where
1396 E: PostgresErrorExt,
1397{
1398 fn unwrap_db_error(self) -> DbError {
1399 match self {
1400 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1401 Err(e) => e.unwrap_db_error(),
1402 }
1403 }
1404}
1405
1406pub async fn insert_with_deterministic_timestamps(
1410 table: &'static str,
1411 values: &'static str,
1412 server: &TestServer,
1413 now: Arc<std::sync::Mutex<EpochMillis>>,
1414) -> Result<(), Box<dyn Error>> {
1415 let client_write = server.connect().await?;
1416 let client_read = server.connect().await?;
1417
1418 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1419
1420 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1421
1422 let write_future = client_write.execute(&insert_query, &[]);
1423 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1424
1425 let mut write_future = std::pin::pin!(write_future);
1426 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1427
1428 loop {
1431 tokio::select! {
1432 _ = (&mut write_future) => return Ok(()),
1433 _ = timestamp_interval.tick() => {
1434 current_ts += 1;
1435 *now.lock().expect("lock poisoned") = current_ts;
1436 }
1437 };
1438 }
1439}
1440
1441pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1442 try_get_explain_timestamp(from_suffix, client)
1443 .await
1444 .unwrap()
1445}
1446
1447pub async fn try_get_explain_timestamp(
1448 from_suffix: &str,
1449 client: &Client,
1450) -> Result<EpochMillis, anyhow::Error> {
1451 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1452 let ts = det.determination.timestamp_context.timestamp_or_default();
1453 Ok(ts.into())
1454}
1455
1456pub async fn get_explain_timestamp_determination(
1457 from_suffix: &str,
1458 client: &Client,
1459) -> Result<TimestampExplanation, anyhow::Error> {
1460 let row = client
1461 .query_one(
1462 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1463 &[],
1464 )
1465 .await?;
1466 let explain: String = row.get(0);
1467 Ok(serde_json::from_str(&explain).unwrap())
1468}
1469
1470pub async fn create_postgres_source_with_table<'a>(
1478 server: &TestServer,
1479 mz_client: &Client,
1480 table_name: &str,
1481 table_schema: &str,
1482 source_name: &str,
1483) -> (
1484 Client,
1485 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1486) {
1487 server
1488 .enable_feature_flags(&["enable_create_table_from_source"])
1489 .await;
1490
1491 let postgres_url = env::var("POSTGRES_URL")
1492 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1493 .unwrap();
1494
1495 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1496 .await
1497 .unwrap();
1498
1499 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1500 let user = pg_config.get_user().unwrap_or("postgres");
1501 let db_name = pg_config.get_dbname().unwrap_or(user);
1502 let ports = pg_config.get_ports();
1503 let port = if ports.is_empty() { 5432 } else { ports[0] };
1504 let hosts = pg_config.get_hosts();
1505 let host = if hosts.is_empty() {
1506 "localhost".to_string()
1507 } else {
1508 match &hosts[0] {
1509 Host::Tcp(host) => host.to_string(),
1510 Host::Unix(host) => host.to_str().unwrap().to_string(),
1511 }
1512 };
1513 let password = pg_config.get_password();
1514
1515 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1516 if let Err(e) = connection.await {
1517 panic!("connection error: {}", e);
1518 }
1519 });
1520
1521 let _ = pg_client
1523 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1524 .await
1525 .unwrap();
1526 let _ = pg_client
1527 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1528 .await
1529 .unwrap();
1530 let _ = pg_client
1531 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1532 .await
1533 .unwrap();
1534 let _ = pg_client
1535 .execute(
1536 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1537 &[],
1538 )
1539 .await
1540 .unwrap();
1541 let _ = pg_client
1542 .execute(
1543 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1544 &[],
1545 )
1546 .await
1547 .unwrap();
1548
1549 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1551 if let Some(password) = password {
1552 let password = std::str::from_utf8(password).unwrap();
1553 mz_client
1554 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1555 .await
1556 .unwrap();
1557 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1558 }
1559 mz_client
1560 .batch_execute(&format!(
1561 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1562 ))
1563 .await
1564 .unwrap();
1565 mz_client
1566 .batch_execute(&format!(
1567 "CREATE SOURCE {source_name}
1568 FROM POSTGRES
1569 CONNECTION pgconn
1570 (PUBLICATION '{source_name}')"
1571 ))
1572 .await
1573 .unwrap();
1574 mz_client
1575 .batch_execute(&format!(
1576 "CREATE TABLE {table_name}
1577 FROM SOURCE {source_name}
1578 (REFERENCE {table_name});"
1579 ))
1580 .await
1581 .unwrap();
1582
1583 let table_name = table_name.to_string();
1584 let source_name = source_name.to_string();
1585 (
1586 pg_client,
1587 move |mz_client: &'a Client, pg_client: &'a Client| {
1588 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1589 mz_client
1590 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1591 .await
1592 .unwrap();
1593 mz_client
1594 .batch_execute("DROP CONNECTION pgconn;")
1595 .await
1596 .unwrap();
1597
1598 let _ = pg_client
1599 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1600 .await
1601 .unwrap();
1602 let _ = pg_client
1603 .execute(&format!("DROP TABLE {table_name};"), &[])
1604 .await
1605 .unwrap();
1606 });
1607 f
1608 },
1609 )
1610}
1611
1612pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1613 let current_isolation = mz_client
1614 .query_one("SHOW transaction_isolation", &[])
1615 .await
1616 .unwrap()
1617 .get::<_, String>(0);
1618 mz_client
1619 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1620 .await
1621 .unwrap();
1622 Retry::default()
1623 .retry_async(|_| async move {
1624 let rows = mz_client
1625 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1626 .await
1627 .unwrap()
1628 .get::<_, i64>(0);
1629 if rows == source_rows {
1630 Ok(())
1631 } else {
1632 Err(format!(
1633 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1634 ))
1635 }
1636 })
1637 .await
1638 .unwrap();
1639 mz_client
1640 .batch_execute(&format!(
1641 "SET transaction_isolation = '{current_isolation}'"
1642 ))
1643 .await
1644 .unwrap();
1645}
1646
1647pub fn auth_with_ws(
1649 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1650 mut options: BTreeMap<String, String>,
1651) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1652 if !options.contains_key("welcome_message") {
1653 options.insert("welcome_message".into(), "off".into());
1654 }
1655 auth_with_ws_impl(
1656 ws,
1657 Message::Text(
1658 serde_json::to_string(&WebSocketAuth::Basic {
1659 user: "materialize".into(),
1660 password: "".into(),
1661 options,
1662 })
1663 .unwrap()
1664 .into(),
1665 ),
1666 )
1667}
1668
1669pub fn auth_with_ws_impl(
1670 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1671 auth_message: Message,
1672) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1673 ws.send(auth_message)?;
1674
1675 let mut msgs = Vec::new();
1677 loop {
1678 let resp = ws.read()?;
1679 match resp {
1680 Message::Text(msg) => {
1681 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1682 match msg {
1683 WebSocketResponse::ReadyForQuery(_) => break,
1684 msg => {
1685 msgs.push(msg);
1686 }
1687 }
1688 }
1689 Message::Ping(_) => continue,
1690 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1691 Message::Close(Some(close_frame)) => {
1692 return Err(anyhow!("ws closed after auth").context(close_frame));
1693 }
1694 _ => panic!("unexpected response: {:?}", resp),
1695 }
1696 }
1697 Ok(msgs)
1698}
1699
1700pub fn make_header<H: Header>(h: H) -> HeaderMap {
1701 let mut map = HeaderMap::new();
1702 map.typed_insert(h);
1703 map
1704}
1705
1706pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1707where
1708 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1709{
1710 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1711 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1732 connector_builder.set_options(options);
1733 configure(&mut connector_builder).unwrap();
1734 MakeTlsConnector::new(connector_builder.build())
1735}
1736
1737pub struct Ca {
1739 pub dir: TempDir,
1740 pub name: X509Name,
1741 pub cert: X509,
1742 pub pkey: PKey<Private>,
1743}
1744
1745impl Ca {
1746 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1747 let dir = tempfile::tempdir()?;
1748 let rsa = Rsa::generate(2048)?;
1749 let pkey = PKey::from_rsa(rsa)?;
1750 let name = {
1751 let mut builder = X509NameBuilder::new()?;
1752 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1753 builder.build()
1754 };
1755 let cert = {
1756 let mut builder = X509::builder()?;
1757 builder.set_version(2)?;
1758 builder.set_pubkey(&pkey)?;
1759 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1760 builder.set_subject_name(&name)?;
1761 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1762 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1763 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1764 builder.sign(
1765 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1766 MessageDigest::sha256(),
1767 )?;
1768 builder.build()
1769 };
1770 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1771 Ok(Ca {
1772 dir,
1773 name,
1774 cert,
1775 pkey,
1776 })
1777 }
1778
1779 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1781 Ca::make_ca(name, None)
1782 }
1783
1784 pub fn ca_cert_path(&self) -> PathBuf {
1786 self.dir.path().join("ca.crt")
1787 }
1788
1789 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1791 Ca::make_ca(name, Some(self))
1792 }
1793
1794 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1799 self.request_cert(name, iter::empty())
1800 }
1801
1802 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1805 where
1806 I: IntoIterator<Item = IpAddr>,
1807 {
1808 let rsa = Rsa::generate(2048)?;
1809 let pkey = PKey::from_rsa(rsa)?;
1810 let subject_name = {
1811 let mut builder = X509NameBuilder::new()?;
1812 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1813 builder.build()
1814 };
1815 let cert = {
1816 let mut builder = X509::builder()?;
1817 builder.set_version(2)?;
1818 builder.set_pubkey(&pkey)?;
1819 builder.set_issuer_name(self.cert.subject_name())?;
1820 builder.set_subject_name(&subject_name)?;
1821 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1822 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1823 for ip in ips {
1824 builder.append_extension(
1825 SubjectAlternativeName::new()
1826 .ip(&ip.to_string())
1827 .build(&builder.x509v3_context(None, None))?,
1828 )?;
1829 }
1830 builder.sign(&self.pkey, MessageDigest::sha256())?;
1831 builder.build()
1832 };
1833 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1834 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1835 fs::write(&cert_path, cert.to_pem()?)?;
1836 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1837 Ok((cert_path, key_path))
1838 }
1839}