Skip to main content

mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90    CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91    WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97/// Entry point for creating and configuring an `environmentd` test harness.
98#[derive(Clone)]
99pub struct TestHarness {
100    data_directory: Option<PathBuf>,
101    tls: Option<TlsCertConfig>,
102    frontegg: Option<FronteggAuthenticator>,
103    external_login_password_mz_system: Option<Password>,
104    listeners_config: ListenersConfig,
105    unsafe_mode: bool,
106    workers: usize,
107    now: NowFn,
108    seed: u32,
109    storage_usage_collection_interval: Duration,
110    storage_usage_retention_period: Option<Duration>,
111    default_cluster_replica_size: String,
112    default_cluster_replication_factor: u32,
113    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119    propagate_crashes: bool,
120    enable_tracing: bool,
121    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
122    // tracing.
123    orchestrator_tracing_cli_args: TracingCliArgs,
124    bootstrap_role: Option<String>,
125    deploy_generation: u64,
126    system_parameter_defaults: BTreeMap<String, String>,
127    internal_console_redirect_url: Option<String>,
128    metrics_registry: Option<MetricsRegistry>,
129    code_version: semver::Version,
130    capture: Option<SharedStorage>,
131    pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135    fn default() -> TestHarness {
136        TestHarness {
137            data_directory: None,
138            tls: None,
139            frontegg: None,
140            external_login_password_mz_system: None,
141            listeners_config: ListenersConfig {
142                sql: btreemap![
143                    "external".to_owned() => SqlListenerConfig {
144                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145                        authenticator_kind: AuthenticatorKind::None,
146                        allowed_roles: AllowedRoles::Normal,
147                        enable_tls: false,
148                    },
149                    "internal".to_owned() => SqlListenerConfig {
150                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151                        authenticator_kind: AuthenticatorKind::None,
152                        allowed_roles: AllowedRoles::NormalAndInternal,
153                        enable_tls: false,
154                    },
155                ],
156                http: btreemap![
157                    "external".to_owned() => HttpListenerConfig {
158                        base: BaseListenerConfig {
159                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160                            authenticator_kind: AuthenticatorKind::None,
161                            allowed_roles: AllowedRoles::Normal,
162                            enable_tls: false,
163                        },
164                        routes: HttpRoutesEnabled{
165                            base: true,
166                            webhook: true,
167                            internal: false,
168                            metrics: false,
169                            profiling: false,
170                        },
171                    },
172                    "internal".to_owned() => HttpListenerConfig {
173                        base: BaseListenerConfig {
174                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
175                            authenticator_kind: AuthenticatorKind::None,
176                            allowed_roles: AllowedRoles::NormalAndInternal,
177                            enable_tls: false,
178                        },
179                        routes: HttpRoutesEnabled{
180                            base: true,
181                            webhook: true,
182                            internal: true,
183                            metrics: true,
184                            profiling: true,
185                        },
186                    },
187                ],
188            },
189            unsafe_mode: false,
190            workers: 1,
191            now: SYSTEM_TIME.clone(),
192            seed: rand::random(),
193            storage_usage_collection_interval: Duration::from_secs(3600),
194            storage_usage_retention_period: None,
195            default_cluster_replica_size: "scale=1,workers=1".to_string(),
196            default_cluster_replication_factor: 1,
197            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
198                size: "scale=1,workers=1".to_string(),
199                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
200            },
201            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
202                size: "scale=1,workers=1".to_string(),
203                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204            },
205            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
206                size: "scale=1,workers=1".to_string(),
207                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208            },
209            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
210                size: "scale=1,workers=1".to_string(),
211                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212            },
213            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
214                size: "scale=1,workers=1".to_string(),
215                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216            },
217            propagate_crashes: false,
218            enable_tracing: false,
219            bootstrap_role: Some("materialize".into()),
220            deploy_generation: 0,
221            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
222            // If we need those in the future, we might need to change both.
223            system_parameter_defaults: BTreeMap::from([(
224                "log_filter".to_string(),
225                "error".to_string(),
226            )]),
227            internal_console_redirect_url: None,
228            metrics_registry: None,
229            orchestrator_tracing_cli_args: TracingCliArgs {
230                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
231                ..Default::default()
232            },
233            code_version: crate::BUILD_INFO.semver_version(),
234            environment_id: EnvironmentId::for_tests(),
235            capture: None,
236        }
237    }
238}
239
240impl TestHarness {
241    /// Starts a test [`TestServer`], panicking if the server could not be started.
242    ///
243    /// For cases when startup might fail, see [`TestHarness::try_start`].
244    pub async fn start(self) -> TestServer {
245        self.try_start().await.expect("Failed to start test Server")
246    }
247
248    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
249    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
250        self.try_start_with_trigger(tls_reload_certs)
251            .await
252            .expect("Failed to start test Server")
253    }
254
255    /// Starts a test [`TestServer`], returning an error if the server could not be started.
256    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
257        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
258            .await
259    }
260
261    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
262    pub async fn try_start_with_trigger(
263        self,
264        tls_reload_certs: ReloadTrigger,
265    ) -> Result<TestServer, anyhow::Error> {
266        let listeners = Listeners::new(&self).await?;
267        listeners.serve_with_trigger(self, tls_reload_certs).await
268    }
269
270    /// Starts a runtime and returns a [`TestServerWithRuntime`].
271    pub fn start_blocking(self) -> TestServerWithRuntime {
272        let runtime = tokio::runtime::Builder::new_multi_thread()
273            .enable_all()
274            .thread_stack_size(mz_ore::stack::STACK_SIZE)
275            .build()
276            .expect("failed to spawn runtime for test");
277        let runtime = Arc::new(runtime);
278        let server = runtime.block_on(self.start());
279        TestServerWithRuntime { runtime, server }
280    }
281
282    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
283        self.data_directory = Some(data_directory.into());
284        self
285    }
286
287    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
288        self.tls = Some(TlsCertConfig {
289            cert: cert_path.into(),
290            key: key_path.into(),
291        });
292        for (_, listener) in &mut self.listeners_config.sql {
293            listener.enable_tls = true;
294        }
295        for (_, listener) in &mut self.listeners_config.http {
296            listener.base.enable_tls = true;
297        }
298        self
299    }
300
301    pub fn unsafe_mode(mut self) -> Self {
302        self.unsafe_mode = true;
303        self
304    }
305
306    pub fn workers(mut self, workers: usize) -> Self {
307        self.workers = workers;
308        self
309    }
310
311    pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
312        self.frontegg = Some(frontegg.clone());
313        let enable_tls = self.tls.is_some();
314        self.listeners_config = ListenersConfig {
315            sql: btreemap! {
316                "external".to_owned() => SqlListenerConfig {
317                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
318                    authenticator_kind: AuthenticatorKind::Frontegg,
319                    allowed_roles: AllowedRoles::Normal,
320                    enable_tls,
321                },
322                "internal".to_owned() => SqlListenerConfig {
323                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
324                    authenticator_kind: AuthenticatorKind::None,
325                    allowed_roles: AllowedRoles::NormalAndInternal,
326                    enable_tls: false,
327                },
328            },
329            http: btreemap! {
330                "external".to_owned() => HttpListenerConfig {
331                    base: BaseListenerConfig {
332                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
333                        authenticator_kind: AuthenticatorKind::Frontegg,
334                        allowed_roles: AllowedRoles::Normal,
335                        enable_tls,
336                    },
337                    routes: HttpRoutesEnabled{
338                        base: true,
339                        webhook: true,
340                        internal: false,
341                        metrics: false,
342                        profiling: false,
343                    },
344                },
345                "internal".to_owned() => HttpListenerConfig {
346                    base: BaseListenerConfig {
347                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
348                        authenticator_kind: AuthenticatorKind::None,
349                        allowed_roles: AllowedRoles::NormalAndInternal,
350                        enable_tls: false,
351                    },
352                    routes: HttpRoutesEnabled{
353                        base: true,
354                        webhook: true,
355                        internal: true,
356                        metrics: true,
357                        profiling: true,
358                    },
359                },
360            },
361        };
362        self
363    }
364
365    pub fn with_oidc_auth(mut self, issuer: Option<String>, audience: Option<String>) -> Self {
366        let enable_tls = self.tls.is_some();
367        self.listeners_config = ListenersConfig {
368            sql: btreemap! {
369                "external".to_owned() => SqlListenerConfig {
370                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
371                    authenticator_kind: AuthenticatorKind::Oidc,
372                    allowed_roles: AllowedRoles::Normal,
373                    enable_tls,
374                },
375                "internal".to_owned() => SqlListenerConfig {
376                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
377                    authenticator_kind: AuthenticatorKind::None,
378                    allowed_roles: AllowedRoles::NormalAndInternal,
379                    enable_tls: false,
380                },
381            },
382            http: btreemap! {
383                "external".to_owned() => HttpListenerConfig {
384                    base: BaseListenerConfig {
385                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
386                        authenticator_kind: AuthenticatorKind::Oidc,
387                        allowed_roles: AllowedRoles::Normal,
388                        enable_tls,
389                    },
390                    routes: HttpRoutesEnabled{
391                        base: true,
392                        webhook: true,
393                        internal: false,
394                        metrics: false,
395                        profiling: false,
396                    },
397                },
398                "internal".to_owned() => HttpListenerConfig {
399                    base: BaseListenerConfig {
400                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
401                        authenticator_kind: AuthenticatorKind::None,
402                        allowed_roles: AllowedRoles::NormalAndInternal,
403                        enable_tls: false,
404                    },
405                    routes: HttpRoutesEnabled{
406                        base: true,
407                        webhook: true,
408                        internal: true,
409                        metrics: true,
410                        profiling: true,
411                    },
412                },
413            },
414        };
415
416        if let Some(issuer) = issuer {
417            self.system_parameter_defaults
418                .insert("oidc_issuer".to_string(), issuer);
419        }
420
421        if let Some(audience) = audience {
422            self.system_parameter_defaults
423                .insert("oidc_audience".to_string(), audience);
424        }
425
426        self
427    }
428
429    pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
430        self.external_login_password_mz_system = Some(mz_system_password);
431        let enable_tls = self.tls.is_some();
432        self.listeners_config = ListenersConfig {
433            sql: btreemap! {
434                "external".to_owned() => SqlListenerConfig {
435                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
436                    authenticator_kind: AuthenticatorKind::Password,
437                    allowed_roles: AllowedRoles::NormalAndInternal,
438                    enable_tls,
439                },
440            },
441            http: btreemap! {
442                "external".to_owned() => HttpListenerConfig {
443                    base: BaseListenerConfig {
444                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
445                        authenticator_kind: AuthenticatorKind::Password,
446                        allowed_roles: AllowedRoles::NormalAndInternal,
447                        enable_tls,
448                    },
449                    routes: HttpRoutesEnabled{
450                        base: true,
451                        webhook: true,
452                        internal: true,
453                        metrics: false,
454                        profiling: true,
455                    },
456                },
457                "metrics".to_owned() => HttpListenerConfig {
458                    base: BaseListenerConfig {
459                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
460                        authenticator_kind: AuthenticatorKind::None,
461                        allowed_roles: AllowedRoles::NormalAndInternal,
462                        enable_tls: false,
463                    },
464                    routes: HttpRoutesEnabled{
465                        base: false,
466                        webhook: false,
467                        internal: false,
468                        metrics: true,
469                        profiling: false,
470                    },
471                },
472            },
473        };
474        self
475    }
476
477    pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
478        self.external_login_password_mz_system = Some(mz_system_password);
479        let enable_tls = self.tls.is_some();
480        self.listeners_config = ListenersConfig {
481            sql: btreemap! {
482                "external".to_owned() => SqlListenerConfig {
483                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
484                    authenticator_kind: AuthenticatorKind::Sasl,
485                    allowed_roles: AllowedRoles::NormalAndInternal,
486                    enable_tls,
487                },
488            },
489            http: btreemap! {
490                "external".to_owned() => HttpListenerConfig {
491                    base: BaseListenerConfig {
492                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
493                        authenticator_kind: AuthenticatorKind::Password,
494                        allowed_roles: AllowedRoles::NormalAndInternal,
495                        enable_tls,
496                    },
497                    routes: HttpRoutesEnabled{
498                        base: true,
499                        webhook: true,
500                        internal: true,
501                        metrics: false,
502                        profiling: true,
503                    },
504                },
505                "metrics".to_owned() => HttpListenerConfig {
506                    base: BaseListenerConfig {
507                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
508                        authenticator_kind: AuthenticatorKind::None,
509                        allowed_roles: AllowedRoles::NormalAndInternal,
510                        enable_tls: false,
511                    },
512                    routes: HttpRoutesEnabled{
513                        base: false,
514                        webhook: false,
515                        internal: false,
516                        metrics: true,
517                        profiling: false,
518                    },
519                },
520            },
521        };
522        self
523    }
524
525    pub fn with_now(mut self, now: NowFn) -> Self {
526        self.now = now;
527        self
528    }
529
530    pub fn with_storage_usage_collection_interval(
531        mut self,
532        storage_usage_collection_interval: Duration,
533    ) -> Self {
534        self.storage_usage_collection_interval = storage_usage_collection_interval;
535        self
536    }
537
538    pub fn with_storage_usage_retention_period(
539        mut self,
540        storage_usage_retention_period: Duration,
541    ) -> Self {
542        self.storage_usage_retention_period = Some(storage_usage_retention_period);
543        self
544    }
545
546    pub fn with_default_cluster_replica_size(
547        mut self,
548        default_cluster_replica_size: String,
549    ) -> Self {
550        self.default_cluster_replica_size = default_cluster_replica_size;
551        self
552    }
553
554    pub fn with_builtin_system_cluster_replica_size(
555        mut self,
556        builtin_system_cluster_replica_size: String,
557    ) -> Self {
558        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
559        self
560    }
561
562    pub fn with_builtin_system_cluster_replication_factor(
563        mut self,
564        builtin_system_cluster_replication_factor: u32,
565    ) -> Self {
566        self.builtin_system_cluster_config.replication_factor =
567            builtin_system_cluster_replication_factor;
568        self
569    }
570
571    pub fn with_builtin_catalog_server_cluster_replica_size(
572        mut self,
573        builtin_catalog_server_cluster_replica_size: String,
574    ) -> Self {
575        self.builtin_catalog_server_cluster_config.size =
576            builtin_catalog_server_cluster_replica_size;
577        self
578    }
579
580    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
581        self.propagate_crashes = propagate_crashes;
582        self
583    }
584
585    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
586        self.enable_tracing = enable_tracing;
587        self
588    }
589
590    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
591        self.bootstrap_role = bootstrap_role;
592        self
593    }
594
595    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
596        self.deploy_generation = deploy_generation;
597        self
598    }
599
600    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
601        self.system_parameter_defaults.insert(param, value);
602        self
603    }
604
605    pub fn with_internal_console_redirect_url(
606        mut self,
607        internal_console_redirect_url: Option<String>,
608    ) -> Self {
609        self.internal_console_redirect_url = internal_console_redirect_url;
610        self
611    }
612
613    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
614        self.metrics_registry = Some(registry);
615        self
616    }
617
618    pub fn with_code_version(mut self, version: semver::Version) -> Self {
619        self.code_version = version;
620        self
621    }
622
623    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
624        self.capture = Some(storage);
625        self
626    }
627}
628
629pub struct Listeners {
630    pub inner: crate::Listeners,
631}
632
633impl Listeners {
634    pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
635        let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
636        Ok(Listeners { inner })
637    }
638
639    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
640        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
641            .await
642    }
643
644    pub async fn serve_with_trigger(
645        self,
646        config: TestHarness,
647        tls_reload_certs: ReloadTrigger,
648    ) -> Result<TestServer, anyhow::Error> {
649        let (data_directory, temp_dir) = match config.data_directory {
650            None => {
651                // If no data directory is provided, we create a temporary
652                // directory. The temporary directory is cleaned up when the
653                // `TempDir` is dropped, so we keep it alive until the `Server` is
654                // dropped.
655                let temp_dir = tempfile::tempdir()?;
656                (temp_dir.path().to_path_buf(), Some(temp_dir))
657            }
658            Some(data_directory) => (data_directory, None),
659        };
660        let scratch_dir = tempfile::tempdir()?;
661        let (consensus_uri, timestamp_oracle_url) = {
662            let seed = config.seed;
663            let cockroach_url = env::var("METADATA_BACKEND_URL")
664                .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
665            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
666            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
667                if let Err(err) = conn.await {
668                    panic!("connection error: {}", err);
669                };
670            });
671            client
672                .batch_execute(&format!(
673                    "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
674                    CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
675                ))
676                .await?;
677            (
678                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
679                    .parse()
680                    .expect("invalid consensus URI"),
681                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
682                    .parse()
683                    .expect("invalid timestamp oracle URI"),
684            )
685        };
686        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
687        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
688            image_dir: env::current_exe()?
689                .parent()
690                .unwrap()
691                .parent()
692                .unwrap()
693                .to_path_buf(),
694            suppress_output: false,
695            environment_id: config.environment_id.to_string(),
696            secrets_dir: data_directory.join("secrets"),
697            command_wrapper: vec![],
698            propagate_crashes: config.propagate_crashes,
699            tcp_proxy: None,
700            scratch_directory: scratch_dir.path().to_path_buf(),
701        })
702        .await?;
703        let orchestrator = Arc::new(orchestrator);
704        // Messing with the clock causes persist to expire leases, causing hangs and
705        // panics. Is it possible/desirable to put this back somehow?
706        let persist_now = SYSTEM_TIME.clone();
707        let dyncfgs = mz_dyncfgs::all_dyncfgs();
708
709        let mut updates = ConfigUpdates::default();
710        // Tune down the number of connections to make this all work a little easier
711        // with local postgres.
712        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
713        updates.apply(&dyncfgs);
714
715        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
716        persist_cfg.build_version = config.code_version;
717        // Stress persist more by writing rollups frequently
718        persist_cfg.set_rollup_threshold(5);
719
720        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
721        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
722        let persist_pubsub_tcp_listener =
723            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
724                .await
725                .expect("pubsub addr binding");
726        let persist_pubsub_server_port = persist_pubsub_tcp_listener
727            .local_addr()
728            .expect("pubsub addr has local addr")
729            .port();
730
731        // Spawn the persist pub-sub server.
732        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
733            persist_pubsub_server
734                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
735                .await
736                .expect("success")
737        });
738        let persist_clients =
739            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
740        let persist_clients = Arc::new(persist_clients);
741
742        let secrets_controller = Arc::clone(&orchestrator);
743        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
744        let orchestrator = Arc::new(TracingOrchestrator::new(
745            orchestrator,
746            config.orchestrator_tracing_cli_args,
747        ));
748        let tracing_handle = if config.enable_tracing {
749            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
750                service_name: "environmentd",
751                stderr_log: StderrLogConfig {
752                    format: StderrLogFormat::Json,
753                    filter: EnvFilter::default(),
754                },
755                opentelemetry: Some(OpenTelemetryConfig {
756                    endpoint: "http://fake_address_for_testing:8080".to_string(),
757                    headers: http::HeaderMap::new(),
758                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
759                    resource: opentelemetry_sdk::resource::Resource::builder().build(),
760                    max_batch_queue_size: 2048,
761                    max_export_batch_size: 512,
762                    max_concurrent_exports: 1,
763                    batch_scheduled_delay: Duration::from_millis(5000),
764                    max_export_timeout: Duration::from_secs(30),
765                }),
766                tokio_console: None,
767                sentry: None,
768                build_version: crate::BUILD_INFO.version,
769                build_sha: crate::BUILD_INFO.sha,
770                registry: metrics_registry.clone(),
771                capture: config.capture,
772            };
773            mz_ore::tracing::configure(config).await?
774        } else {
775            TracingHandle::disabled()
776        };
777        let host_name = format!(
778            "localhost:{}",
779            self.inner.http["external"].handle.local_addr.port()
780        );
781        let catalog_config = CatalogConfig {
782            persist_clients: Arc::clone(&persist_clients),
783            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
784        };
785
786        let inner = self
787            .inner
788            .serve(crate::Config {
789                catalog_config,
790                timestamp_oracle_url: Some(timestamp_oracle_url),
791                controller: ControllerConfig {
792                    build_info: &crate::BUILD_INFO,
793                    orchestrator,
794                    clusterd_image: "clusterd".into(),
795                    init_container_image: None,
796                    deploy_generation: config.deploy_generation,
797                    persist_location: PersistLocation {
798                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
799                            .parse()
800                            .expect("invalid blob URI"),
801                        consensus_uri,
802                    },
803                    persist_clients,
804                    now: config.now.clone(),
805                    metrics_registry: metrics_registry.clone(),
806                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
807                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
808                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
809                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
810                        secrets_reader_kubernetes_context: None,
811                        secrets_reader_aws_prefix: None,
812                        secrets_reader_name_prefix: None,
813                    },
814                    connection_context,
815                    replica_http_locator: Default::default(),
816                },
817                secrets_controller,
818                cloud_resource_controller: None,
819                tls: config.tls,
820                frontegg: config.frontegg,
821                unsafe_mode: config.unsafe_mode,
822                all_features: false,
823                metrics_registry: metrics_registry.clone(),
824                now: config.now,
825                environment_id: config.environment_id,
826                cors_allowed_origin: AllowOrigin::list([]),
827                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
828                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
829                bootstrap_default_cluster_replication_factor: config
830                    .default_cluster_replication_factor,
831                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
832                bootstrap_builtin_catalog_server_cluster_config: config
833                    .builtin_catalog_server_cluster_config,
834                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
835                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
836                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
837                system_parameter_defaults: config.system_parameter_defaults,
838                availability_zones: Default::default(),
839                tracing_handle,
840                storage_usage_collection_interval: config.storage_usage_collection_interval,
841                storage_usage_retention_period: config.storage_usage_retention_period,
842                segment_api_key: None,
843                segment_client_side: false,
844                test_only_dummy_segment_client: false,
845                egress_addresses: vec![],
846                aws_account_id: None,
847                aws_privatelink_availability_zones: None,
848                launchdarkly_sdk_key: None,
849                launchdarkly_key_map: Default::default(),
850                config_sync_file_path: None,
851                config_sync_timeout: Duration::from_secs(30),
852                config_sync_loop_interval: None,
853                bootstrap_role: config.bootstrap_role,
854                http_host_name: Some(host_name),
855                internal_console_redirect_url: config.internal_console_redirect_url,
856                tls_reload_certs,
857                helm_chart_version: None,
858                license_key: ValidatedLicenseKey::for_tests(),
859                external_login_password_mz_system: config.external_login_password_mz_system,
860                force_builtin_schema_migration: None,
861            })
862            .await?;
863
864        Ok(TestServer {
865            inner,
866            metrics_registry,
867            _temp_dir: temp_dir,
868            _scratch_dir: scratch_dir,
869        })
870    }
871}
872
873/// A running instance of `environmentd`.
874pub struct TestServer {
875    pub inner: crate::Server,
876    pub metrics_registry: MetricsRegistry,
877    /// The `TempDir`s are saved to prevent them from being dropped, and thus cleaned up too early.
878    _temp_dir: Option<TempDir>,
879    _scratch_dir: TempDir,
880}
881
882impl TestServer {
883    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
884        ConnectBuilder::new(self).no_tls()
885    }
886
887    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
888        let internal_client = self.connect().internal().await.unwrap();
889
890        for flag in flags {
891            internal_client
892                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
893                .await
894                .unwrap();
895        }
896    }
897
898    pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
899        let internal_client = self.connect().internal().await.unwrap();
900
901        for flag in flags {
902            internal_client
903                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
904                .await
905                .unwrap();
906        }
907    }
908
909    pub fn ws_addr(&self) -> Uri {
910        format!(
911            "ws://{}/api/experimental/sql",
912            self.inner.http_listener_handles["external"].local_addr
913        )
914        .parse()
915        .unwrap()
916    }
917
918    pub fn internal_ws_addr(&self) -> Uri {
919        format!(
920            "ws://{}/api/experimental/sql",
921            self.inner.http_listener_handles["internal"].local_addr
922        )
923        .parse()
924        .unwrap()
925    }
926
927    pub fn http_local_addr(&self) -> SocketAddr {
928        self.inner.http_listener_handles["external"].local_addr
929    }
930
931    pub fn internal_http_local_addr(&self) -> SocketAddr {
932        self.inner.http_listener_handles["internal"].local_addr
933    }
934
935    pub fn sql_local_addr(&self) -> SocketAddr {
936        self.inner.sql_listener_handles["external"].local_addr
937    }
938
939    pub fn internal_sql_local_addr(&self) -> SocketAddr {
940        self.inner.sql_listener_handles["internal"].local_addr
941    }
942}
943
944/// A builder struct to configure a pgwire connection to a running [`TestServer`].
945///
946/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
947pub struct ConnectBuilder<'s, T, H> {
948    /// A running `environmentd` test server.
949    server: &'s TestServer,
950
951    /// Postgres configuration for connecting to the test server.
952    pg_config: tokio_postgres::Config,
953    /// Port to use when connecting to the test server.
954    port: u16,
955    /// Tls settings to use.
956    tls: T,
957
958    /// Callback that gets invoked for every notice we receive.
959    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
960
961    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
962    _with_handle: H,
963}
964
965impl<'s> ConnectBuilder<'s, (), NoHandle> {
966    fn new(server: &'s TestServer) -> Self {
967        let mut pg_config = tokio_postgres::Config::new();
968        pg_config
969            .host(&Ipv4Addr::LOCALHOST.to_string())
970            .user("materialize")
971            .options("--welcome_message=off")
972            .application_name("environmentd_test_framework");
973
974        ConnectBuilder {
975            server,
976            pg_config,
977            port: server.sql_local_addr().port(),
978            tls: (),
979            notice_callback: None,
980            _with_handle: NoHandle,
981        }
982    }
983}
984
985impl<'s, T, H> ConnectBuilder<'s, T, H> {
986    /// Create a pgwire connection without using TLS.
987    ///
988    /// Note: this is the default for all connections.
989    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
990        ConnectBuilder {
991            server: self.server,
992            pg_config: self.pg_config,
993            port: self.port,
994            tls: postgres::NoTls,
995            notice_callback: self.notice_callback,
996            _with_handle: self._with_handle,
997        }
998    }
999
1000    /// Create a pgwire connection with TLS.
1001    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1002    where
1003        Tls: MakeTlsConnect<Socket> + Send + 'static,
1004        Tls::TlsConnect: Send,
1005        Tls::Stream: Send,
1006        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1007    {
1008        ConnectBuilder {
1009            server: self.server,
1010            pg_config: self.pg_config,
1011            port: self.port,
1012            tls,
1013            notice_callback: self.notice_callback,
1014            _with_handle: self._with_handle,
1015        }
1016    }
1017
1018    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
1019    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1020        self.pg_config = pg_config;
1021        self
1022    }
1023
1024    /// Set the [`SslMode`] to be used with the resulting connection.
1025    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1026        self.pg_config.ssl_mode(mode);
1027        self
1028    }
1029
1030    /// Set the user for the pgwire connection.
1031    pub fn user(mut self, user: &str) -> Self {
1032        self.pg_config.user(user);
1033        self
1034    }
1035
1036    /// Set the password for the pgwire connection.
1037    pub fn password(mut self, password: &str) -> Self {
1038        self.pg_config.password(password);
1039        self
1040    }
1041
1042    /// Set the application name for the pgwire connection.
1043    pub fn application_name(mut self, application_name: &str) -> Self {
1044        self.pg_config.application_name(application_name);
1045        self
1046    }
1047
1048    /// Set the database name for the pgwire connection.
1049    pub fn dbname(mut self, dbname: &str) -> Self {
1050        self.pg_config.dbname(dbname);
1051        self
1052    }
1053
1054    /// Set the options for the pgwire connection.
1055    pub fn options(mut self, options: &str) -> Self {
1056        self.pg_config.options(options);
1057        self
1058    }
1059
1060    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
1061    /// [`TestServer`].
1062    ///
1063    /// For example, this will change the port we connect to, and the user we connect as.
1064    pub fn internal(mut self) -> Self {
1065        self.port = self.server.internal_sql_local_addr().port();
1066        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1067        self
1068    }
1069
1070    /// Sets a callback for any database notices that are received from the [`TestServer`].
1071    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1072        ConnectBuilder {
1073            notice_callback: Some(Box::new(callback)),
1074            ..self
1075        }
1076    }
1077
1078    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
1079    /// polling the underlying postgres connection, associated with the returned client.
1080    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1081        ConnectBuilder {
1082            server: self.server,
1083            pg_config: self.pg_config,
1084            port: self.port,
1085            tls: self.tls,
1086            notice_callback: self.notice_callback,
1087            _with_handle: WithHandle,
1088        }
1089    }
1090
1091    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
1092    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1093        &self.pg_config
1094    }
1095}
1096
1097/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
1098/// of a client connection.
1099pub trait IncludeHandle: Send {
1100    type Output;
1101    fn transform_result(
1102        client: tokio_postgres::Client,
1103        handle: mz_ore::task::JoinHandle<()>,
1104    ) -> Self::Output;
1105}
1106
1107/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
1108/// result of a [`ConnectBuilder`].
1109pub struct NoHandle;
1110impl IncludeHandle for NoHandle {
1111    type Output = tokio_postgres::Client;
1112    fn transform_result(
1113        client: tokio_postgres::Client,
1114        _handle: mz_ore::task::JoinHandle<()>,
1115    ) -> Self::Output {
1116        client
1117    }
1118}
1119
1120/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
1121/// a [`ConnectBuilder`].
1122pub struct WithHandle;
1123impl IncludeHandle for WithHandle {
1124    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1125    fn transform_result(
1126        client: tokio_postgres::Client,
1127        handle: mz_ore::task::JoinHandle<()>,
1128    ) -> Self::Output {
1129        (client, handle)
1130    }
1131}
1132
1133impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1134where
1135    T: MakeTlsConnect<Socket> + Send + 'static,
1136    T::TlsConnect: Send,
1137    T::Stream: Send,
1138    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1139    H: IncludeHandle,
1140{
1141    type Output = Result<H::Output, postgres::Error>;
1142    type IntoFuture = BoxFuture<'static, Self::Output>;
1143
1144    fn into_future(mut self) -> Self::IntoFuture {
1145        Box::pin(async move {
1146            assert!(
1147                self.pg_config.get_ports().is_empty(),
1148                "specifying multiple ports is not supported"
1149            );
1150            self.pg_config.port(self.port);
1151
1152            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1153            let mut notice_callback = self.notice_callback.take();
1154
1155            let handle = task::spawn(|| "connect", async move {
1156                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1157                    match msg {
1158                        Ok(AsyncMessage::Notice(notice)) => {
1159                            if let Some(callback) = notice_callback.as_mut() {
1160                                callback(notice);
1161                            }
1162                        }
1163                        Ok(msg) => {
1164                            tracing::debug!(?msg, "Dropping message from database");
1165                        }
1166                        Err(e) => {
1167                            // tokio_postgres::Connection docs say:
1168                            // > Return values of None or Some(Err(_)) are “terminal”; callers
1169                            // > should not invoke this method again after receiving one of those
1170                            // > values.
1171                            tracing::info!("connection error: {e}");
1172                            break;
1173                        }
1174                    }
1175                }
1176                tracing::info!("connection closed");
1177            });
1178
1179            let output = H::transform_result(client, handle);
1180            Ok(output)
1181        })
1182    }
1183}
1184
1185/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
1186///
1187/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
1188/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
1189pub struct TestServerWithRuntime {
1190    server: TestServer,
1191    runtime: Arc<Runtime>,
1192}
1193
1194impl TestServerWithRuntime {
1195    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
1196    ///
1197    /// Can be used to spawn async tasks.
1198    pub fn runtime(&self) -> &Arc<Runtime> {
1199        &self.runtime
1200    }
1201
1202    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
1203    pub fn inner(&self) -> &crate::Server {
1204        &self.server.inner
1205    }
1206
1207    /// Connect to the __public__ SQL port of the running `environmentd` server.
1208    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1209    where
1210        T: MakeTlsConnect<Socket> + Send + 'static,
1211        T::TlsConnect: Send,
1212        T::Stream: Send,
1213        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1214    {
1215        self.pg_config().connect(tls)
1216    }
1217
1218    /// Connect to the __internal__ SQL port of the running `environmentd` server.
1219    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1220    where
1221        T: MakeTlsConnect<Socket> + Send + 'static,
1222        T::TlsConnect: Send,
1223        T::Stream: Send,
1224        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1225    {
1226        Ok(self.pg_config_internal().connect(tls)?)
1227    }
1228
1229    /// Enable LaunchDarkly feature flags.
1230    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1231        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1232
1233        for flag in flags {
1234            internal_client
1235                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1236                .unwrap();
1237        }
1238    }
1239
1240    /// Disable LaunchDarkly feature flags.
1241    pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1242        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1243
1244        for flag in flags {
1245            internal_client
1246                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1247                .unwrap();
1248        }
1249    }
1250
1251    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
1252    /// `environmentd` server.
1253    pub fn pg_config(&self) -> postgres::Config {
1254        let local_addr = self.server.sql_local_addr();
1255        let mut config = postgres::Config::new();
1256        config
1257            .host(&Ipv4Addr::LOCALHOST.to_string())
1258            .port(local_addr.port())
1259            .user("materialize")
1260            .options("--welcome_message=off");
1261        config
1262    }
1263
1264    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
1265    /// `environmentd` server.
1266    pub fn pg_config_internal(&self) -> postgres::Config {
1267        let local_addr = self.server.internal_sql_local_addr();
1268        let mut config = postgres::Config::new();
1269        config
1270            .host(&Ipv4Addr::LOCALHOST.to_string())
1271            .port(local_addr.port())
1272            .user("mz_system")
1273            .options("--welcome_message=off");
1274        config
1275    }
1276
1277    pub fn ws_addr(&self) -> Uri {
1278        self.server.ws_addr()
1279    }
1280
1281    pub fn internal_ws_addr(&self) -> Uri {
1282        self.server.internal_ws_addr()
1283    }
1284
1285    pub fn http_local_addr(&self) -> SocketAddr {
1286        self.server.http_local_addr()
1287    }
1288
1289    pub fn internal_http_local_addr(&self) -> SocketAddr {
1290        self.server.internal_http_local_addr()
1291    }
1292
1293    pub fn sql_local_addr(&self) -> SocketAddr {
1294        self.server.sql_local_addr()
1295    }
1296
1297    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1298        self.server.internal_sql_local_addr()
1299    }
1300
1301    /// Returns the metrics registry for the test server.
1302    pub fn metrics_registry(&self) -> &MetricsRegistry {
1303        &self.server.metrics_registry
1304    }
1305}
1306
1307#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1308pub struct MzTimestamp(pub u64);
1309
1310impl<'a> FromSql<'a> for MzTimestamp {
1311    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1312        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1313        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1314    }
1315
1316    fn accepts(ty: &Type) -> bool {
1317        mz_pgrepr::Numeric::accepts(ty)
1318    }
1319}
1320
1321pub trait PostgresErrorExt {
1322    fn unwrap_db_error(self) -> DbError;
1323}
1324
1325impl PostgresErrorExt for postgres::Error {
1326    fn unwrap_db_error(self) -> DbError {
1327        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1328            Some(e) => e.clone(),
1329            None => panic!("expected DbError, but got: {:?}", self),
1330        }
1331    }
1332}
1333
1334impl<T, E> PostgresErrorExt for Result<T, E>
1335where
1336    E: PostgresErrorExt,
1337{
1338    fn unwrap_db_error(self) -> DbError {
1339        match self {
1340            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1341            Err(e) => e.unwrap_db_error(),
1342        }
1343    }
1344}
1345
1346/// Group commit will block writes until the current time has advanced. This can make
1347/// performing inserts while using deterministic time difficult. This is a helper
1348/// method to perform writes and advance the current time.
1349pub async fn insert_with_deterministic_timestamps(
1350    table: &'static str,
1351    values: &'static str,
1352    server: &TestServer,
1353    now: Arc<std::sync::Mutex<EpochMillis>>,
1354) -> Result<(), Box<dyn Error>> {
1355    let client_write = server.connect().await?;
1356    let client_read = server.connect().await?;
1357
1358    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1359
1360    let insert_query = format!("INSERT INTO {table} VALUES {values}");
1361
1362    let write_future = client_write.execute(&insert_query, &[]);
1363    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1364
1365    let mut write_future = std::pin::pin!(write_future);
1366    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1367
1368    // Keep increasing `now` until the write has executed succeed. Table advancements may
1369    // have increased the global timestamp by an unknown amount.
1370    loop {
1371        tokio::select! {
1372            _ = (&mut write_future) => return Ok(()),
1373            _ = timestamp_interval.tick() => {
1374                current_ts += 1;
1375                *now.lock().expect("lock poisoned") = current_ts;
1376            }
1377        };
1378    }
1379}
1380
1381pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1382    try_get_explain_timestamp(from_suffix, client)
1383        .await
1384        .unwrap()
1385}
1386
1387pub async fn try_get_explain_timestamp(
1388    from_suffix: &str,
1389    client: &Client,
1390) -> Result<EpochMillis, anyhow::Error> {
1391    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1392    let ts = det.determination.timestamp_context.timestamp_or_default();
1393    Ok(ts.into())
1394}
1395
1396pub async fn get_explain_timestamp_determination(
1397    from_suffix: &str,
1398    client: &Client,
1399) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1400    let row = client
1401        .query_one(
1402            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1403            &[],
1404        )
1405        .await?;
1406    let explain: String = row.get(0);
1407    Ok(serde_json::from_str(&explain).unwrap())
1408}
1409
1410/// Helper function to create a Postgres source.
1411///
1412/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1413/// Postgres state.
1414///
1415/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1416/// use different postgres tables.
1417pub async fn create_postgres_source_with_table<'a>(
1418    server: &TestServer,
1419    mz_client: &Client,
1420    table_name: &str,
1421    table_schema: &str,
1422    source_name: &str,
1423) -> (
1424    Client,
1425    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1426) {
1427    server
1428        .enable_feature_flags(&["enable_create_table_from_source"])
1429        .await;
1430
1431    let postgres_url = env::var("POSTGRES_URL")
1432        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1433        .unwrap();
1434
1435    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1436        .await
1437        .unwrap();
1438
1439    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1440    let user = pg_config.get_user().unwrap_or("postgres");
1441    let db_name = pg_config.get_dbname().unwrap_or(user);
1442    let ports = pg_config.get_ports();
1443    let port = if ports.is_empty() { 5432 } else { ports[0] };
1444    let hosts = pg_config.get_hosts();
1445    let host = if hosts.is_empty() {
1446        "localhost".to_string()
1447    } else {
1448        match &hosts[0] {
1449            Host::Tcp(host) => host.to_string(),
1450            Host::Unix(host) => host.to_str().unwrap().to_string(),
1451        }
1452    };
1453    let password = pg_config.get_password();
1454
1455    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1456        if let Err(e) = connection.await {
1457            panic!("connection error: {}", e);
1458        }
1459    });
1460
1461    // Create table in Postgres with publication.
1462    let _ = pg_client
1463        .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1464        .await
1465        .unwrap();
1466    let _ = pg_client
1467        .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1468        .await
1469        .unwrap();
1470    let _ = pg_client
1471        .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1472        .await
1473        .unwrap();
1474    let _ = pg_client
1475        .execute(
1476            &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1477            &[],
1478        )
1479        .await
1480        .unwrap();
1481    let _ = pg_client
1482        .execute(
1483            &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1484            &[],
1485        )
1486        .await
1487        .unwrap();
1488
1489    // Create postgres source in Materialize.
1490    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1491    if let Some(password) = password {
1492        let password = std::str::from_utf8(password).unwrap();
1493        mz_client
1494            .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1495            .await
1496            .unwrap();
1497        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1498    }
1499    mz_client
1500        .batch_execute(&format!(
1501            "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1502        ))
1503        .await
1504        .unwrap();
1505    mz_client
1506        .batch_execute(&format!(
1507            "CREATE SOURCE {source_name}
1508            FROM POSTGRES
1509            CONNECTION pgconn
1510            (PUBLICATION '{source_name}')"
1511        ))
1512        .await
1513        .unwrap();
1514    mz_client
1515        .batch_execute(&format!(
1516            "CREATE TABLE {table_name}
1517            FROM SOURCE {source_name}
1518            (REFERENCE {table_name});"
1519        ))
1520        .await
1521        .unwrap();
1522
1523    let table_name = table_name.to_string();
1524    let source_name = source_name.to_string();
1525    (
1526        pg_client,
1527        move |mz_client: &'a Client, pg_client: &'a Client| {
1528            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1529                mz_client
1530                    .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1531                    .await
1532                    .unwrap();
1533                mz_client
1534                    .batch_execute("DROP CONNECTION pgconn;")
1535                    .await
1536                    .unwrap();
1537
1538                let _ = pg_client
1539                    .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1540                    .await
1541                    .unwrap();
1542                let _ = pg_client
1543                    .execute(&format!("DROP TABLE {table_name};"), &[])
1544                    .await
1545                    .unwrap();
1546            });
1547            f
1548        },
1549    )
1550}
1551
1552pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1553    let current_isolation = mz_client
1554        .query_one("SHOW transaction_isolation", &[])
1555        .await
1556        .unwrap()
1557        .get::<_, String>(0);
1558    mz_client
1559        .batch_execute("SET transaction_isolation = SERIALIZABLE")
1560        .await
1561        .unwrap();
1562    Retry::default()
1563        .retry_async(|_| async move {
1564            let rows = mz_client
1565                .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1566                .await
1567                .unwrap()
1568                .get::<_, i64>(0);
1569            if rows == source_rows {
1570                Ok(())
1571            } else {
1572                Err(format!(
1573                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1574                ))
1575            }
1576        })
1577        .await
1578        .unwrap();
1579    mz_client
1580        .batch_execute(&format!(
1581            "SET transaction_isolation = '{current_isolation}'"
1582        ))
1583        .await
1584        .unwrap();
1585}
1586
1587// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1588pub fn auth_with_ws(
1589    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1590    mut options: BTreeMap<String, String>,
1591) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1592    if !options.contains_key("welcome_message") {
1593        options.insert("welcome_message".into(), "off".into());
1594    }
1595    auth_with_ws_impl(
1596        ws,
1597        Message::Text(
1598            serde_json::to_string(&WebSocketAuth::Basic {
1599                user: "materialize".into(),
1600                password: "".into(),
1601                options,
1602            })
1603            .unwrap()
1604            .into(),
1605        ),
1606    )
1607}
1608
1609pub fn auth_with_ws_impl(
1610    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1611    auth_message: Message,
1612) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1613    ws.send(auth_message)?;
1614
1615    // Wait for initial ready response.
1616    let mut msgs = Vec::new();
1617    loop {
1618        let resp = ws.read()?;
1619        match resp {
1620            Message::Text(msg) => {
1621                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1622                match msg {
1623                    WebSocketResponse::ReadyForQuery(_) => break,
1624                    msg => {
1625                        msgs.push(msg);
1626                    }
1627                }
1628            }
1629            Message::Ping(_) => continue,
1630            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1631            Message::Close(Some(close_frame)) => {
1632                return Err(anyhow!("ws closed after auth").context(close_frame));
1633            }
1634            _ => panic!("unexpected response: {:?}", resp),
1635        }
1636    }
1637    Ok(msgs)
1638}
1639
1640pub fn make_header<H: Header>(h: H) -> HeaderMap {
1641    let mut map = HeaderMap::new();
1642    map.typed_insert(h);
1643    map
1644}
1645
1646pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1647where
1648    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1649{
1650    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1651    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1652    // messages with TLS v1.2.
1653    //
1654    // Briefly, in TLS v1.3, failing to present a client certificate does not
1655    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1656    // attempt to read from the stream. But both `postgres` and `hyper` write a
1657    // bunch of data before attempting to read from the stream. With a failed
1658    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1659    // out this data, and then return a nice error message on the call to read.
1660    // But sometimes the connection is closed before they write out the data,
1661    // and so they report "connection closed" before they ever call read, never
1662    // noticing the underlying SSL error.
1663    //
1664    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1665    // if writing to the stream fails to see if a TLS error occured? Is it on
1666    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1667    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1668    // just avoid TLS v1.3 for now.
1669    //
1670    // [1]: https://github.com/openssl/openssl/issues/11118
1671    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1672    connector_builder.set_options(options);
1673    configure(&mut connector_builder).unwrap();
1674    MakeTlsConnector::new(connector_builder.build())
1675}
1676
1677/// A certificate authority for use in tests.
1678pub struct Ca {
1679    pub dir: TempDir,
1680    pub name: X509Name,
1681    pub cert: X509,
1682    pub pkey: PKey<Private>,
1683}
1684
1685impl Ca {
1686    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1687        let dir = tempfile::tempdir()?;
1688        let rsa = Rsa::generate(2048)?;
1689        let pkey = PKey::from_rsa(rsa)?;
1690        let name = {
1691            let mut builder = X509NameBuilder::new()?;
1692            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1693            builder.build()
1694        };
1695        let cert = {
1696            let mut builder = X509::builder()?;
1697            builder.set_version(2)?;
1698            builder.set_pubkey(&pkey)?;
1699            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1700            builder.set_subject_name(&name)?;
1701            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1702            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1703            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1704            builder.sign(
1705                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1706                MessageDigest::sha256(),
1707            )?;
1708            builder.build()
1709        };
1710        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1711        Ok(Ca {
1712            dir,
1713            name,
1714            cert,
1715            pkey,
1716        })
1717    }
1718
1719    /// Creates a new root certificate authority.
1720    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1721        Ca::make_ca(name, None)
1722    }
1723
1724    /// Returns the path to the CA's certificate.
1725    pub fn ca_cert_path(&self) -> PathBuf {
1726        self.dir.path().join("ca.crt")
1727    }
1728
1729    /// Requests a new intermediate certificate authority.
1730    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1731        Ca::make_ca(name, Some(self))
1732    }
1733
1734    /// Generates a certificate with the specified Common Name (CN) that is
1735    /// signed by the CA.
1736    ///
1737    /// Returns the paths to the certificate and key.
1738    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1739        self.request_cert(name, iter::empty())
1740    }
1741
1742    /// Like `request_client_cert`, but permits specifying additional IP
1743    /// addresses to attach as Subject Alternate Names.
1744    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1745    where
1746        I: IntoIterator<Item = IpAddr>,
1747    {
1748        let rsa = Rsa::generate(2048)?;
1749        let pkey = PKey::from_rsa(rsa)?;
1750        let subject_name = {
1751            let mut builder = X509NameBuilder::new()?;
1752            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1753            builder.build()
1754        };
1755        let cert = {
1756            let mut builder = X509::builder()?;
1757            builder.set_version(2)?;
1758            builder.set_pubkey(&pkey)?;
1759            builder.set_issuer_name(self.cert.subject_name())?;
1760            builder.set_subject_name(&subject_name)?;
1761            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1762            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1763            for ip in ips {
1764                builder.append_extension(
1765                    SubjectAlternativeName::new()
1766                        .ip(&ip.to_string())
1767                        .build(&builder.x509v3_context(None, None))?,
1768                )?;
1769            }
1770            builder.sign(&self.pkey, MessageDigest::sha256())?;
1771            builder.build()
1772        };
1773        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1774        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1775        fs::write(&cert_path, cert.to_pem()?)?;
1776        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1777        Ok((cert_path, key_path))
1778    }
1779}