mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use hyper::http::header::HeaderMap;
27use mz_adapter::TimestampExplanation;
28use mz_adapter_types::bootstrap_builtin_cluster_config::{
29    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
30    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
31    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
32};
33
34use mz_catalog::config::ClusterReplicaSizeMap;
35use mz_controller::ControllerConfig;
36use mz_dyncfg::ConfigUpdates;
37use mz_license_keys::ValidatedLicenseKey;
38use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
39use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
40use mz_ore::metrics::MetricsRegistry;
41use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
42use mz_ore::retry::Retry;
43use mz_ore::task;
44use mz_ore::tracing::{
45    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingGuard,
46    TracingHandle,
47};
48use mz_persist_client::PersistLocation;
49use mz_persist_client::cache::PersistClientCache;
50use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
51use mz_persist_client::rpc::PersistGrpcPubSubServer;
52use mz_secrets::SecretsController;
53use mz_server_core::{ReloadTrigger, TlsCertConfig};
54use mz_sql::catalog::EnvironmentId;
55use mz_storage_types::connections::ConnectionContext;
56use mz_tracing::CloneableEnvFilter;
57use openssl::asn1::Asn1Time;
58use openssl::error::ErrorStack;
59use openssl::hash::MessageDigest;
60use openssl::nid::Nid;
61use openssl::pkey::{PKey, Private};
62use openssl::rsa::Rsa;
63use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
64use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
65use openssl::x509::{X509, X509Name, X509NameBuilder};
66use postgres::error::DbError;
67use postgres::tls::{MakeTlsConnect, TlsConnect};
68use postgres::types::{FromSql, Type};
69use postgres::{NoTls, Socket};
70use postgres_openssl::MakeTlsConnector;
71use tempfile::TempDir;
72use tokio::net::TcpListener;
73use tokio::runtime::Runtime;
74use tokio_postgres::config::{Host, SslMode};
75use tokio_postgres::{AsyncMessage, Client};
76use tokio_stream::wrappers::TcpListenerStream;
77use tower_http::cors::AllowOrigin;
78use tracing::Level;
79use tracing_capture::SharedStorage;
80use tracing_subscriber::EnvFilter;
81use tungstenite::stream::MaybeTlsStream;
82use tungstenite::{Message, WebSocket};
83use url::Url;
84
85use crate::{CatalogConfig, FronteggAuthentication, WebSocketAuth, WebSocketResponse};
86
87pub static KAFKA_ADDRS: LazyLock<String> =
88    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
89
90/// Entry point for creating and configuring an `environmentd` test harness.
91#[derive(Clone)]
92pub struct TestHarness {
93    data_directory: Option<PathBuf>,
94    tls: Option<TlsCertConfig>,
95    frontegg: Option<FronteggAuthentication>,
96    self_hosted_auth: bool,
97    self_hosted_auth_internal: bool,
98    unsafe_mode: bool,
99    workers: usize,
100    now: NowFn,
101    seed: u32,
102    storage_usage_collection_interval: Duration,
103    storage_usage_retention_period: Option<Duration>,
104    default_cluster_replica_size: String,
105    default_cluster_replication_factor: u32,
106    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
107    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
108    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
109    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
110    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
111
112    propagate_crashes: bool,
113    enable_tracing: bool,
114    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
115    // tracing.
116    orchestrator_tracing_cli_args: TracingCliArgs,
117    bootstrap_role: Option<String>,
118    deploy_generation: u64,
119    system_parameter_defaults: BTreeMap<String, String>,
120    internal_console_redirect_url: Option<String>,
121    metrics_registry: Option<MetricsRegistry>,
122    code_version: semver::Version,
123    capture: Option<SharedStorage>,
124    pub environment_id: EnvironmentId,
125}
126
127impl Default for TestHarness {
128    fn default() -> TestHarness {
129        TestHarness {
130            data_directory: None,
131            tls: None,
132            frontegg: None,
133            self_hosted_auth: false,
134            self_hosted_auth_internal: false,
135            unsafe_mode: false,
136            workers: 1,
137            now: SYSTEM_TIME.clone(),
138            seed: rand::random(),
139            storage_usage_collection_interval: Duration::from_secs(3600),
140            storage_usage_retention_period: None,
141            default_cluster_replica_size: "1".to_string(),
142            default_cluster_replication_factor: 2,
143            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
144                size: "1".to_string(),
145                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
146            },
147            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
148                size: "1".to_string(),
149                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
150            },
151            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
152                size: "1".to_string(),
153                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
154            },
155            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
156                size: "1".to_string(),
157                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
158            },
159            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
160                size: "1".to_string(),
161                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
162            },
163            propagate_crashes: false,
164            enable_tracing: false,
165            bootstrap_role: Some("materialize".into()),
166            deploy_generation: 0,
167            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
168            // If we need those in the future, we might need to change both.
169            system_parameter_defaults: BTreeMap::from([(
170                "log_filter".to_string(),
171                "error".to_string(),
172            )]),
173            internal_console_redirect_url: None,
174            metrics_registry: None,
175            orchestrator_tracing_cli_args: TracingCliArgs {
176                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
177                ..Default::default()
178            },
179            code_version: crate::BUILD_INFO.semver_version(),
180            environment_id: EnvironmentId::for_tests(),
181            capture: None,
182        }
183    }
184}
185
186impl TestHarness {
187    /// Starts a test [`TestServer`], panicking if the server could not be started.
188    ///
189    /// For cases when startup might fail, see [`TestHarness::try_start`].
190    pub async fn start(self) -> TestServer {
191        self.try_start().await.expect("Failed to start test Server")
192    }
193
194    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
195    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
196        self.try_start_with_trigger(tls_reload_certs)
197            .await
198            .expect("Failed to start test Server")
199    }
200
201    /// Starts a test [`TestServer`], returning an error if the server could not be started.
202    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
203        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
204            .await
205    }
206
207    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
208    pub async fn try_start_with_trigger(
209        self,
210        tls_reload_certs: ReloadTrigger,
211    ) -> Result<TestServer, anyhow::Error> {
212        let listeners = Listeners::new().await?;
213        listeners.serve_with_trigger(self, tls_reload_certs).await
214    }
215
216    /// Starts a runtime and returns a [`TestServerWithRuntime`].
217    pub fn start_blocking(self) -> TestServerWithRuntime {
218        stacker::grow(mz_ore::stack::STACK_SIZE, || {
219            let runtime = Runtime::new().expect("failed to spawn runtime for test");
220            let runtime = Arc::new(runtime);
221            let server = runtime.block_on(self.start());
222            TestServerWithRuntime { runtime, server }
223        })
224    }
225
226    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
227        self.data_directory = Some(data_directory.into());
228        self
229    }
230
231    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
232        self.tls = Some(TlsCertConfig {
233            cert: cert_path.into(),
234            key: key_path.into(),
235        });
236        self
237    }
238
239    pub fn unsafe_mode(mut self) -> Self {
240        self.unsafe_mode = true;
241        self
242    }
243
244    pub fn workers(mut self, workers: usize) -> Self {
245        self.workers = workers;
246        self
247    }
248
249    pub fn with_frontegg(mut self, frontegg: &FronteggAuthentication) -> Self {
250        self.frontegg = Some(frontegg.clone());
251        self
252    }
253
254    pub fn with_self_hosted_auth(mut self, self_hosted_auth: bool) -> Self {
255        self.self_hosted_auth = self_hosted_auth;
256        self
257    }
258
259    pub fn with_self_hosted_auth_internal(mut self, self_hosted_auth_internal: bool) -> Self {
260        self.self_hosted_auth_internal = self_hosted_auth_internal;
261        self
262    }
263
264    pub fn with_now(mut self, now: NowFn) -> Self {
265        self.now = now;
266        self
267    }
268
269    pub fn with_storage_usage_collection_interval(
270        mut self,
271        storage_usage_collection_interval: Duration,
272    ) -> Self {
273        self.storage_usage_collection_interval = storage_usage_collection_interval;
274        self
275    }
276
277    pub fn with_storage_usage_retention_period(
278        mut self,
279        storage_usage_retention_period: Duration,
280    ) -> Self {
281        self.storage_usage_retention_period = Some(storage_usage_retention_period);
282        self
283    }
284
285    pub fn with_default_cluster_replica_size(
286        mut self,
287        default_cluster_replica_size: String,
288    ) -> Self {
289        self.default_cluster_replica_size = default_cluster_replica_size;
290        self
291    }
292
293    pub fn with_builtin_system_cluster_replica_size(
294        mut self,
295        builtin_system_cluster_replica_size: String,
296    ) -> Self {
297        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
298        self
299    }
300
301    pub fn with_builtin_system_cluster_replication_factor(
302        mut self,
303        builtin_system_cluster_replication_factor: u32,
304    ) -> Self {
305        self.builtin_system_cluster_config.replication_factor =
306            builtin_system_cluster_replication_factor;
307        self
308    }
309
310    pub fn with_builtin_catalog_server_cluster_replica_size(
311        mut self,
312        builtin_catalog_server_cluster_replica_size: String,
313    ) -> Self {
314        self.builtin_catalog_server_cluster_config.size =
315            builtin_catalog_server_cluster_replica_size;
316        self
317    }
318
319    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
320        self.propagate_crashes = propagate_crashes;
321        self
322    }
323
324    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
325        self.enable_tracing = enable_tracing;
326        self
327    }
328
329    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
330        self.bootstrap_role = bootstrap_role;
331        self
332    }
333
334    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
335        self.deploy_generation = deploy_generation;
336        self
337    }
338
339    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
340        self.system_parameter_defaults.insert(param, value);
341        self
342    }
343
344    pub fn with_internal_console_redirect_url(
345        mut self,
346        internal_console_redirect_url: Option<String>,
347    ) -> Self {
348        self.internal_console_redirect_url = internal_console_redirect_url;
349        self
350    }
351
352    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
353        self.metrics_registry = Some(registry);
354        self
355    }
356
357    pub fn with_code_version(mut self, version: semver::Version) -> Self {
358        self.code_version = version;
359        self
360    }
361
362    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
363        self.capture = Some(storage);
364        self
365    }
366}
367
368pub struct Listeners {
369    pub inner: crate::Listeners,
370}
371
372impl Listeners {
373    pub async fn new() -> Result<Listeners, anyhow::Error> {
374        let inner = crate::Listeners::bind_any_local().await?;
375        Ok(Listeners { inner })
376    }
377
378    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
379        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
380            .await
381    }
382
383    pub async fn serve_with_trigger(
384        self,
385        config: TestHarness,
386        tls_reload_certs: ReloadTrigger,
387    ) -> Result<TestServer, anyhow::Error> {
388        let (data_directory, temp_dir) = match config.data_directory {
389            None => {
390                // If no data directory is provided, we create a temporary
391                // directory. The temporary directory is cleaned up when the
392                // `TempDir` is dropped, so we keep it alive until the `Server` is
393                // dropped.
394                let temp_dir = tempfile::tempdir()?;
395                (temp_dir.path().to_path_buf(), Some(temp_dir))
396            }
397            Some(data_directory) => (data_directory, None),
398        };
399        let scratch_dir = tempfile::tempdir()?;
400        let (consensus_uri, timestamp_oracle_url) = {
401            let seed = config.seed;
402            let cockroach_url = env::var("COCKROACH_URL")
403                .map_err(|_| anyhow!("COCKROACH_URL environment variable is not set"))?;
404            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
405            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
406                if let Err(err) = conn.await {
407                    panic!("connection error: {}", err);
408                };
409            });
410            client
411                .batch_execute(&format!(
412                    "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
413                    CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
414                ))
415                .await?;
416            (
417                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
418                    .parse()
419                    .expect("invalid consensus URI"),
420                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
421                    .parse()
422                    .expect("invalid timestamp oracle URI"),
423            )
424        };
425        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
426        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
427            image_dir: env::current_exe()?
428                .parent()
429                .unwrap()
430                .parent()
431                .unwrap()
432                .to_path_buf(),
433            suppress_output: false,
434            environment_id: config.environment_id.to_string(),
435            secrets_dir: data_directory.join("secrets"),
436            command_wrapper: vec![],
437            propagate_crashes: config.propagate_crashes,
438            tcp_proxy: None,
439            scratch_directory: scratch_dir.path().to_path_buf(),
440        })
441        .await?;
442        let orchestrator = Arc::new(orchestrator);
443        // Messing with the clock causes persist to expire leases, causing hangs and
444        // panics. Is it possible/desirable to put this back somehow?
445        let persist_now = SYSTEM_TIME.clone();
446        let dyncfgs = mz_dyncfgs::all_dyncfgs();
447
448        let mut updates = ConfigUpdates::default();
449        // Tune down the number of connections to make this all work a little easier
450        // with local postgres.
451        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
452        updates.apply(&dyncfgs);
453
454        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
455        persist_cfg.build_version = config.code_version;
456        // Stress persist more by writing rollups frequently
457        persist_cfg.set_rollup_threshold(5);
458
459        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
460        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
461        let persist_pubsub_tcp_listener =
462            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
463                .await
464                .expect("pubsub addr binding");
465        let persist_pubsub_server_port = persist_pubsub_tcp_listener
466            .local_addr()
467            .expect("pubsub addr has local addr")
468            .port();
469
470        // Spawn the persist pub-sub server.
471        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
472            persist_pubsub_server
473                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
474                .await
475                .expect("success")
476        });
477        let persist_clients =
478            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
479        let persist_clients = Arc::new(persist_clients);
480
481        let secrets_controller = Arc::clone(&orchestrator);
482        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
483        let orchestrator = Arc::new(TracingOrchestrator::new(
484            orchestrator,
485            config.orchestrator_tracing_cli_args,
486        ));
487        let (tracing_handle, tracing_guard) = if config.enable_tracing {
488            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
489                service_name: "environmentd",
490                stderr_log: StderrLogConfig {
491                    format: StderrLogFormat::Json,
492                    filter: EnvFilter::default(),
493                },
494                opentelemetry: Some(OpenTelemetryConfig {
495                    endpoint: "http://fake_address_for_testing:8080".to_string(),
496                    headers: http::HeaderMap::new(),
497                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
498                    resource: opentelemetry_sdk::resource::Resource::default(),
499                    max_batch_queue_size: 2048,
500                    max_export_batch_size: 512,
501                    max_concurrent_exports: 1,
502                    batch_scheduled_delay: Duration::from_millis(5000),
503                    max_export_timeout: Duration::from_secs(30),
504                }),
505                tokio_console: None,
506                sentry: None,
507                build_version: crate::BUILD_INFO.version,
508                build_sha: crate::BUILD_INFO.sha,
509                registry: metrics_registry.clone(),
510                capture: config.capture,
511            };
512            let (tracing_handle, tracing_guard) = mz_ore::tracing::configure(config).await?;
513            (tracing_handle, Some(tracing_guard))
514        } else {
515            (TracingHandle::disabled(), None)
516        };
517        let host_name = format!("localhost:{}", self.inner.http_local_addr().port());
518        let catalog_config = CatalogConfig {
519            persist_clients: Arc::clone(&persist_clients),
520            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
521        };
522
523        let inner = self
524            .inner
525            .serve(crate::Config {
526                catalog_config,
527                timestamp_oracle_url: Some(timestamp_oracle_url),
528                controller: ControllerConfig {
529                    build_info: &crate::BUILD_INFO,
530                    orchestrator,
531                    clusterd_image: "clusterd".into(),
532                    init_container_image: None,
533                    deploy_generation: config.deploy_generation,
534                    persist_location: PersistLocation {
535                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
536                            .parse()
537                            .expect("invalid blob URI"),
538                        consensus_uri,
539                    },
540                    persist_clients,
541                    now: config.now.clone(),
542                    metrics_registry: metrics_registry.clone(),
543                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
544                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
545                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
546                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
547                        secrets_reader_kubernetes_context: None,
548                        secrets_reader_aws_prefix: None,
549                        secrets_reader_name_prefix: None,
550                    },
551                    connection_context,
552                },
553                secrets_controller,
554                cloud_resource_controller: None,
555                tls: config.tls,
556                frontegg: config.frontegg,
557                self_hosted_auth: config.self_hosted_auth,
558                self_hosted_auth_internal: config.self_hosted_auth_internal,
559                unsafe_mode: config.unsafe_mode,
560                all_features: false,
561                metrics_registry: metrics_registry.clone(),
562                now: config.now,
563                environment_id: config.environment_id,
564                cors_allowed_origin: AllowOrigin::list([]),
565                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
566                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
567                bootstrap_default_cluster_replication_factor: config
568                    .default_cluster_replication_factor,
569                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
570                bootstrap_builtin_catalog_server_cluster_config: config
571                    .builtin_catalog_server_cluster_config,
572                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
573                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
574                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
575                system_parameter_defaults: config.system_parameter_defaults,
576                availability_zones: Default::default(),
577                tracing_handle,
578                storage_usage_collection_interval: config.storage_usage_collection_interval,
579                storage_usage_retention_period: config.storage_usage_retention_period,
580                segment_api_key: None,
581                segment_client_side: false,
582                egress_addresses: vec![],
583                aws_account_id: None,
584                aws_privatelink_availability_zones: None,
585                launchdarkly_sdk_key: None,
586                launchdarkly_key_map: Default::default(),
587                config_sync_timeout: Duration::from_secs(30),
588                config_sync_loop_interval: None,
589                bootstrap_role: config.bootstrap_role,
590                http_host_name: Some(host_name),
591                internal_console_redirect_url: config.internal_console_redirect_url,
592                tls_reload_certs,
593                helm_chart_version: None,
594                license_key: ValidatedLicenseKey::for_tests(),
595            })
596            .await?;
597
598        Ok(TestServer {
599            inner,
600            metrics_registry,
601            _temp_dir: temp_dir,
602            _tracing_guard: tracing_guard,
603        })
604    }
605}
606
607/// A running instance of `environmentd`.
608pub struct TestServer {
609    pub inner: crate::Server,
610    pub metrics_registry: MetricsRegistry,
611    _temp_dir: Option<TempDir>,
612    _tracing_guard: Option<TracingGuard>,
613}
614
615impl TestServer {
616    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
617        ConnectBuilder::new(self).no_tls()
618    }
619
620    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
621        let internal_client = self.connect().internal().await.unwrap();
622
623        for flag in flags {
624            internal_client
625                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
626                .await
627                .unwrap();
628        }
629    }
630
631    pub fn ws_addr(&self) -> Url {
632        Url::parse(&format!(
633            "ws://{}/api/experimental/sql",
634            self.inner.http_local_addr()
635        ))
636        .unwrap()
637    }
638
639    pub fn internal_ws_addr(&self) -> Url {
640        Url::parse(&format!(
641            "ws://{}/api/experimental/sql",
642            self.inner.internal_http_local_addr()
643        ))
644        .unwrap()
645    }
646}
647
648/// A builder struct to configure a pgwire connection to a running [`TestServer`].
649///
650/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
651pub struct ConnectBuilder<'s, T, H> {
652    /// A running `environmentd` test server.
653    server: &'s TestServer,
654
655    /// Postgres configuration for connecting to the test server.
656    pg_config: tokio_postgres::Config,
657    /// Port to use when connecting to the test server.
658    port: u16,
659    /// Tls settings to use.
660    tls: T,
661
662    /// Callback that gets invoked for every notice we receive.
663    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
664
665    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
666    _with_handle: H,
667}
668
669impl<'s> ConnectBuilder<'s, (), NoHandle> {
670    fn new(server: &'s TestServer) -> Self {
671        let mut pg_config = tokio_postgres::Config::new();
672        pg_config
673            .host(&Ipv4Addr::LOCALHOST.to_string())
674            .user("materialize")
675            .options("--welcome_message=off")
676            .application_name("environmentd_test_framework");
677
678        ConnectBuilder {
679            server,
680            pg_config,
681            port: server.inner.sql_local_addr().port(),
682            tls: (),
683            notice_callback: None,
684            _with_handle: NoHandle,
685        }
686    }
687}
688
689impl<'s, T, H> ConnectBuilder<'s, T, H> {
690    /// Create a pgwire connection without using TLS.
691    ///
692    /// Note: this is the default for all connections.
693    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
694        ConnectBuilder {
695            server: self.server,
696            pg_config: self.pg_config,
697            port: self.port,
698            tls: postgres::NoTls,
699            notice_callback: self.notice_callback,
700            _with_handle: self._with_handle,
701        }
702    }
703
704    /// Create a pgwire connection with TLS.
705    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
706    where
707        Tls: MakeTlsConnect<Socket> + Send + 'static,
708        Tls::TlsConnect: Send,
709        Tls::Stream: Send,
710        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
711    {
712        ConnectBuilder {
713            server: self.server,
714            pg_config: self.pg_config,
715            port: self.port,
716            tls,
717            notice_callback: self.notice_callback,
718            _with_handle: self._with_handle,
719        }
720    }
721
722    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
723    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
724        self.pg_config = pg_config;
725        self
726    }
727
728    /// Set the [`SslMode`] to be used with the resulting connection.
729    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
730        self.pg_config.ssl_mode(mode);
731        self
732    }
733
734    /// Set the user for the pgwire connection.
735    pub fn user(mut self, user: &str) -> Self {
736        self.pg_config.user(user);
737        self
738    }
739
740    /// Set the password for the pgwire connection.
741    pub fn password(mut self, password: &str) -> Self {
742        self.pg_config.password(password);
743        self
744    }
745
746    /// Set the application name for the pgwire connection.
747    pub fn application_name(mut self, application_name: &str) -> Self {
748        self.pg_config.application_name(application_name);
749        self
750    }
751
752    /// Set the database name for the pgwire connection.
753    pub fn dbname(mut self, dbname: &str) -> Self {
754        self.pg_config.dbname(dbname);
755        self
756    }
757
758    /// Set the options for the pgwire connection.
759    pub fn options(mut self, options: &str) -> Self {
760        self.pg_config.options(options);
761        self
762    }
763
764    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
765    /// [`TestServer`].
766    ///
767    /// For example, this will change the port we connect to, and the user we connect as.
768    pub fn internal(mut self) -> Self {
769        self.port = self.server.inner.internal_sql_local_addr().port();
770        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
771        self
772    }
773
774    /// Configures this [`ConnectBuilder`] to connect to the __balancer__ SQL port of the running
775    /// [`TestServer`].
776    ///
777    /// For example, this will change the port we connect to, and the user we connect as.
778    pub fn balancer(mut self) -> Self {
779        self.port = self.server.inner.sql_local_addr().port();
780        self.pg_config.user("materialize");
781        self
782    }
783
784    /// Sets a callback for any database notices that are received from the [`TestServer`].
785    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
786        ConnectBuilder {
787            notice_callback: Some(Box::new(callback)),
788            ..self
789        }
790    }
791
792    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
793    /// polling the underlying postgres connection, associated with the returned client.
794    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
795        ConnectBuilder {
796            server: self.server,
797            pg_config: self.pg_config,
798            port: self.port,
799            tls: self.tls,
800            notice_callback: self.notice_callback,
801            _with_handle: WithHandle,
802        }
803    }
804
805    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
806    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
807        &self.pg_config
808    }
809}
810
811/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
812/// of a client connection.
813pub trait IncludeHandle: Send {
814    type Output;
815    fn transform_result(
816        client: tokio_postgres::Client,
817        handle: mz_ore::task::JoinHandle<()>,
818    ) -> Self::Output;
819}
820
821/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
822/// result of a [`ConnectBuilder`].
823pub struct NoHandle;
824impl IncludeHandle for NoHandle {
825    type Output = tokio_postgres::Client;
826    fn transform_result(
827        client: tokio_postgres::Client,
828        _handle: mz_ore::task::JoinHandle<()>,
829    ) -> Self::Output {
830        client
831    }
832}
833
834/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
835/// a [`ConnectBuilder`].
836pub struct WithHandle;
837impl IncludeHandle for WithHandle {
838    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
839    fn transform_result(
840        client: tokio_postgres::Client,
841        handle: mz_ore::task::JoinHandle<()>,
842    ) -> Self::Output {
843        (client, handle)
844    }
845}
846
847impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
848where
849    T: MakeTlsConnect<Socket> + Send + 'static,
850    T::TlsConnect: Send,
851    T::Stream: Send,
852    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
853    H: IncludeHandle,
854{
855    type Output = Result<H::Output, postgres::Error>;
856    type IntoFuture = BoxFuture<'static, Self::Output>;
857
858    fn into_future(mut self) -> Self::IntoFuture {
859        Box::pin(async move {
860            assert!(
861                self.pg_config.get_ports().is_empty(),
862                "specifying multiple ports is not supported"
863            );
864            self.pg_config.port(self.port);
865
866            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
867            let mut notice_callback = self.notice_callback.take();
868
869            let handle = task::spawn(|| "connect", async move {
870                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
871                    match msg {
872                        Ok(AsyncMessage::Notice(notice)) => {
873                            if let Some(callback) = notice_callback.as_mut() {
874                                callback(notice);
875                            }
876                        }
877                        Ok(msg) => {
878                            tracing::debug!(?msg, "Dropping message from database");
879                        }
880                        Err(e) => {
881                            // tokio_postgres::Connection docs say:
882                            // > Return values of None or Some(Err(_)) are “terminal”; callers
883                            // > should not invoke this method again after receiving one of those
884                            // > values.
885                            tracing::info!("connection error: {e}");
886                            break;
887                        }
888                    }
889                }
890                tracing::info!("connection closed");
891            });
892
893            let output = H::transform_result(client, handle);
894            Ok(output)
895        })
896    }
897}
898
899/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
900///
901/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
902/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
903pub struct TestServerWithRuntime {
904    server: TestServer,
905    runtime: Arc<Runtime>,
906}
907
908impl TestServerWithRuntime {
909    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
910    ///
911    /// Can be used to spawn async tasks.
912    pub fn runtime(&self) -> &Arc<Runtime> {
913        &self.runtime
914    }
915
916    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
917    pub fn inner(&self) -> &crate::Server {
918        &self.server.inner
919    }
920
921    /// Connect to the __public__ SQL port of the running `environmentd` server.
922    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
923    where
924        T: MakeTlsConnect<Socket> + Send + 'static,
925        T::TlsConnect: Send,
926        T::Stream: Send,
927        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
928    {
929        self.pg_config().connect(tls)
930    }
931
932    /// Connect to the __internal__ SQL port of the running `environmentd` server.
933    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
934    where
935        T: MakeTlsConnect<Socket> + Send + 'static,
936        T::TlsConnect: Send,
937        T::Stream: Send,
938        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
939    {
940        Ok(self.pg_config_internal().connect(tls)?)
941    }
942
943    /// Enable LaunchDarkly feature flags.
944    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
945        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
946
947        for flag in flags {
948            internal_client
949                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
950                .unwrap();
951        }
952    }
953
954    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
955    /// `environmentd` server.
956    pub fn pg_config(&self) -> postgres::Config {
957        let local_addr = self.server.inner.sql_local_addr();
958        let mut config = postgres::Config::new();
959        config
960            .host(&Ipv4Addr::LOCALHOST.to_string())
961            .port(local_addr.port())
962            .user("materialize")
963            .options("--welcome_message=off");
964        config
965    }
966
967    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
968    /// `environmentd` server.
969    pub fn pg_config_internal(&self) -> postgres::Config {
970        let local_addr = self.server.inner.internal_sql_local_addr();
971        let mut config = postgres::Config::new();
972        config
973            .host(&Ipv4Addr::LOCALHOST.to_string())
974            .port(local_addr.port())
975            .user("mz_system")
976            .options("--welcome_message=off");
977        config
978    }
979
980    /// Return a [`postgres::Config`] for connecting to the __balancer__ SQL port of the running
981    /// `environmentd` server.
982    pub fn pg_config_balancer(&self) -> postgres::Config {
983        let local_addr = self.server.inner.sql_local_addr();
984        let mut config = postgres::Config::new();
985        config
986            .host(&Ipv4Addr::LOCALHOST.to_string())
987            .port(local_addr.port())
988            .user("materialize")
989            .options("--welcome_message=off")
990            .ssl_mode(tokio_postgres::config::SslMode::Disable);
991        config
992    }
993
994    pub fn ws_addr(&self) -> Url {
995        self.server.ws_addr()
996    }
997
998    pub fn internal_ws_addr(&self) -> Url {
999        self.server.internal_ws_addr()
1000    }
1001}
1002
1003#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1004pub struct MzTimestamp(pub u64);
1005
1006impl<'a> FromSql<'a> for MzTimestamp {
1007    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1008        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1009        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1010    }
1011
1012    fn accepts(ty: &Type) -> bool {
1013        mz_pgrepr::Numeric::accepts(ty)
1014    }
1015}
1016
1017pub trait PostgresErrorExt {
1018    fn unwrap_db_error(self) -> DbError;
1019}
1020
1021impl PostgresErrorExt for postgres::Error {
1022    fn unwrap_db_error(self) -> DbError {
1023        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1024            Some(e) => e.clone(),
1025            None => panic!("expected DbError, but got: {:?}", self),
1026        }
1027    }
1028}
1029
1030impl<T, E> PostgresErrorExt for Result<T, E>
1031where
1032    E: PostgresErrorExt,
1033{
1034    fn unwrap_db_error(self) -> DbError {
1035        match self {
1036            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1037            Err(e) => e.unwrap_db_error(),
1038        }
1039    }
1040}
1041
1042/// Group commit will block writes until the current time has advanced. This can make
1043/// performing inserts while using deterministic time difficult. This is a helper
1044/// method to perform writes and advance the current time.
1045pub async fn insert_with_deterministic_timestamps(
1046    table: &'static str,
1047    values: &'static str,
1048    server: &TestServer,
1049    now: Arc<std::sync::Mutex<EpochMillis>>,
1050) -> Result<(), Box<dyn Error>> {
1051    let client_write = server.connect().await?;
1052    let client_read = server.connect().await?;
1053
1054    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1055
1056    let insert_query = format!("INSERT INTO {table} VALUES {values}");
1057
1058    let write_future = client_write.execute(&insert_query, &[]);
1059    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1060
1061    let mut write_future = std::pin::pin!(write_future);
1062    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1063
1064    // Keep increasing `now` until the write has executed succeed. Table advancements may
1065    // have increased the global timestamp by an unknown amount.
1066    loop {
1067        tokio::select! {
1068            _ = (&mut write_future) => return Ok(()),
1069            _ = timestamp_interval.tick() => {
1070                current_ts += 1;
1071                *now.lock().expect("lock poisoned") = current_ts;
1072            }
1073        };
1074    }
1075}
1076
1077pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1078    try_get_explain_timestamp(from_suffix, client)
1079        .await
1080        .unwrap()
1081}
1082
1083pub async fn try_get_explain_timestamp(
1084    from_suffix: &str,
1085    client: &Client,
1086) -> Result<EpochMillis, anyhow::Error> {
1087    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1088    let ts = det.determination.timestamp_context.timestamp_or_default();
1089    Ok(ts.into())
1090}
1091
1092pub async fn get_explain_timestamp_determination(
1093    from_suffix: &str,
1094    client: &Client,
1095) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1096    let row = client
1097        .query_one(
1098            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1099            &[],
1100        )
1101        .await?;
1102    let explain: String = row.get(0);
1103    Ok(serde_json::from_str(&explain).unwrap())
1104}
1105
1106/// Helper function to create a Postgres source.
1107///
1108/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1109/// Postgres state.
1110///
1111/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1112/// use different postgres tables.
1113pub async fn create_postgres_source_with_table<'a>(
1114    server: &TestServer,
1115    mz_client: &Client,
1116    table_name: &str,
1117    table_schema: &str,
1118    source_name: &str,
1119) -> (
1120    Client,
1121    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1122) {
1123    server
1124        .enable_feature_flags(&["enable_create_table_from_source"])
1125        .await;
1126
1127    let postgres_url = env::var("POSTGRES_URL")
1128        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1129        .unwrap();
1130
1131    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1132        .await
1133        .unwrap();
1134
1135    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1136    let user = pg_config.get_user().unwrap_or("postgres");
1137    let db_name = pg_config.get_dbname().unwrap_or(user);
1138    let ports = pg_config.get_ports();
1139    let port = if ports.is_empty() { 5432 } else { ports[0] };
1140    let hosts = pg_config.get_hosts();
1141    let host = if hosts.is_empty() {
1142        "localhost".to_string()
1143    } else {
1144        match &hosts[0] {
1145            Host::Tcp(host) => host.to_string(),
1146            Host::Unix(host) => host.to_str().unwrap().to_string(),
1147        }
1148    };
1149    let password = pg_config.get_password();
1150
1151    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1152        if let Err(e) = connection.await {
1153            panic!("connection error: {}", e);
1154        }
1155    });
1156
1157    // Create table in Postgres with publication.
1158    let _ = pg_client
1159        .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1160        .await
1161        .unwrap();
1162    let _ = pg_client
1163        .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1164        .await
1165        .unwrap();
1166    let _ = pg_client
1167        .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1168        .await
1169        .unwrap();
1170    let _ = pg_client
1171        .execute(
1172            &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1173            &[],
1174        )
1175        .await
1176        .unwrap();
1177    let _ = pg_client
1178        .execute(
1179            &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1180            &[],
1181        )
1182        .await
1183        .unwrap();
1184
1185    // Create postgres source in Materialize.
1186    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1187    if let Some(password) = password {
1188        let password = std::str::from_utf8(password).unwrap();
1189        mz_client
1190            .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1191            .await
1192            .unwrap();
1193        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1194    }
1195    mz_client
1196        .batch_execute(&format!(
1197            "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1198        ))
1199        .await
1200        .unwrap();
1201    mz_client
1202        .batch_execute(&format!(
1203            "CREATE SOURCE {source_name}
1204            FROM POSTGRES
1205            CONNECTION pgconn
1206            (PUBLICATION '{source_name}')"
1207        ))
1208        .await
1209        .unwrap();
1210    mz_client
1211        .batch_execute(&format!(
1212            "CREATE TABLE {table_name}
1213            FROM SOURCE {source_name}
1214            (REFERENCE {table_name});"
1215        ))
1216        .await
1217        .unwrap();
1218
1219    let table_name = table_name.to_string();
1220    let source_name = source_name.to_string();
1221    (
1222        pg_client,
1223        move |mz_client: &'a Client, pg_client: &'a Client| {
1224            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1225                mz_client
1226                    .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1227                    .await
1228                    .unwrap();
1229                mz_client
1230                    .batch_execute("DROP CONNECTION pgconn;")
1231                    .await
1232                    .unwrap();
1233
1234                let _ = pg_client
1235                    .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1236                    .await
1237                    .unwrap();
1238                let _ = pg_client
1239                    .execute(&format!("DROP TABLE {table_name};"), &[])
1240                    .await
1241                    .unwrap();
1242            });
1243            f
1244        },
1245    )
1246}
1247
1248pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1249    let current_isolation = mz_client
1250        .query_one("SHOW transaction_isolation", &[])
1251        .await
1252        .unwrap()
1253        .get::<_, String>(0);
1254    mz_client
1255        .batch_execute("SET transaction_isolation = SERIALIZABLE")
1256        .await
1257        .unwrap();
1258    Retry::default()
1259        .retry_async(|_| async move {
1260            let rows = mz_client
1261                .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1262                .await
1263                .unwrap()
1264                .get::<_, i64>(0);
1265            if rows == source_rows {
1266                Ok(())
1267            } else {
1268                Err(format!(
1269                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1270                ))
1271            }
1272        })
1273        .await
1274        .unwrap();
1275    mz_client
1276        .batch_execute(&format!(
1277            "SET transaction_isolation = '{current_isolation}'"
1278        ))
1279        .await
1280        .unwrap();
1281}
1282
1283// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1284pub fn auth_with_ws(
1285    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1286    mut options: BTreeMap<String, String>,
1287) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1288    if !options.contains_key("welcome_message") {
1289        options.insert("welcome_message".into(), "off".into());
1290    }
1291    auth_with_ws_impl(
1292        ws,
1293        Message::Text(
1294            serde_json::to_string(&WebSocketAuth::Basic {
1295                user: "materialize".into(),
1296                password: "".into(),
1297                options,
1298            })
1299            .unwrap(),
1300        ),
1301    )
1302}
1303
1304pub fn auth_with_ws_impl(
1305    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1306    auth_message: Message,
1307) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1308    ws.send(auth_message)?;
1309
1310    // Wait for initial ready response.
1311    let mut msgs = Vec::new();
1312    loop {
1313        let resp = ws.read()?;
1314        match resp {
1315            Message::Text(msg) => {
1316                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1317                match msg {
1318                    WebSocketResponse::ReadyForQuery(_) => break,
1319                    msg => {
1320                        msgs.push(msg);
1321                    }
1322                }
1323            }
1324            Message::Ping(_) => continue,
1325            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1326            Message::Close(Some(close_frame)) => {
1327                return Err(anyhow!("ws closed after auth").context(close_frame));
1328            }
1329            _ => panic!("unexpected response: {:?}", resp),
1330        }
1331    }
1332    Ok(msgs)
1333}
1334
1335pub fn make_header<H: Header>(h: H) -> HeaderMap {
1336    let mut map = HeaderMap::new();
1337    map.typed_insert(h);
1338    map
1339}
1340
1341pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1342where
1343    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1344{
1345    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1346    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1347    // messages with TLS v1.2.
1348    //
1349    // Briefly, in TLS v1.3, failing to present a client certificate does not
1350    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1351    // attempt to read from the stream. But both `postgres` and `hyper` write a
1352    // bunch of data before attempting to read from the stream. With a failed
1353    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1354    // out this data, and then return a nice error message on the call to read.
1355    // But sometimes the connection is closed before they write out the data,
1356    // and so they report "connection closed" before they ever call read, never
1357    // noticing the underlying SSL error.
1358    //
1359    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1360    // if writing to the stream fails to see if a TLS error occured? Is it on
1361    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1362    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1363    // just avoid TLS v1.3 for now.
1364    //
1365    // [1]: https://github.com/openssl/openssl/issues/11118
1366    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1367    connector_builder.set_options(options);
1368    configure(&mut connector_builder).unwrap();
1369    MakeTlsConnector::new(connector_builder.build())
1370}
1371
1372/// A certificate authority for use in tests.
1373pub struct Ca {
1374    pub dir: TempDir,
1375    pub name: X509Name,
1376    pub cert: X509,
1377    pub pkey: PKey<Private>,
1378}
1379
1380impl Ca {
1381    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1382        let dir = tempfile::tempdir()?;
1383        let rsa = Rsa::generate(2048)?;
1384        let pkey = PKey::from_rsa(rsa)?;
1385        let name = {
1386            let mut builder = X509NameBuilder::new()?;
1387            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1388            builder.build()
1389        };
1390        let cert = {
1391            let mut builder = X509::builder()?;
1392            builder.set_version(2)?;
1393            builder.set_pubkey(&pkey)?;
1394            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1395            builder.set_subject_name(&name)?;
1396            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1397            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1398            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1399            builder.sign(
1400                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1401                MessageDigest::sha256(),
1402            )?;
1403            builder.build()
1404        };
1405        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1406        Ok(Ca {
1407            dir,
1408            name,
1409            cert,
1410            pkey,
1411        })
1412    }
1413
1414    /// Creates a new root certificate authority.
1415    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1416        Ca::make_ca(name, None)
1417    }
1418
1419    /// Returns the path to the CA's certificate.
1420    pub fn ca_cert_path(&self) -> PathBuf {
1421        self.dir.path().join("ca.crt")
1422    }
1423
1424    /// Requests a new intermediate certificate authority.
1425    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1426        Ca::make_ca(name, Some(self))
1427    }
1428
1429    /// Generates a certificate with the specified Common Name (CN) that is
1430    /// signed by the CA.
1431    ///
1432    /// Returns the paths to the certificate and key.
1433    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1434        self.request_cert(name, iter::empty())
1435    }
1436
1437    /// Like `request_client_cert`, but permits specifying additional IP
1438    /// addresses to attach as Subject Alternate Names.
1439    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1440    where
1441        I: IntoIterator<Item = IpAddr>,
1442    {
1443        let rsa = Rsa::generate(2048)?;
1444        let pkey = PKey::from_rsa(rsa)?;
1445        let subject_name = {
1446            let mut builder = X509NameBuilder::new()?;
1447            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1448            builder.build()
1449        };
1450        let cert = {
1451            let mut builder = X509::builder()?;
1452            builder.set_version(2)?;
1453            builder.set_pubkey(&pkey)?;
1454            builder.set_issuer_name(self.cert.subject_name())?;
1455            builder.set_subject_name(&subject_name)?;
1456            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1457            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1458            for ip in ips {
1459                builder.append_extension(
1460                    SubjectAlternativeName::new()
1461                        .ip(&ip.to_string())
1462                        .build(&builder.x509v3_context(None, None))?,
1463                )?;
1464            }
1465            builder.sign(&self.pkey, MessageDigest::sha256())?;
1466            builder.build()
1467        };
1468        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1469        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1470        fs::write(&cert_path, cert.to_pem()?)?;
1471        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1472        Ok((cert_path, key_path))
1473    }
1474}