1use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31 ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32 CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33 SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48 OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingGuard,
49 TracingHandle,
50};
51use mz_persist_client::PersistLocation;
52use mz_persist_client::cache::PersistClientCache;
53use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
54use mz_persist_client::rpc::PersistGrpcPubSubServer;
55use mz_secrets::SecretsController;
56use mz_server_core::listeners::{
57 AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
58};
59use mz_server_core::{ReloadTrigger, TlsCertConfig};
60use mz_sql::catalog::EnvironmentId;
61use mz_storage_types::connections::ConnectionContext;
62use mz_tracing::CloneableEnvFilter;
63use openssl::asn1::Asn1Time;
64use openssl::error::ErrorStack;
65use openssl::hash::MessageDigest;
66use openssl::nid::Nid;
67use openssl::pkey::{PKey, Private};
68use openssl::rsa::Rsa;
69use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
70use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
71use openssl::x509::{X509, X509Name, X509NameBuilder};
72use postgres::error::DbError;
73use postgres::tls::{MakeTlsConnect, TlsConnect};
74use postgres::types::{FromSql, Type};
75use postgres::{NoTls, Socket};
76use postgres_openssl::MakeTlsConnector;
77use tempfile::TempDir;
78use tokio::net::TcpListener;
79use tokio::runtime::Runtime;
80use tokio_postgres::config::{Host, SslMode};
81use tokio_postgres::{AsyncMessage, Client};
82use tokio_stream::wrappers::TcpListenerStream;
83use tower_http::cors::AllowOrigin;
84use tracing::Level;
85use tracing_capture::SharedStorage;
86use tracing_subscriber::EnvFilter;
87use tungstenite::stream::MaybeTlsStream;
88use tungstenite::{Message, WebSocket};
89
90use crate::{
91 CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
92 WebSocketAuth, WebSocketResponse,
93};
94
95pub static KAFKA_ADDRS: LazyLock<String> =
96 LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
97
98#[derive(Clone)]
100pub struct TestHarness {
101 data_directory: Option<PathBuf>,
102 tls: Option<TlsCertConfig>,
103 frontegg: Option<FronteggAuthenticator>,
104 external_login_password_mz_system: Option<Password>,
105 listeners_config: ListenersConfig,
106 unsafe_mode: bool,
107 workers: usize,
108 now: NowFn,
109 seed: u32,
110 storage_usage_collection_interval: Duration,
111 storage_usage_retention_period: Option<Duration>,
112 default_cluster_replica_size: String,
113 default_cluster_replication_factor: u32,
114 builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
115 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
116 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
117 builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
118 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
119
120 propagate_crashes: bool,
121 enable_tracing: bool,
122 orchestrator_tracing_cli_args: TracingCliArgs,
125 bootstrap_role: Option<String>,
126 deploy_generation: u64,
127 system_parameter_defaults: BTreeMap<String, String>,
128 internal_console_redirect_url: Option<String>,
129 metrics_registry: Option<MetricsRegistry>,
130 code_version: semver::Version,
131 capture: Option<SharedStorage>,
132 pub environment_id: EnvironmentId,
133}
134
135impl Default for TestHarness {
136 fn default() -> TestHarness {
137 TestHarness {
138 data_directory: None,
139 tls: None,
140 frontegg: None,
141 external_login_password_mz_system: None,
142 listeners_config: ListenersConfig {
143 sql: btreemap![
144 "external".to_owned() => SqlListenerConfig {
145 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
146 authenticator_kind: AuthenticatorKind::None,
147 allowed_roles: AllowedRoles::Normal,
148 enable_tls: false,
149 },
150 "internal".to_owned() => SqlListenerConfig {
151 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
152 authenticator_kind: AuthenticatorKind::None,
153 allowed_roles: AllowedRoles::NormalAndInternal,
154 enable_tls: false,
155 },
156 ],
157 http: btreemap![
158 "external".to_owned() => HttpListenerConfig {
159 base: BaseListenerConfig {
160 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
161 authenticator_kind: AuthenticatorKind::None,
162 allowed_roles: AllowedRoles::Normal,
163 enable_tls: false,
164 },
165 routes: HttpRoutesEnabled{
166 base: true,
167 webhook: true,
168 internal: false,
169 metrics: false,
170 profiling: false,
171 },
172 },
173 "internal".to_owned() => HttpListenerConfig {
174 base: BaseListenerConfig {
175 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
176 authenticator_kind: AuthenticatorKind::None,
177 allowed_roles: AllowedRoles::NormalAndInternal,
178 enable_tls: false,
179 },
180 routes: HttpRoutesEnabled{
181 base: true,
182 webhook: true,
183 internal: true,
184 metrics: true,
185 profiling: true,
186 },
187 },
188 ],
189 },
190 unsafe_mode: false,
191 workers: 1,
192 now: SYSTEM_TIME.clone(),
193 seed: rand::random(),
194 storage_usage_collection_interval: Duration::from_secs(3600),
195 storage_usage_retention_period: None,
196 default_cluster_replica_size: "scale=1,workers=1".to_string(),
197 default_cluster_replication_factor: 1,
198 builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
199 size: "scale=1,workers=1".to_string(),
200 replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
201 },
202 builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
203 size: "scale=1,workers=1".to_string(),
204 replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
205 },
206 builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
207 size: "scale=1,workers=1".to_string(),
208 replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
209 },
210 builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
211 size: "scale=1,workers=1".to_string(),
212 replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
213 },
214 builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
215 size: "scale=1,workers=1".to_string(),
216 replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
217 },
218 propagate_crashes: false,
219 enable_tracing: false,
220 bootstrap_role: Some("materialize".into()),
221 deploy_generation: 0,
222 system_parameter_defaults: BTreeMap::from([(
225 "log_filter".to_string(),
226 "error".to_string(),
227 )]),
228 internal_console_redirect_url: None,
229 metrics_registry: None,
230 orchestrator_tracing_cli_args: TracingCliArgs {
231 startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
232 ..Default::default()
233 },
234 code_version: crate::BUILD_INFO.semver_version(),
235 environment_id: EnvironmentId::for_tests(),
236 capture: None,
237 }
238 }
239}
240
241impl TestHarness {
242 pub async fn start(self) -> TestServer {
246 self.try_start().await.expect("Failed to start test Server")
247 }
248
249 pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
251 self.try_start_with_trigger(tls_reload_certs)
252 .await
253 .expect("Failed to start test Server")
254 }
255
256 pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
258 self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
259 .await
260 }
261
262 pub async fn try_start_with_trigger(
264 self,
265 tls_reload_certs: ReloadTrigger,
266 ) -> Result<TestServer, anyhow::Error> {
267 let listeners = Listeners::new(&self).await?;
268 listeners.serve_with_trigger(self, tls_reload_certs).await
269 }
270
271 pub fn start_blocking(self) -> TestServerWithRuntime {
273 stacker::grow(mz_ore::stack::STACK_SIZE, || {
274 let runtime = Runtime::new().expect("failed to spawn runtime for test");
275 let runtime = Arc::new(runtime);
276 let server = runtime.block_on(self.start());
277 TestServerWithRuntime { runtime, server }
278 })
279 }
280
281 pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
282 self.data_directory = Some(data_directory.into());
283 self
284 }
285
286 pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
287 self.tls = Some(TlsCertConfig {
288 cert: cert_path.into(),
289 key: key_path.into(),
290 });
291 for (_, listener) in &mut self.listeners_config.sql {
292 listener.enable_tls = true;
293 }
294 for (_, listener) in &mut self.listeners_config.http {
295 listener.base.enable_tls = true;
296 }
297 self
298 }
299
300 pub fn unsafe_mode(mut self) -> Self {
301 self.unsafe_mode = true;
302 self
303 }
304
305 pub fn workers(mut self, workers: usize) -> Self {
306 self.workers = workers;
307 self
308 }
309
310 pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
311 self.frontegg = Some(frontegg.clone());
312 let enable_tls = self.tls.is_some();
313 self.listeners_config = ListenersConfig {
314 sql: btreemap! {
315 "external".to_owned() => SqlListenerConfig {
316 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
317 authenticator_kind: AuthenticatorKind::Frontegg,
318 allowed_roles: AllowedRoles::Normal,
319 enable_tls,
320 },
321 "internal".to_owned() => SqlListenerConfig {
322 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
323 authenticator_kind: AuthenticatorKind::None,
324 allowed_roles: AllowedRoles::NormalAndInternal,
325 enable_tls: false,
326 },
327 },
328 http: btreemap! {
329 "external".to_owned() => HttpListenerConfig {
330 base: BaseListenerConfig {
331 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
332 authenticator_kind: AuthenticatorKind::Frontegg,
333 allowed_roles: AllowedRoles::Normal,
334 enable_tls,
335 },
336 routes: HttpRoutesEnabled{
337 base: true,
338 webhook: true,
339 internal: false,
340 metrics: false,
341 profiling: false,
342 },
343 },
344 "internal".to_owned() => HttpListenerConfig {
345 base: BaseListenerConfig {
346 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
347 authenticator_kind: AuthenticatorKind::None,
348 allowed_roles: AllowedRoles::NormalAndInternal,
349 enable_tls: false,
350 },
351 routes: HttpRoutesEnabled{
352 base: true,
353 webhook: true,
354 internal: true,
355 metrics: true,
356 profiling: true,
357 },
358 },
359 },
360 };
361 self
362 }
363
364 pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
365 self.external_login_password_mz_system = Some(mz_system_password);
366 let enable_tls = self.tls.is_some();
367 self.listeners_config = ListenersConfig {
368 sql: btreemap! {
369 "external".to_owned() => SqlListenerConfig {
370 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
371 authenticator_kind: AuthenticatorKind::Password,
372 allowed_roles: AllowedRoles::NormalAndInternal,
373 enable_tls,
374 },
375 },
376 http: btreemap! {
377 "external".to_owned() => HttpListenerConfig {
378 base: BaseListenerConfig {
379 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
380 authenticator_kind: AuthenticatorKind::Password,
381 allowed_roles: AllowedRoles::NormalAndInternal,
382 enable_tls,
383 },
384 routes: HttpRoutesEnabled{
385 base: true,
386 webhook: true,
387 internal: true,
388 metrics: false,
389 profiling: true,
390 },
391 },
392 "metrics".to_owned() => HttpListenerConfig {
393 base: BaseListenerConfig {
394 addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
395 authenticator_kind: AuthenticatorKind::None,
396 allowed_roles: AllowedRoles::NormalAndInternal,
397 enable_tls: false,
398 },
399 routes: HttpRoutesEnabled{
400 base: false,
401 webhook: false,
402 internal: false,
403 metrics: true,
404 profiling: false,
405 },
406 },
407 },
408 };
409 self
410 }
411
412 pub fn with_now(mut self, now: NowFn) -> Self {
413 self.now = now;
414 self
415 }
416
417 pub fn with_storage_usage_collection_interval(
418 mut self,
419 storage_usage_collection_interval: Duration,
420 ) -> Self {
421 self.storage_usage_collection_interval = storage_usage_collection_interval;
422 self
423 }
424
425 pub fn with_storage_usage_retention_period(
426 mut self,
427 storage_usage_retention_period: Duration,
428 ) -> Self {
429 self.storage_usage_retention_period = Some(storage_usage_retention_period);
430 self
431 }
432
433 pub fn with_default_cluster_replica_size(
434 mut self,
435 default_cluster_replica_size: String,
436 ) -> Self {
437 self.default_cluster_replica_size = default_cluster_replica_size;
438 self
439 }
440
441 pub fn with_builtin_system_cluster_replica_size(
442 mut self,
443 builtin_system_cluster_replica_size: String,
444 ) -> Self {
445 self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
446 self
447 }
448
449 pub fn with_builtin_system_cluster_replication_factor(
450 mut self,
451 builtin_system_cluster_replication_factor: u32,
452 ) -> Self {
453 self.builtin_system_cluster_config.replication_factor =
454 builtin_system_cluster_replication_factor;
455 self
456 }
457
458 pub fn with_builtin_catalog_server_cluster_replica_size(
459 mut self,
460 builtin_catalog_server_cluster_replica_size: String,
461 ) -> Self {
462 self.builtin_catalog_server_cluster_config.size =
463 builtin_catalog_server_cluster_replica_size;
464 self
465 }
466
467 pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
468 self.propagate_crashes = propagate_crashes;
469 self
470 }
471
472 pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
473 self.enable_tracing = enable_tracing;
474 self
475 }
476
477 pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
478 self.bootstrap_role = bootstrap_role;
479 self
480 }
481
482 pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
483 self.deploy_generation = deploy_generation;
484 self
485 }
486
487 pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
488 self.system_parameter_defaults.insert(param, value);
489 self
490 }
491
492 pub fn with_internal_console_redirect_url(
493 mut self,
494 internal_console_redirect_url: Option<String>,
495 ) -> Self {
496 self.internal_console_redirect_url = internal_console_redirect_url;
497 self
498 }
499
500 pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
501 self.metrics_registry = Some(registry);
502 self
503 }
504
505 pub fn with_code_version(mut self, version: semver::Version) -> Self {
506 self.code_version = version;
507 self
508 }
509
510 pub fn with_capture(mut self, storage: SharedStorage) -> Self {
511 self.capture = Some(storage);
512 self
513 }
514}
515
516pub struct Listeners {
517 pub inner: crate::Listeners,
518}
519
520impl Listeners {
521 pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
522 let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
523 Ok(Listeners { inner })
524 }
525
526 pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
527 self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
528 .await
529 }
530
531 pub async fn serve_with_trigger(
532 self,
533 config: TestHarness,
534 tls_reload_certs: ReloadTrigger,
535 ) -> Result<TestServer, anyhow::Error> {
536 let (data_directory, temp_dir) = match config.data_directory {
537 None => {
538 let temp_dir = tempfile::tempdir()?;
543 (temp_dir.path().to_path_buf(), Some(temp_dir))
544 }
545 Some(data_directory) => (data_directory, None),
546 };
547 let scratch_dir = tempfile::tempdir()?;
548 let (consensus_uri, timestamp_oracle_url) = {
549 let seed = config.seed;
550 let cockroach_url = env::var("METADATA_BACKEND_URL")
551 .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
552 let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
553 mz_ore::task::spawn(|| "startup-postgres-conn", async move {
554 if let Err(err) = conn.await {
555 panic!("connection error: {}", err);
556 };
557 });
558 client
559 .batch_execute(&format!(
560 "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
561 CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
562 ))
563 .await?;
564 (
565 format!("{cockroach_url}?options=--search_path=consensus_{seed}")
566 .parse()
567 .expect("invalid consensus URI"),
568 format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
569 .parse()
570 .expect("invalid timestamp oracle URI"),
571 )
572 };
573 let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
574 let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
575 image_dir: env::current_exe()?
576 .parent()
577 .unwrap()
578 .parent()
579 .unwrap()
580 .to_path_buf(),
581 suppress_output: false,
582 environment_id: config.environment_id.to_string(),
583 secrets_dir: data_directory.join("secrets"),
584 command_wrapper: vec![],
585 propagate_crashes: config.propagate_crashes,
586 tcp_proxy: None,
587 scratch_directory: scratch_dir.path().to_path_buf(),
588 })
589 .await?;
590 let orchestrator = Arc::new(orchestrator);
591 let persist_now = SYSTEM_TIME.clone();
594 let dyncfgs = mz_dyncfgs::all_dyncfgs();
595
596 let mut updates = ConfigUpdates::default();
597 updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
600 updates.apply(&dyncfgs);
601
602 let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
603 persist_cfg.build_version = config.code_version;
604 persist_cfg.set_rollup_threshold(5);
606
607 let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
608 let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
609 let persist_pubsub_tcp_listener =
610 TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
611 .await
612 .expect("pubsub addr binding");
613 let persist_pubsub_server_port = persist_pubsub_tcp_listener
614 .local_addr()
615 .expect("pubsub addr has local addr")
616 .port();
617
618 mz_ore::task::spawn(|| "persist_pubsub_server", async move {
620 persist_pubsub_server
621 .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
622 .await
623 .expect("success")
624 });
625 let persist_clients =
626 PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
627 let persist_clients = Arc::new(persist_clients);
628
629 let secrets_controller = Arc::clone(&orchestrator);
630 let connection_context = ConnectionContext::for_tests(orchestrator.reader());
631 let orchestrator = Arc::new(TracingOrchestrator::new(
632 orchestrator,
633 config.orchestrator_tracing_cli_args,
634 ));
635 let (tracing_handle, tracing_guard) = if config.enable_tracing {
636 let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
637 service_name: "environmentd",
638 stderr_log: StderrLogConfig {
639 format: StderrLogFormat::Json,
640 filter: EnvFilter::default(),
641 },
642 opentelemetry: Some(OpenTelemetryConfig {
643 endpoint: "http://fake_address_for_testing:8080".to_string(),
644 headers: http::HeaderMap::new(),
645 filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
646 resource: opentelemetry_sdk::resource::Resource::default(),
647 max_batch_queue_size: 2048,
648 max_export_batch_size: 512,
649 max_concurrent_exports: 1,
650 batch_scheduled_delay: Duration::from_millis(5000),
651 max_export_timeout: Duration::from_secs(30),
652 }),
653 tokio_console: None,
654 sentry: None,
655 build_version: crate::BUILD_INFO.version,
656 build_sha: crate::BUILD_INFO.sha,
657 registry: metrics_registry.clone(),
658 capture: config.capture,
659 };
660 let (tracing_handle, tracing_guard) = mz_ore::tracing::configure(config).await?;
661 (tracing_handle, Some(tracing_guard))
662 } else {
663 (TracingHandle::disabled(), None)
664 };
665 let host_name = format!(
666 "localhost:{}",
667 self.inner.http["external"].handle.local_addr.port()
668 );
669 let catalog_config = CatalogConfig {
670 persist_clients: Arc::clone(&persist_clients),
671 metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
672 };
673
674 let inner = self
675 .inner
676 .serve(crate::Config {
677 catalog_config,
678 timestamp_oracle_url: Some(timestamp_oracle_url),
679 controller: ControllerConfig {
680 build_info: &crate::BUILD_INFO,
681 orchestrator,
682 clusterd_image: "clusterd".into(),
683 init_container_image: None,
684 deploy_generation: config.deploy_generation,
685 persist_location: PersistLocation {
686 blob_uri: format!("file://{}/persist/blob", data_directory.display())
687 .parse()
688 .expect("invalid blob URI"),
689 consensus_uri,
690 },
691 persist_clients,
692 now: config.now.clone(),
693 metrics_registry: metrics_registry.clone(),
694 persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
695 secrets_args: mz_service::secrets::SecretsReaderCliArgs {
696 secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
697 secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
698 secrets_reader_kubernetes_context: None,
699 secrets_reader_aws_prefix: None,
700 secrets_reader_name_prefix: None,
701 },
702 connection_context,
703 },
704 secrets_controller,
705 cloud_resource_controller: None,
706 tls: config.tls,
707 frontegg: config.frontegg,
708 unsafe_mode: config.unsafe_mode,
709 all_features: false,
710 metrics_registry: metrics_registry.clone(),
711 now: config.now,
712 environment_id: config.environment_id,
713 cors_allowed_origin: AllowOrigin::list([]),
714 cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
715 bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
716 bootstrap_default_cluster_replication_factor: config
717 .default_cluster_replication_factor,
718 bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
719 bootstrap_builtin_catalog_server_cluster_config: config
720 .builtin_catalog_server_cluster_config,
721 bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
722 bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
723 bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
724 system_parameter_defaults: config.system_parameter_defaults,
725 availability_zones: Default::default(),
726 tracing_handle,
727 storage_usage_collection_interval: config.storage_usage_collection_interval,
728 storage_usage_retention_period: config.storage_usage_retention_period,
729 segment_api_key: None,
730 segment_client_side: false,
731 test_only_dummy_segment_client: false,
732 egress_addresses: vec![],
733 aws_account_id: None,
734 aws_privatelink_availability_zones: None,
735 launchdarkly_sdk_key: None,
736 launchdarkly_key_map: Default::default(),
737 config_sync_file_path: None,
738 config_sync_timeout: Duration::from_secs(30),
739 config_sync_loop_interval: None,
740 bootstrap_role: config.bootstrap_role,
741 http_host_name: Some(host_name),
742 internal_console_redirect_url: config.internal_console_redirect_url,
743 tls_reload_certs,
744 helm_chart_version: None,
745 license_key: ValidatedLicenseKey::for_tests(),
746 external_login_password_mz_system: config.external_login_password_mz_system,
747 })
748 .await?;
749
750 Ok(TestServer {
751 inner,
752 metrics_registry,
753 _temp_dir: temp_dir,
754 _tracing_guard: tracing_guard,
755 })
756 }
757}
758
759pub struct TestServer {
761 pub inner: crate::Server,
762 pub metrics_registry: MetricsRegistry,
763 _temp_dir: Option<TempDir>,
764 _tracing_guard: Option<TracingGuard>,
765}
766
767impl TestServer {
768 pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
769 ConnectBuilder::new(self).no_tls()
770 }
771
772 pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
773 let internal_client = self.connect().internal().await.unwrap();
774
775 for flag in flags {
776 internal_client
777 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
778 .await
779 .unwrap();
780 }
781 }
782
783 pub fn ws_addr(&self) -> Uri {
784 format!(
785 "ws://{}/api/experimental/sql",
786 self.inner.http_listener_handles["external"].local_addr
787 )
788 .parse()
789 .unwrap()
790 }
791
792 pub fn internal_ws_addr(&self) -> Uri {
793 format!(
794 "ws://{}/api/experimental/sql",
795 self.inner.http_listener_handles["internal"].local_addr
796 )
797 .parse()
798 .unwrap()
799 }
800
801 pub fn http_local_addr(&self) -> SocketAddr {
802 self.inner.http_listener_handles["external"].local_addr
803 }
804
805 pub fn internal_http_local_addr(&self) -> SocketAddr {
806 self.inner.http_listener_handles["internal"].local_addr
807 }
808
809 pub fn sql_local_addr(&self) -> SocketAddr {
810 self.inner.sql_listener_handles["external"].local_addr
811 }
812
813 pub fn internal_sql_local_addr(&self) -> SocketAddr {
814 self.inner.sql_listener_handles["internal"].local_addr
815 }
816}
817
818pub struct ConnectBuilder<'s, T, H> {
822 server: &'s TestServer,
824
825 pg_config: tokio_postgres::Config,
827 port: u16,
829 tls: T,
831
832 notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
834
835 _with_handle: H,
837}
838
839impl<'s> ConnectBuilder<'s, (), NoHandle> {
840 fn new(server: &'s TestServer) -> Self {
841 let mut pg_config = tokio_postgres::Config::new();
842 pg_config
843 .host(&Ipv4Addr::LOCALHOST.to_string())
844 .user("materialize")
845 .options("--welcome_message=off")
846 .application_name("environmentd_test_framework");
847
848 ConnectBuilder {
849 server,
850 pg_config,
851 port: server.sql_local_addr().port(),
852 tls: (),
853 notice_callback: None,
854 _with_handle: NoHandle,
855 }
856 }
857}
858
859impl<'s, T, H> ConnectBuilder<'s, T, H> {
860 pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
864 ConnectBuilder {
865 server: self.server,
866 pg_config: self.pg_config,
867 port: self.port,
868 tls: postgres::NoTls,
869 notice_callback: self.notice_callback,
870 _with_handle: self._with_handle,
871 }
872 }
873
874 pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
876 where
877 Tls: MakeTlsConnect<Socket> + Send + 'static,
878 Tls::TlsConnect: Send,
879 Tls::Stream: Send,
880 <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
881 {
882 ConnectBuilder {
883 server: self.server,
884 pg_config: self.pg_config,
885 port: self.port,
886 tls,
887 notice_callback: self.notice_callback,
888 _with_handle: self._with_handle,
889 }
890 }
891
892 pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
894 self.pg_config = pg_config;
895 self
896 }
897
898 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
900 self.pg_config.ssl_mode(mode);
901 self
902 }
903
904 pub fn user(mut self, user: &str) -> Self {
906 self.pg_config.user(user);
907 self
908 }
909
910 pub fn password(mut self, password: &str) -> Self {
912 self.pg_config.password(password);
913 self
914 }
915
916 pub fn application_name(mut self, application_name: &str) -> Self {
918 self.pg_config.application_name(application_name);
919 self
920 }
921
922 pub fn dbname(mut self, dbname: &str) -> Self {
924 self.pg_config.dbname(dbname);
925 self
926 }
927
928 pub fn options(mut self, options: &str) -> Self {
930 self.pg_config.options(options);
931 self
932 }
933
934 pub fn internal(mut self) -> Self {
939 self.port = self.server.internal_sql_local_addr().port();
940 self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
941 self
942 }
943
944 pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
946 ConnectBuilder {
947 notice_callback: Some(Box::new(callback)),
948 ..self
949 }
950 }
951
952 pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
955 ConnectBuilder {
956 server: self.server,
957 pg_config: self.pg_config,
958 port: self.port,
959 tls: self.tls,
960 notice_callback: self.notice_callback,
961 _with_handle: WithHandle,
962 }
963 }
964
965 pub fn as_pg_config(&self) -> &tokio_postgres::Config {
967 &self.pg_config
968 }
969}
970
971pub trait IncludeHandle: Send {
974 type Output;
975 fn transform_result(
976 client: tokio_postgres::Client,
977 handle: mz_ore::task::JoinHandle<()>,
978 ) -> Self::Output;
979}
980
981pub struct NoHandle;
984impl IncludeHandle for NoHandle {
985 type Output = tokio_postgres::Client;
986 fn transform_result(
987 client: tokio_postgres::Client,
988 _handle: mz_ore::task::JoinHandle<()>,
989 ) -> Self::Output {
990 client
991 }
992}
993
994pub struct WithHandle;
997impl IncludeHandle for WithHandle {
998 type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
999 fn transform_result(
1000 client: tokio_postgres::Client,
1001 handle: mz_ore::task::JoinHandle<()>,
1002 ) -> Self::Output {
1003 (client, handle)
1004 }
1005}
1006
1007impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1008where
1009 T: MakeTlsConnect<Socket> + Send + 'static,
1010 T::TlsConnect: Send,
1011 T::Stream: Send,
1012 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1013 H: IncludeHandle,
1014{
1015 type Output = Result<H::Output, postgres::Error>;
1016 type IntoFuture = BoxFuture<'static, Self::Output>;
1017
1018 fn into_future(mut self) -> Self::IntoFuture {
1019 Box::pin(async move {
1020 assert!(
1021 self.pg_config.get_ports().is_empty(),
1022 "specifying multiple ports is not supported"
1023 );
1024 self.pg_config.port(self.port);
1025
1026 let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1027 let mut notice_callback = self.notice_callback.take();
1028
1029 let handle = task::spawn(|| "connect", async move {
1030 while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1031 match msg {
1032 Ok(AsyncMessage::Notice(notice)) => {
1033 if let Some(callback) = notice_callback.as_mut() {
1034 callback(notice);
1035 }
1036 }
1037 Ok(msg) => {
1038 tracing::debug!(?msg, "Dropping message from database");
1039 }
1040 Err(e) => {
1041 tracing::info!("connection error: {e}");
1046 break;
1047 }
1048 }
1049 }
1050 tracing::info!("connection closed");
1051 });
1052
1053 let output = H::transform_result(client, handle);
1054 Ok(output)
1055 })
1056 }
1057}
1058
1059pub struct TestServerWithRuntime {
1064 server: TestServer,
1065 runtime: Arc<Runtime>,
1066}
1067
1068impl TestServerWithRuntime {
1069 pub fn runtime(&self) -> &Arc<Runtime> {
1073 &self.runtime
1074 }
1075
1076 pub fn inner(&self) -> &crate::Server {
1078 &self.server.inner
1079 }
1080
1081 pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1083 where
1084 T: MakeTlsConnect<Socket> + Send + 'static,
1085 T::TlsConnect: Send,
1086 T::Stream: Send,
1087 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1088 {
1089 self.pg_config().connect(tls)
1090 }
1091
1092 pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1094 where
1095 T: MakeTlsConnect<Socket> + Send + 'static,
1096 T::TlsConnect: Send,
1097 T::Stream: Send,
1098 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1099 {
1100 Ok(self.pg_config_internal().connect(tls)?)
1101 }
1102
1103 pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1105 let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1106
1107 for flag in flags {
1108 internal_client
1109 .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1110 .unwrap();
1111 }
1112 }
1113
1114 pub fn pg_config(&self) -> postgres::Config {
1117 let local_addr = self.server.sql_local_addr();
1118 let mut config = postgres::Config::new();
1119 config
1120 .host(&Ipv4Addr::LOCALHOST.to_string())
1121 .port(local_addr.port())
1122 .user("materialize")
1123 .options("--welcome_message=off");
1124 config
1125 }
1126
1127 pub fn pg_config_internal(&self) -> postgres::Config {
1130 let local_addr = self.server.internal_sql_local_addr();
1131 let mut config = postgres::Config::new();
1132 config
1133 .host(&Ipv4Addr::LOCALHOST.to_string())
1134 .port(local_addr.port())
1135 .user("mz_system")
1136 .options("--welcome_message=off");
1137 config
1138 }
1139
1140 pub fn ws_addr(&self) -> Uri {
1141 self.server.ws_addr()
1142 }
1143
1144 pub fn internal_ws_addr(&self) -> Uri {
1145 self.server.internal_ws_addr()
1146 }
1147
1148 pub fn http_local_addr(&self) -> SocketAddr {
1149 self.server.http_local_addr()
1150 }
1151
1152 pub fn internal_http_local_addr(&self) -> SocketAddr {
1153 self.server.internal_http_local_addr()
1154 }
1155
1156 pub fn sql_local_addr(&self) -> SocketAddr {
1157 self.server.sql_local_addr()
1158 }
1159
1160 pub fn internal_sql_local_addr(&self) -> SocketAddr {
1161 self.server.internal_sql_local_addr()
1162 }
1163}
1164
1165#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1166pub struct MzTimestamp(pub u64);
1167
1168impl<'a> FromSql<'a> for MzTimestamp {
1169 fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1170 let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1171 Ok(MzTimestamp(u64::try_from(n.0.0)?))
1172 }
1173
1174 fn accepts(ty: &Type) -> bool {
1175 mz_pgrepr::Numeric::accepts(ty)
1176 }
1177}
1178
1179pub trait PostgresErrorExt {
1180 fn unwrap_db_error(self) -> DbError;
1181}
1182
1183impl PostgresErrorExt for postgres::Error {
1184 fn unwrap_db_error(self) -> DbError {
1185 match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1186 Some(e) => e.clone(),
1187 None => panic!("expected DbError, but got: {:?}", self),
1188 }
1189 }
1190}
1191
1192impl<T, E> PostgresErrorExt for Result<T, E>
1193where
1194 E: PostgresErrorExt,
1195{
1196 fn unwrap_db_error(self) -> DbError {
1197 match self {
1198 Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1199 Err(e) => e.unwrap_db_error(),
1200 }
1201 }
1202}
1203
1204pub async fn insert_with_deterministic_timestamps(
1208 table: &'static str,
1209 values: &'static str,
1210 server: &TestServer,
1211 now: Arc<std::sync::Mutex<EpochMillis>>,
1212) -> Result<(), Box<dyn Error>> {
1213 let client_write = server.connect().await?;
1214 let client_read = server.connect().await?;
1215
1216 let mut current_ts = get_explain_timestamp(table, &client_read).await;
1217
1218 let insert_query = format!("INSERT INTO {table} VALUES {values}");
1219
1220 let write_future = client_write.execute(&insert_query, &[]);
1221 let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1222
1223 let mut write_future = std::pin::pin!(write_future);
1224 let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1225
1226 loop {
1229 tokio::select! {
1230 _ = (&mut write_future) => return Ok(()),
1231 _ = timestamp_interval.tick() => {
1232 current_ts += 1;
1233 *now.lock().expect("lock poisoned") = current_ts;
1234 }
1235 };
1236 }
1237}
1238
1239pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1240 try_get_explain_timestamp(from_suffix, client)
1241 .await
1242 .unwrap()
1243}
1244
1245pub async fn try_get_explain_timestamp(
1246 from_suffix: &str,
1247 client: &Client,
1248) -> Result<EpochMillis, anyhow::Error> {
1249 let det = get_explain_timestamp_determination(from_suffix, client).await?;
1250 let ts = det.determination.timestamp_context.timestamp_or_default();
1251 Ok(ts.into())
1252}
1253
1254pub async fn get_explain_timestamp_determination(
1255 from_suffix: &str,
1256 client: &Client,
1257) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1258 let row = client
1259 .query_one(
1260 &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1261 &[],
1262 )
1263 .await?;
1264 let explain: String = row.get(0);
1265 Ok(serde_json::from_str(&explain).unwrap())
1266}
1267
1268pub async fn create_postgres_source_with_table<'a>(
1276 server: &TestServer,
1277 mz_client: &Client,
1278 table_name: &str,
1279 table_schema: &str,
1280 source_name: &str,
1281) -> (
1282 Client,
1283 impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1284) {
1285 server
1286 .enable_feature_flags(&["enable_create_table_from_source"])
1287 .await;
1288
1289 let postgres_url = env::var("POSTGRES_URL")
1290 .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1291 .unwrap();
1292
1293 let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1294 .await
1295 .unwrap();
1296
1297 let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1298 let user = pg_config.get_user().unwrap_or("postgres");
1299 let db_name = pg_config.get_dbname().unwrap_or(user);
1300 let ports = pg_config.get_ports();
1301 let port = if ports.is_empty() { 5432 } else { ports[0] };
1302 let hosts = pg_config.get_hosts();
1303 let host = if hosts.is_empty() {
1304 "localhost".to_string()
1305 } else {
1306 match &hosts[0] {
1307 Host::Tcp(host) => host.to_string(),
1308 Host::Unix(host) => host.to_str().unwrap().to_string(),
1309 }
1310 };
1311 let password = pg_config.get_password();
1312
1313 mz_ore::task::spawn(|| "postgres-source-connection", async move {
1314 if let Err(e) = connection.await {
1315 panic!("connection error: {}", e);
1316 }
1317 });
1318
1319 let _ = pg_client
1321 .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1322 .await
1323 .unwrap();
1324 let _ = pg_client
1325 .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1326 .await
1327 .unwrap();
1328 let _ = pg_client
1329 .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1330 .await
1331 .unwrap();
1332 let _ = pg_client
1333 .execute(
1334 &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1335 &[],
1336 )
1337 .await
1338 .unwrap();
1339 let _ = pg_client
1340 .execute(
1341 &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1342 &[],
1343 )
1344 .await
1345 .unwrap();
1346
1347 let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1349 if let Some(password) = password {
1350 let password = std::str::from_utf8(password).unwrap();
1351 mz_client
1352 .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1353 .await
1354 .unwrap();
1355 connection_str = format!("{connection_str}, PASSWORD SECRET s");
1356 }
1357 mz_client
1358 .batch_execute(&format!(
1359 "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1360 ))
1361 .await
1362 .unwrap();
1363 mz_client
1364 .batch_execute(&format!(
1365 "CREATE SOURCE {source_name}
1366 FROM POSTGRES
1367 CONNECTION pgconn
1368 (PUBLICATION '{source_name}')"
1369 ))
1370 .await
1371 .unwrap();
1372 mz_client
1373 .batch_execute(&format!(
1374 "CREATE TABLE {table_name}
1375 FROM SOURCE {source_name}
1376 (REFERENCE {table_name});"
1377 ))
1378 .await
1379 .unwrap();
1380
1381 let table_name = table_name.to_string();
1382 let source_name = source_name.to_string();
1383 (
1384 pg_client,
1385 move |mz_client: &'a Client, pg_client: &'a Client| {
1386 let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1387 mz_client
1388 .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1389 .await
1390 .unwrap();
1391 mz_client
1392 .batch_execute("DROP CONNECTION pgconn;")
1393 .await
1394 .unwrap();
1395
1396 let _ = pg_client
1397 .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1398 .await
1399 .unwrap();
1400 let _ = pg_client
1401 .execute(&format!("DROP TABLE {table_name};"), &[])
1402 .await
1403 .unwrap();
1404 });
1405 f
1406 },
1407 )
1408}
1409
1410pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1411 let current_isolation = mz_client
1412 .query_one("SHOW transaction_isolation", &[])
1413 .await
1414 .unwrap()
1415 .get::<_, String>(0);
1416 mz_client
1417 .batch_execute("SET transaction_isolation = SERIALIZABLE")
1418 .await
1419 .unwrap();
1420 Retry::default()
1421 .retry_async(|_| async move {
1422 let rows = mz_client
1423 .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1424 .await
1425 .unwrap()
1426 .get::<_, i64>(0);
1427 if rows == source_rows {
1428 Ok(())
1429 } else {
1430 Err(format!(
1431 "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1432 ))
1433 }
1434 })
1435 .await
1436 .unwrap();
1437 mz_client
1438 .batch_execute(&format!(
1439 "SET transaction_isolation = '{current_isolation}'"
1440 ))
1441 .await
1442 .unwrap();
1443}
1444
1445pub fn auth_with_ws(
1447 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1448 mut options: BTreeMap<String, String>,
1449) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1450 if !options.contains_key("welcome_message") {
1451 options.insert("welcome_message".into(), "off".into());
1452 }
1453 auth_with_ws_impl(
1454 ws,
1455 Message::Text(
1456 serde_json::to_string(&WebSocketAuth::Basic {
1457 user: "materialize".into(),
1458 password: "".into(),
1459 options,
1460 })
1461 .unwrap(),
1462 ),
1463 )
1464}
1465
1466pub fn auth_with_ws_impl(
1467 ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1468 auth_message: Message,
1469) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1470 ws.send(auth_message)?;
1471
1472 let mut msgs = Vec::new();
1474 loop {
1475 let resp = ws.read()?;
1476 match resp {
1477 Message::Text(msg) => {
1478 let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1479 match msg {
1480 WebSocketResponse::ReadyForQuery(_) => break,
1481 msg => {
1482 msgs.push(msg);
1483 }
1484 }
1485 }
1486 Message::Ping(_) => continue,
1487 Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1488 Message::Close(Some(close_frame)) => {
1489 return Err(anyhow!("ws closed after auth").context(close_frame));
1490 }
1491 _ => panic!("unexpected response: {:?}", resp),
1492 }
1493 }
1494 Ok(msgs)
1495}
1496
1497pub fn make_header<H: Header>(h: H) -> HeaderMap {
1498 let mut map = HeaderMap::new();
1499 map.typed_insert(h);
1500 map
1501}
1502
1503pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1504where
1505 F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1506{
1507 let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1508 let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1529 connector_builder.set_options(options);
1530 configure(&mut connector_builder).unwrap();
1531 MakeTlsConnector::new(connector_builder.build())
1532}
1533
1534pub struct Ca {
1536 pub dir: TempDir,
1537 pub name: X509Name,
1538 pub cert: X509,
1539 pub pkey: PKey<Private>,
1540}
1541
1542impl Ca {
1543 fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1544 let dir = tempfile::tempdir()?;
1545 let rsa = Rsa::generate(2048)?;
1546 let pkey = PKey::from_rsa(rsa)?;
1547 let name = {
1548 let mut builder = X509NameBuilder::new()?;
1549 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1550 builder.build()
1551 };
1552 let cert = {
1553 let mut builder = X509::builder()?;
1554 builder.set_version(2)?;
1555 builder.set_pubkey(&pkey)?;
1556 builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1557 builder.set_subject_name(&name)?;
1558 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1559 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1560 builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1561 builder.sign(
1562 parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1563 MessageDigest::sha256(),
1564 )?;
1565 builder.build()
1566 };
1567 fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1568 Ok(Ca {
1569 dir,
1570 name,
1571 cert,
1572 pkey,
1573 })
1574 }
1575
1576 pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1578 Ca::make_ca(name, None)
1579 }
1580
1581 pub fn ca_cert_path(&self) -> PathBuf {
1583 self.dir.path().join("ca.crt")
1584 }
1585
1586 pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1588 Ca::make_ca(name, Some(self))
1589 }
1590
1591 pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1596 self.request_cert(name, iter::empty())
1597 }
1598
1599 pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1602 where
1603 I: IntoIterator<Item = IpAddr>,
1604 {
1605 let rsa = Rsa::generate(2048)?;
1606 let pkey = PKey::from_rsa(rsa)?;
1607 let subject_name = {
1608 let mut builder = X509NameBuilder::new()?;
1609 builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1610 builder.build()
1611 };
1612 let cert = {
1613 let mut builder = X509::builder()?;
1614 builder.set_version(2)?;
1615 builder.set_pubkey(&pkey)?;
1616 builder.set_issuer_name(self.cert.subject_name())?;
1617 builder.set_subject_name(&subject_name)?;
1618 builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1619 builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1620 for ip in ips {
1621 builder.append_extension(
1622 SubjectAlternativeName::new()
1623 .ip(&ip.to_string())
1624 .build(&builder.x509v3_context(None, None))?,
1625 )?;
1626 }
1627 builder.sign(&self.pkey, MessageDigest::sha256())?;
1628 builder.build()
1629 };
1630 let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1631 let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1632 fs::write(&cert_path, cert.to_pem()?)?;
1633 fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1634 Ok((cert_path, key_path))
1635 }
1636}