Skip to main content

mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90    CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91    WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97/// Entry point for creating and configuring an `environmentd` test harness.
98#[derive(Clone)]
99pub struct TestHarness {
100    data_directory: Option<PathBuf>,
101    tls: Option<TlsCertConfig>,
102    frontegg: Option<FronteggAuthenticator>,
103    external_login_password_mz_system: Option<Password>,
104    listeners_config: ListenersConfig,
105    unsafe_mode: bool,
106    workers: usize,
107    now: NowFn,
108    seed: u32,
109    storage_usage_collection_interval: Duration,
110    storage_usage_retention_period: Option<Duration>,
111    default_cluster_replica_size: String,
112    default_cluster_replication_factor: u32,
113    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119    propagate_crashes: bool,
120    enable_tracing: bool,
121    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
122    // tracing.
123    orchestrator_tracing_cli_args: TracingCliArgs,
124    bootstrap_role: Option<String>,
125    deploy_generation: u64,
126    system_parameter_defaults: BTreeMap<String, String>,
127    internal_console_redirect_url: Option<String>,
128    metrics_registry: Option<MetricsRegistry>,
129    code_version: semver::Version,
130    capture: Option<SharedStorage>,
131    pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135    fn default() -> TestHarness {
136        TestHarness {
137            data_directory: None,
138            tls: None,
139            frontegg: None,
140            external_login_password_mz_system: None,
141            listeners_config: ListenersConfig {
142                sql: btreemap![
143                    "external".to_owned() => SqlListenerConfig {
144                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145                        authenticator_kind: AuthenticatorKind::None,
146                        allowed_roles: AllowedRoles::Normal,
147                        enable_tls: false,
148                    },
149                    "internal".to_owned() => SqlListenerConfig {
150                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151                        authenticator_kind: AuthenticatorKind::None,
152                        allowed_roles: AllowedRoles::NormalAndInternal,
153                        enable_tls: false,
154                    },
155                ],
156                http: btreemap![
157                    "external".to_owned() => HttpListenerConfig {
158                        base: BaseListenerConfig {
159                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160                            authenticator_kind: AuthenticatorKind::None,
161                            allowed_roles: AllowedRoles::Normal,
162                            enable_tls: false,
163                        },
164                        routes: HttpRoutesEnabled{
165                            base: true,
166                            webhook: true,
167                            internal: false,
168                            metrics: false,
169                            profiling: false,
170                            mcp_agents: false,
171                            mcp_observatory: false,
172                        },
173                    },
174                    "internal".to_owned() => HttpListenerConfig {
175                        base: BaseListenerConfig {
176                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
177                            authenticator_kind: AuthenticatorKind::None,
178                            allowed_roles: AllowedRoles::NormalAndInternal,
179                            enable_tls: false,
180                        },
181                        routes: HttpRoutesEnabled{
182                            base: true,
183                            webhook: true,
184                            internal: true,
185                            metrics: true,
186                            profiling: true,
187                            mcp_agents: false,
188                            mcp_observatory: false,
189                        },
190                    },
191                ],
192            },
193            unsafe_mode: false,
194            workers: 1,
195            now: SYSTEM_TIME.clone(),
196            seed: rand::random(),
197            storage_usage_collection_interval: Duration::from_secs(3600),
198            storage_usage_retention_period: None,
199            default_cluster_replica_size: "scale=1,workers=1".to_string(),
200            default_cluster_replication_factor: 1,
201            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
202                size: "scale=1,workers=1".to_string(),
203                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
204            },
205            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
206                size: "scale=1,workers=1".to_string(),
207                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
208            },
209            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
210                size: "scale=1,workers=1".to_string(),
211                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
212            },
213            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
214                size: "scale=1,workers=1".to_string(),
215                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
216            },
217            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
218                size: "scale=1,workers=1".to_string(),
219                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
220            },
221            propagate_crashes: false,
222            enable_tracing: false,
223            bootstrap_role: Some("materialize".into()),
224            deploy_generation: 0,
225            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
226            // If we need those in the future, we might need to change both.
227            system_parameter_defaults: BTreeMap::from([(
228                "log_filter".to_string(),
229                "error".to_string(),
230            )]),
231            internal_console_redirect_url: None,
232            metrics_registry: None,
233            orchestrator_tracing_cli_args: TracingCliArgs {
234                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
235                ..Default::default()
236            },
237            code_version: crate::BUILD_INFO.semver_version(),
238            environment_id: EnvironmentId::for_tests(),
239            capture: None,
240        }
241    }
242}
243
244impl TestHarness {
245    /// Starts a test [`TestServer`], panicking if the server could not be started.
246    ///
247    /// For cases when startup might fail, see [`TestHarness::try_start`].
248    pub async fn start(self) -> TestServer {
249        self.try_start().await.expect("Failed to start test Server")
250    }
251
252    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
253    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
254        self.try_start_with_trigger(tls_reload_certs)
255            .await
256            .expect("Failed to start test Server")
257    }
258
259    /// Starts a test [`TestServer`], returning an error if the server could not be started.
260    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
261        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
262            .await
263    }
264
265    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
266    pub async fn try_start_with_trigger(
267        self,
268        tls_reload_certs: ReloadTrigger,
269    ) -> Result<TestServer, anyhow::Error> {
270        let listeners = Listeners::new(&self).await?;
271        listeners.serve_with_trigger(self, tls_reload_certs).await
272    }
273
274    /// Starts a runtime and returns a [`TestServerWithRuntime`].
275    pub fn start_blocking(self) -> TestServerWithRuntime {
276        let runtime = tokio::runtime::Builder::new_multi_thread()
277            .enable_all()
278            .thread_stack_size(mz_ore::stack::STACK_SIZE)
279            .build()
280            .expect("failed to spawn runtime for test");
281        let runtime = Arc::new(runtime);
282        let server = runtime.block_on(self.start());
283        TestServerWithRuntime { runtime, server }
284    }
285
286    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
287        self.data_directory = Some(data_directory.into());
288        self
289    }
290
291    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
292        self.tls = Some(TlsCertConfig {
293            cert: cert_path.into(),
294            key: key_path.into(),
295        });
296        for (_, listener) in &mut self.listeners_config.sql {
297            listener.enable_tls = true;
298        }
299        for (_, listener) in &mut self.listeners_config.http {
300            listener.base.enable_tls = true;
301        }
302        self
303    }
304
305    pub fn unsafe_mode(mut self) -> Self {
306        self.unsafe_mode = true;
307        self
308    }
309
310    pub fn workers(mut self, workers: usize) -> Self {
311        self.workers = workers;
312        self
313    }
314
315    pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
316        self.frontegg = Some(frontegg.clone());
317        let enable_tls = self.tls.is_some();
318        self.listeners_config = ListenersConfig {
319            sql: btreemap! {
320                "external".to_owned() => SqlListenerConfig {
321                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
322                    authenticator_kind: AuthenticatorKind::Frontegg,
323                    allowed_roles: AllowedRoles::Normal,
324                    enable_tls,
325                },
326                "internal".to_owned() => SqlListenerConfig {
327                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
328                    authenticator_kind: AuthenticatorKind::None,
329                    allowed_roles: AllowedRoles::NormalAndInternal,
330                    enable_tls: false,
331                },
332            },
333            http: btreemap! {
334                "external".to_owned() => HttpListenerConfig {
335                    base: BaseListenerConfig {
336                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
337                        authenticator_kind: AuthenticatorKind::Frontegg,
338                        allowed_roles: AllowedRoles::Normal,
339                        enable_tls,
340                    },
341                    routes: HttpRoutesEnabled{
342                        base: true,
343                        webhook: true,
344                        internal: false,
345                        metrics: false,
346                        profiling: false,
347                        mcp_agents: false,
348                        mcp_observatory: false,
349                    },
350                },
351                "internal".to_owned() => HttpListenerConfig {
352                    base: BaseListenerConfig {
353                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
354                        authenticator_kind: AuthenticatorKind::None,
355                        allowed_roles: AllowedRoles::NormalAndInternal,
356                        enable_tls: false,
357                    },
358                    routes: HttpRoutesEnabled{
359                        base: true,
360                        webhook: true,
361                        internal: true,
362                        metrics: true,
363                        profiling: true,
364                        mcp_agents: false,
365                        mcp_observatory: false,
366                    },
367                },
368            },
369        };
370        self
371    }
372
373    pub fn with_oidc_auth(
374        mut self,
375        issuer: Option<String>,
376        authentication_claim: Option<String>,
377        expected_audiences: Option<Vec<String>>,
378    ) -> Self {
379        let enable_tls = self.tls.is_some();
380        self.listeners_config = ListenersConfig {
381            sql: btreemap! {
382                "external".to_owned() => SqlListenerConfig {
383                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
384                    authenticator_kind: AuthenticatorKind::Oidc,
385                    allowed_roles: AllowedRoles::Normal,
386                    enable_tls,
387                },
388                "internal".to_owned() => SqlListenerConfig {
389                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
390                    authenticator_kind: AuthenticatorKind::None,
391                    allowed_roles: AllowedRoles::NormalAndInternal,
392                    enable_tls: false,
393                },
394            },
395            http: btreemap! {
396                "external".to_owned() => HttpListenerConfig {
397                    base: BaseListenerConfig {
398                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
399                        authenticator_kind: AuthenticatorKind::Oidc,
400                        allowed_roles: AllowedRoles::Normal,
401                        enable_tls,
402                    },
403                    routes: HttpRoutesEnabled{
404                        base: true,
405                        webhook: true,
406                        internal: false,
407                        metrics: false,
408                        profiling: false,
409                        mcp_agents: false,
410                        mcp_observatory: false,
411                    },
412                },
413                "internal".to_owned() => HttpListenerConfig {
414                    base: BaseListenerConfig {
415                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
416                        authenticator_kind: AuthenticatorKind::None,
417                        allowed_roles: AllowedRoles::NormalAndInternal,
418                        enable_tls: false,
419                    },
420                    routes: HttpRoutesEnabled{
421                        base: true,
422                        webhook: true,
423                        internal: true,
424                        metrics: true,
425                        profiling: true,
426                        mcp_agents: false,
427                        mcp_observatory: false,
428                    },
429                },
430            },
431        };
432
433        if let Some(issuer) = issuer {
434            self.system_parameter_defaults
435                .insert("oidc_issuer".to_string(), issuer);
436        }
437
438        if let Some(authentication_claim) = authentication_claim {
439            self.system_parameter_defaults.insert(
440                "oidc_authentication_claim".to_string(),
441                authentication_claim,
442            );
443        }
444
445        if let Some(expected_audiences) = expected_audiences {
446            self.system_parameter_defaults.insert(
447                "oidc_audience".to_string(),
448                serde_json::to_string(&expected_audiences).unwrap(),
449            );
450        }
451
452        self
453    }
454
455    pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
456        self.external_login_password_mz_system = Some(mz_system_password);
457        let enable_tls = self.tls.is_some();
458        self.listeners_config = ListenersConfig {
459            sql: btreemap! {
460                "external".to_owned() => SqlListenerConfig {
461                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
462                    authenticator_kind: AuthenticatorKind::Password,
463                    allowed_roles: AllowedRoles::NormalAndInternal,
464                    enable_tls,
465                },
466            },
467            http: btreemap! {
468                "external".to_owned() => HttpListenerConfig {
469                    base: BaseListenerConfig {
470                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
471                        authenticator_kind: AuthenticatorKind::Password,
472                        allowed_roles: AllowedRoles::NormalAndInternal,
473                        enable_tls,
474                    },
475                    routes: HttpRoutesEnabled{
476                        base: true,
477                        webhook: true,
478                        internal: true,
479                        metrics: false,
480                        profiling: true,
481                        mcp_agents: false,
482                        mcp_observatory: false,
483                    },
484                },
485                "metrics".to_owned() => HttpListenerConfig {
486                    base: BaseListenerConfig {
487                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
488                        authenticator_kind: AuthenticatorKind::None,
489                        allowed_roles: AllowedRoles::NormalAndInternal,
490                        enable_tls: false,
491                    },
492                    routes: HttpRoutesEnabled{
493                        base: false,
494                        webhook: false,
495                        internal: false,
496                        metrics: true,
497                        profiling: false,
498                        mcp_agents: false,
499                        mcp_observatory: false,
500                    },
501                },
502            },
503        };
504        self
505    }
506
507    pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
508        self.external_login_password_mz_system = Some(mz_system_password);
509        let enable_tls = self.tls.is_some();
510        self.listeners_config = ListenersConfig {
511            sql: btreemap! {
512                "external".to_owned() => SqlListenerConfig {
513                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
514                    authenticator_kind: AuthenticatorKind::Sasl,
515                    allowed_roles: AllowedRoles::NormalAndInternal,
516                    enable_tls,
517                },
518            },
519            http: btreemap! {
520                "external".to_owned() => HttpListenerConfig {
521                    base: BaseListenerConfig {
522                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
523                        authenticator_kind: AuthenticatorKind::Password,
524                        allowed_roles: AllowedRoles::NormalAndInternal,
525                        enable_tls,
526                    },
527                    routes: HttpRoutesEnabled{
528                        base: true,
529                        webhook: true,
530                        internal: true,
531                        metrics: false,
532                        profiling: true,
533                        mcp_agents: false,
534                        mcp_observatory: false,
535                    },
536                },
537                "metrics".to_owned() => HttpListenerConfig {
538                    base: BaseListenerConfig {
539                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
540                        authenticator_kind: AuthenticatorKind::None,
541                        allowed_roles: AllowedRoles::NormalAndInternal,
542                        enable_tls: false,
543                    },
544                    routes: HttpRoutesEnabled{
545                        base: false,
546                        webhook: false,
547                        internal: false,
548                        metrics: true,
549                        profiling: false,
550                        mcp_agents: false,
551                        mcp_observatory: false,
552                    },
553                },
554            },
555        };
556        self
557    }
558
559    pub fn with_now(mut self, now: NowFn) -> Self {
560        self.now = now;
561        self
562    }
563
564    pub fn with_storage_usage_collection_interval(
565        mut self,
566        storage_usage_collection_interval: Duration,
567    ) -> Self {
568        self.storage_usage_collection_interval = storage_usage_collection_interval;
569        self
570    }
571
572    pub fn with_storage_usage_retention_period(
573        mut self,
574        storage_usage_retention_period: Duration,
575    ) -> Self {
576        self.storage_usage_retention_period = Some(storage_usage_retention_period);
577        self
578    }
579
580    pub fn with_default_cluster_replica_size(
581        mut self,
582        default_cluster_replica_size: String,
583    ) -> Self {
584        self.default_cluster_replica_size = default_cluster_replica_size;
585        self
586    }
587
588    pub fn with_builtin_system_cluster_replica_size(
589        mut self,
590        builtin_system_cluster_replica_size: String,
591    ) -> Self {
592        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
593        self
594    }
595
596    pub fn with_builtin_system_cluster_replication_factor(
597        mut self,
598        builtin_system_cluster_replication_factor: u32,
599    ) -> Self {
600        self.builtin_system_cluster_config.replication_factor =
601            builtin_system_cluster_replication_factor;
602        self
603    }
604
605    pub fn with_builtin_catalog_server_cluster_replica_size(
606        mut self,
607        builtin_catalog_server_cluster_replica_size: String,
608    ) -> Self {
609        self.builtin_catalog_server_cluster_config.size =
610            builtin_catalog_server_cluster_replica_size;
611        self
612    }
613
614    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
615        self.propagate_crashes = propagate_crashes;
616        self
617    }
618
619    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
620        self.enable_tracing = enable_tracing;
621        self
622    }
623
624    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
625        self.bootstrap_role = bootstrap_role;
626        self
627    }
628
629    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
630        self.deploy_generation = deploy_generation;
631        self
632    }
633
634    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
635        self.system_parameter_defaults.insert(param, value);
636        self
637    }
638
639    pub fn with_internal_console_redirect_url(
640        mut self,
641        internal_console_redirect_url: Option<String>,
642    ) -> Self {
643        self.internal_console_redirect_url = internal_console_redirect_url;
644        self
645    }
646
647    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
648        self.metrics_registry = Some(registry);
649        self
650    }
651
652    pub fn with_code_version(mut self, version: semver::Version) -> Self {
653        self.code_version = version;
654        self
655    }
656
657    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
658        self.capture = Some(storage);
659        self
660    }
661}
662
663pub struct Listeners {
664    pub inner: crate::Listeners,
665}
666
667impl Listeners {
668    pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
669        let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
670        Ok(Listeners { inner })
671    }
672
673    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
674        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
675            .await
676    }
677
678    pub async fn serve_with_trigger(
679        self,
680        config: TestHarness,
681        tls_reload_certs: ReloadTrigger,
682    ) -> Result<TestServer, anyhow::Error> {
683        let (data_directory, temp_dir) = match config.data_directory {
684            None => {
685                // If no data directory is provided, we create a temporary
686                // directory. The temporary directory is cleaned up when the
687                // `TempDir` is dropped, so we keep it alive until the `Server` is
688                // dropped.
689                let temp_dir = tempfile::tempdir()?;
690                (temp_dir.path().to_path_buf(), Some(temp_dir))
691            }
692            Some(data_directory) => (data_directory, None),
693        };
694        let scratch_dir = tempfile::tempdir()?;
695        let (consensus_uri, timestamp_oracle_url) = {
696            let seed = config.seed;
697            let cockroach_url = env::var("METADATA_BACKEND_URL")
698                .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
699            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
700            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
701                if let Err(err) = conn.await {
702                    panic!("connection error: {}", err);
703                };
704            });
705            client
706                .batch_execute(&format!(
707                    "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
708                    CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
709                ))
710                .await?;
711            (
712                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
713                    .parse()
714                    .expect("invalid consensus URI"),
715                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
716                    .parse()
717                    .expect("invalid timestamp oracle URI"),
718            )
719        };
720        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
721        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
722            image_dir: env::current_exe()?
723                .parent()
724                .unwrap()
725                .parent()
726                .unwrap()
727                .to_path_buf(),
728            suppress_output: false,
729            environment_id: config.environment_id.to_string(),
730            secrets_dir: data_directory.join("secrets"),
731            command_wrapper: vec![],
732            propagate_crashes: config.propagate_crashes,
733            tcp_proxy: None,
734            scratch_directory: scratch_dir.path().to_path_buf(),
735        })
736        .await?;
737        let orchestrator = Arc::new(orchestrator);
738        // Messing with the clock causes persist to expire leases, causing hangs and
739        // panics. Is it possible/desirable to put this back somehow?
740        let persist_now = SYSTEM_TIME.clone();
741        let dyncfgs = mz_dyncfgs::all_dyncfgs();
742
743        let mut updates = ConfigUpdates::default();
744        // Tune down the number of connections to make this all work a little easier
745        // with local postgres.
746        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
747        updates.apply(&dyncfgs);
748
749        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
750        persist_cfg.build_version = config.code_version;
751        // Stress persist more by writing rollups frequently
752        persist_cfg.set_rollup_threshold(5);
753
754        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
755        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
756        let persist_pubsub_tcp_listener =
757            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
758                .await
759                .expect("pubsub addr binding");
760        let persist_pubsub_server_port = persist_pubsub_tcp_listener
761            .local_addr()
762            .expect("pubsub addr has local addr")
763            .port();
764
765        // Spawn the persist pub-sub server.
766        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
767            persist_pubsub_server
768                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
769                .await
770                .expect("success")
771        });
772        let persist_clients =
773            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
774        let persist_clients = Arc::new(persist_clients);
775
776        let secrets_controller = Arc::clone(&orchestrator);
777        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
778        let orchestrator = Arc::new(TracingOrchestrator::new(
779            orchestrator,
780            config.orchestrator_tracing_cli_args,
781        ));
782        let tracing_handle = if config.enable_tracing {
783            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
784                service_name: "environmentd",
785                stderr_log: StderrLogConfig {
786                    format: StderrLogFormat::Json,
787                    filter: EnvFilter::default(),
788                },
789                opentelemetry: Some(OpenTelemetryConfig {
790                    endpoint: "http://fake_address_for_testing:8080".to_string(),
791                    headers: http::HeaderMap::new(),
792                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
793                    resource: opentelemetry_sdk::resource::Resource::builder().build(),
794                    max_batch_queue_size: 2048,
795                    max_export_batch_size: 512,
796                    max_concurrent_exports: 1,
797                    batch_scheduled_delay: Duration::from_millis(5000),
798                    max_export_timeout: Duration::from_secs(30),
799                }),
800                tokio_console: None,
801                sentry: None,
802                build_version: crate::BUILD_INFO.version,
803                build_sha: crate::BUILD_INFO.sha,
804                registry: metrics_registry.clone(),
805                capture: config.capture,
806            };
807            mz_ore::tracing::configure(config).await?
808        } else {
809            TracingHandle::disabled()
810        };
811        let host_name = format!(
812            "localhost:{}",
813            self.inner.http["external"].handle.local_addr.port()
814        );
815        let catalog_config = CatalogConfig {
816            persist_clients: Arc::clone(&persist_clients),
817            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
818        };
819
820        let inner = self
821            .inner
822            .serve(crate::Config {
823                catalog_config,
824                timestamp_oracle_url: Some(timestamp_oracle_url),
825                controller: ControllerConfig {
826                    build_info: &crate::BUILD_INFO,
827                    orchestrator,
828                    clusterd_image: "clusterd".into(),
829                    init_container_image: None,
830                    deploy_generation: config.deploy_generation,
831                    persist_location: PersistLocation {
832                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
833                            .parse()
834                            .expect("invalid blob URI"),
835                        consensus_uri,
836                    },
837                    persist_clients,
838                    now: config.now.clone(),
839                    metrics_registry: metrics_registry.clone(),
840                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
841                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
842                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
843                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
844                        secrets_reader_kubernetes_context: None,
845                        secrets_reader_aws_prefix: None,
846                        secrets_reader_name_prefix: None,
847                    },
848                    connection_context,
849                    replica_http_locator: Default::default(),
850                },
851                secrets_controller,
852                cloud_resource_controller: None,
853                tls: config.tls,
854                frontegg: config.frontegg,
855                unsafe_mode: config.unsafe_mode,
856                all_features: false,
857                metrics_registry: metrics_registry.clone(),
858                now: config.now,
859                environment_id: config.environment_id,
860                cors_allowed_origin: AllowOrigin::list([]),
861                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
862                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
863                bootstrap_default_cluster_replication_factor: config
864                    .default_cluster_replication_factor,
865                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
866                bootstrap_builtin_catalog_server_cluster_config: config
867                    .builtin_catalog_server_cluster_config,
868                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
869                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
870                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
871                system_parameter_defaults: config.system_parameter_defaults,
872                availability_zones: Default::default(),
873                tracing_handle,
874                storage_usage_collection_interval: config.storage_usage_collection_interval,
875                storage_usage_retention_period: config.storage_usage_retention_period,
876                segment_api_key: None,
877                segment_client_side: false,
878                test_only_dummy_segment_client: false,
879                egress_addresses: vec![],
880                aws_account_id: None,
881                aws_privatelink_availability_zones: None,
882                launchdarkly_sdk_key: None,
883                launchdarkly_key_map: Default::default(),
884                config_sync_file_path: None,
885                config_sync_timeout: Duration::from_secs(30),
886                config_sync_loop_interval: None,
887                bootstrap_role: config.bootstrap_role,
888                http_host_name: Some(host_name),
889                internal_console_redirect_url: config.internal_console_redirect_url,
890                tls_reload_certs,
891                helm_chart_version: None,
892                license_key: ValidatedLicenseKey::for_tests(),
893                external_login_password_mz_system: config.external_login_password_mz_system,
894                force_builtin_schema_migration: None,
895            })
896            .await?;
897
898        Ok(TestServer {
899            inner,
900            metrics_registry,
901            _temp_dir: temp_dir,
902            _scratch_dir: scratch_dir,
903        })
904    }
905}
906
907/// A running instance of `environmentd`.
908pub struct TestServer {
909    pub inner: crate::Server,
910    pub metrics_registry: MetricsRegistry,
911    /// The `TempDir`s are saved to prevent them from being dropped, and thus cleaned up too early.
912    _temp_dir: Option<TempDir>,
913    _scratch_dir: TempDir,
914}
915
916impl TestServer {
917    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
918        ConnectBuilder::new(self).no_tls()
919    }
920
921    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
922        let internal_client = self.connect().internal().await.unwrap();
923
924        for flag in flags {
925            internal_client
926                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
927                .await
928                .unwrap();
929        }
930    }
931
932    pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
933        let internal_client = self.connect().internal().await.unwrap();
934
935        for flag in flags {
936            internal_client
937                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
938                .await
939                .unwrap();
940        }
941    }
942
943    pub fn ws_addr(&self) -> Uri {
944        format!(
945            "ws://{}/api/experimental/sql",
946            self.inner.http_listener_handles["external"].local_addr
947        )
948        .parse()
949        .unwrap()
950    }
951
952    pub fn internal_ws_addr(&self) -> Uri {
953        format!(
954            "ws://{}/api/experimental/sql",
955            self.inner.http_listener_handles["internal"].local_addr
956        )
957        .parse()
958        .unwrap()
959    }
960
961    pub fn http_local_addr(&self) -> SocketAddr {
962        self.inner.http_listener_handles["external"].local_addr
963    }
964
965    pub fn internal_http_local_addr(&self) -> SocketAddr {
966        self.inner.http_listener_handles["internal"].local_addr
967    }
968
969    pub fn sql_local_addr(&self) -> SocketAddr {
970        self.inner.sql_listener_handles["external"].local_addr
971    }
972
973    pub fn internal_sql_local_addr(&self) -> SocketAddr {
974        self.inner.sql_listener_handles["internal"].local_addr
975    }
976}
977
978/// A builder struct to configure a pgwire connection to a running [`TestServer`].
979///
980/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
981pub struct ConnectBuilder<'s, T, H> {
982    /// A running `environmentd` test server.
983    server: &'s TestServer,
984
985    /// Postgres configuration for connecting to the test server.
986    pg_config: tokio_postgres::Config,
987    /// Port to use when connecting to the test server.
988    port: u16,
989    /// Tls settings to use.
990    tls: T,
991
992    /// Callback that gets invoked for every notice we receive.
993    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
994
995    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
996    _with_handle: H,
997}
998
999impl<'s> ConnectBuilder<'s, (), NoHandle> {
1000    fn new(server: &'s TestServer) -> Self {
1001        let mut pg_config = tokio_postgres::Config::new();
1002        pg_config
1003            .host(&Ipv4Addr::LOCALHOST.to_string())
1004            .user("materialize")
1005            .options("--welcome_message=off")
1006            .application_name("environmentd_test_framework");
1007
1008        ConnectBuilder {
1009            server,
1010            pg_config,
1011            port: server.sql_local_addr().port(),
1012            tls: (),
1013            notice_callback: None,
1014            _with_handle: NoHandle,
1015        }
1016    }
1017}
1018
1019impl<'s, T, H> ConnectBuilder<'s, T, H> {
1020    /// Create a pgwire connection without using TLS.
1021    ///
1022    /// Note: this is the default for all connections.
1023    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
1024        ConnectBuilder {
1025            server: self.server,
1026            pg_config: self.pg_config,
1027            port: self.port,
1028            tls: postgres::NoTls,
1029            notice_callback: self.notice_callback,
1030            _with_handle: self._with_handle,
1031        }
1032    }
1033
1034    /// Create a pgwire connection with TLS.
1035    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1036    where
1037        Tls: MakeTlsConnect<Socket> + Send + 'static,
1038        Tls::TlsConnect: Send,
1039        Tls::Stream: Send,
1040        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1041    {
1042        ConnectBuilder {
1043            server: self.server,
1044            pg_config: self.pg_config,
1045            port: self.port,
1046            tls,
1047            notice_callback: self.notice_callback,
1048            _with_handle: self._with_handle,
1049        }
1050    }
1051
1052    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
1053    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1054        self.pg_config = pg_config;
1055        self
1056    }
1057
1058    /// Set the [`SslMode`] to be used with the resulting connection.
1059    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1060        self.pg_config.ssl_mode(mode);
1061        self
1062    }
1063
1064    /// Set the user for the pgwire connection.
1065    pub fn user(mut self, user: &str) -> Self {
1066        self.pg_config.user(user);
1067        self
1068    }
1069
1070    /// Set the password for the pgwire connection.
1071    pub fn password(mut self, password: &str) -> Self {
1072        self.pg_config.password(password);
1073        self
1074    }
1075
1076    /// Set the application name for the pgwire connection.
1077    pub fn application_name(mut self, application_name: &str) -> Self {
1078        self.pg_config.application_name(application_name);
1079        self
1080    }
1081
1082    /// Set the database name for the pgwire connection.
1083    pub fn dbname(mut self, dbname: &str) -> Self {
1084        self.pg_config.dbname(dbname);
1085        self
1086    }
1087
1088    /// Set the options for the pgwire connection.
1089    pub fn options(mut self, options: &str) -> Self {
1090        self.pg_config.options(options);
1091        self
1092    }
1093
1094    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
1095    /// [`TestServer`].
1096    ///
1097    /// For example, this will change the port we connect to, and the user we connect as.
1098    pub fn internal(mut self) -> Self {
1099        self.port = self.server.internal_sql_local_addr().port();
1100        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1101        self
1102    }
1103
1104    /// Sets a callback for any database notices that are received from the [`TestServer`].
1105    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1106        ConnectBuilder {
1107            notice_callback: Some(Box::new(callback)),
1108            ..self
1109        }
1110    }
1111
1112    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
1113    /// polling the underlying postgres connection, associated with the returned client.
1114    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1115        ConnectBuilder {
1116            server: self.server,
1117            pg_config: self.pg_config,
1118            port: self.port,
1119            tls: self.tls,
1120            notice_callback: self.notice_callback,
1121            _with_handle: WithHandle,
1122        }
1123    }
1124
1125    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
1126    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1127        &self.pg_config
1128    }
1129}
1130
1131/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
1132/// of a client connection.
1133pub trait IncludeHandle: Send {
1134    type Output;
1135    fn transform_result(
1136        client: tokio_postgres::Client,
1137        handle: mz_ore::task::JoinHandle<()>,
1138    ) -> Self::Output;
1139}
1140
1141/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
1142/// result of a [`ConnectBuilder`].
1143pub struct NoHandle;
1144impl IncludeHandle for NoHandle {
1145    type Output = tokio_postgres::Client;
1146    fn transform_result(
1147        client: tokio_postgres::Client,
1148        _handle: mz_ore::task::JoinHandle<()>,
1149    ) -> Self::Output {
1150        client
1151    }
1152}
1153
1154/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
1155/// a [`ConnectBuilder`].
1156pub struct WithHandle;
1157impl IncludeHandle for WithHandle {
1158    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1159    fn transform_result(
1160        client: tokio_postgres::Client,
1161        handle: mz_ore::task::JoinHandle<()>,
1162    ) -> Self::Output {
1163        (client, handle)
1164    }
1165}
1166
1167impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1168where
1169    T: MakeTlsConnect<Socket> + Send + 'static,
1170    T::TlsConnect: Send,
1171    T::Stream: Send,
1172    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1173    H: IncludeHandle,
1174{
1175    type Output = Result<H::Output, postgres::Error>;
1176    type IntoFuture = BoxFuture<'static, Self::Output>;
1177
1178    fn into_future(mut self) -> Self::IntoFuture {
1179        Box::pin(async move {
1180            assert!(
1181                self.pg_config.get_ports().is_empty(),
1182                "specifying multiple ports is not supported"
1183            );
1184            self.pg_config.port(self.port);
1185
1186            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1187            let mut notice_callback = self.notice_callback.take();
1188
1189            let handle = task::spawn(|| "connect", async move {
1190                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1191                    match msg {
1192                        Ok(AsyncMessage::Notice(notice)) => {
1193                            if let Some(callback) = notice_callback.as_mut() {
1194                                callback(notice);
1195                            }
1196                        }
1197                        Ok(msg) => {
1198                            tracing::debug!(?msg, "Dropping message from database");
1199                        }
1200                        Err(e) => {
1201                            // tokio_postgres::Connection docs say:
1202                            // > Return values of None or Some(Err(_)) are “terminal”; callers
1203                            // > should not invoke this method again after receiving one of those
1204                            // > values.
1205                            tracing::info!("connection error: {e}");
1206                            break;
1207                        }
1208                    }
1209                }
1210                tracing::info!("connection closed");
1211            });
1212
1213            let output = H::transform_result(client, handle);
1214            Ok(output)
1215        })
1216    }
1217}
1218
1219/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
1220///
1221/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
1222/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
1223pub struct TestServerWithRuntime {
1224    server: TestServer,
1225    runtime: Arc<Runtime>,
1226}
1227
1228impl TestServerWithRuntime {
1229    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
1230    ///
1231    /// Can be used to spawn async tasks.
1232    pub fn runtime(&self) -> &Arc<Runtime> {
1233        &self.runtime
1234    }
1235
1236    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
1237    pub fn inner(&self) -> &crate::Server {
1238        &self.server.inner
1239    }
1240
1241    /// Connect to the __public__ SQL port of the running `environmentd` server.
1242    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1243    where
1244        T: MakeTlsConnect<Socket> + Send + 'static,
1245        T::TlsConnect: Send,
1246        T::Stream: Send,
1247        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1248    {
1249        self.pg_config().connect(tls)
1250    }
1251
1252    /// Connect to the __internal__ SQL port of the running `environmentd` server.
1253    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1254    where
1255        T: MakeTlsConnect<Socket> + Send + 'static,
1256        T::TlsConnect: Send,
1257        T::Stream: Send,
1258        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1259    {
1260        Ok(self.pg_config_internal().connect(tls)?)
1261    }
1262
1263    /// Enable LaunchDarkly feature flags.
1264    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1265        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1266
1267        for flag in flags {
1268            internal_client
1269                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1270                .unwrap();
1271        }
1272    }
1273
1274    /// Disable LaunchDarkly feature flags.
1275    pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1276        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1277
1278        for flag in flags {
1279            internal_client
1280                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1281                .unwrap();
1282        }
1283    }
1284
1285    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
1286    /// `environmentd` server.
1287    pub fn pg_config(&self) -> postgres::Config {
1288        let local_addr = self.server.sql_local_addr();
1289        let mut config = postgres::Config::new();
1290        config
1291            .host(&Ipv4Addr::LOCALHOST.to_string())
1292            .port(local_addr.port())
1293            .user("materialize")
1294            .options("--welcome_message=off");
1295        config
1296    }
1297
1298    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
1299    /// `environmentd` server.
1300    pub fn pg_config_internal(&self) -> postgres::Config {
1301        let local_addr = self.server.internal_sql_local_addr();
1302        let mut config = postgres::Config::new();
1303        config
1304            .host(&Ipv4Addr::LOCALHOST.to_string())
1305            .port(local_addr.port())
1306            .user("mz_system")
1307            .options("--welcome_message=off");
1308        config
1309    }
1310
1311    pub fn ws_addr(&self) -> Uri {
1312        self.server.ws_addr()
1313    }
1314
1315    pub fn internal_ws_addr(&self) -> Uri {
1316        self.server.internal_ws_addr()
1317    }
1318
1319    pub fn http_local_addr(&self) -> SocketAddr {
1320        self.server.http_local_addr()
1321    }
1322
1323    pub fn internal_http_local_addr(&self) -> SocketAddr {
1324        self.server.internal_http_local_addr()
1325    }
1326
1327    pub fn sql_local_addr(&self) -> SocketAddr {
1328        self.server.sql_local_addr()
1329    }
1330
1331    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1332        self.server.internal_sql_local_addr()
1333    }
1334
1335    /// Returns the metrics registry for the test server.
1336    pub fn metrics_registry(&self) -> &MetricsRegistry {
1337        &self.server.metrics_registry
1338    }
1339}
1340
1341#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1342pub struct MzTimestamp(pub u64);
1343
1344impl<'a> FromSql<'a> for MzTimestamp {
1345    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1346        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1347        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1348    }
1349
1350    fn accepts(ty: &Type) -> bool {
1351        mz_pgrepr::Numeric::accepts(ty)
1352    }
1353}
1354
1355pub trait PostgresErrorExt {
1356    fn unwrap_db_error(self) -> DbError;
1357}
1358
1359impl PostgresErrorExt for postgres::Error {
1360    fn unwrap_db_error(self) -> DbError {
1361        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1362            Some(e) => e.clone(),
1363            None => panic!("expected DbError, but got: {:?}", self),
1364        }
1365    }
1366}
1367
1368impl<T, E> PostgresErrorExt for Result<T, E>
1369where
1370    E: PostgresErrorExt,
1371{
1372    fn unwrap_db_error(self) -> DbError {
1373        match self {
1374            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1375            Err(e) => e.unwrap_db_error(),
1376        }
1377    }
1378}
1379
1380/// Group commit will block writes until the current time has advanced. This can make
1381/// performing inserts while using deterministic time difficult. This is a helper
1382/// method to perform writes and advance the current time.
1383pub async fn insert_with_deterministic_timestamps(
1384    table: &'static str,
1385    values: &'static str,
1386    server: &TestServer,
1387    now: Arc<std::sync::Mutex<EpochMillis>>,
1388) -> Result<(), Box<dyn Error>> {
1389    let client_write = server.connect().await?;
1390    let client_read = server.connect().await?;
1391
1392    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1393
1394    let insert_query = format!("INSERT INTO {table} VALUES {values}");
1395
1396    let write_future = client_write.execute(&insert_query, &[]);
1397    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1398
1399    let mut write_future = std::pin::pin!(write_future);
1400    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1401
1402    // Keep increasing `now` until the write has executed succeed. Table advancements may
1403    // have increased the global timestamp by an unknown amount.
1404    loop {
1405        tokio::select! {
1406            _ = (&mut write_future) => return Ok(()),
1407            _ = timestamp_interval.tick() => {
1408                current_ts += 1;
1409                *now.lock().expect("lock poisoned") = current_ts;
1410            }
1411        };
1412    }
1413}
1414
1415pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1416    try_get_explain_timestamp(from_suffix, client)
1417        .await
1418        .unwrap()
1419}
1420
1421pub async fn try_get_explain_timestamp(
1422    from_suffix: &str,
1423    client: &Client,
1424) -> Result<EpochMillis, anyhow::Error> {
1425    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1426    let ts = det.determination.timestamp_context.timestamp_or_default();
1427    Ok(ts.into())
1428}
1429
1430pub async fn get_explain_timestamp_determination(
1431    from_suffix: &str,
1432    client: &Client,
1433) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1434    let row = client
1435        .query_one(
1436            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1437            &[],
1438        )
1439        .await?;
1440    let explain: String = row.get(0);
1441    Ok(serde_json::from_str(&explain).unwrap())
1442}
1443
1444/// Helper function to create a Postgres source.
1445///
1446/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1447/// Postgres state.
1448///
1449/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1450/// use different postgres tables.
1451pub async fn create_postgres_source_with_table<'a>(
1452    server: &TestServer,
1453    mz_client: &Client,
1454    table_name: &str,
1455    table_schema: &str,
1456    source_name: &str,
1457) -> (
1458    Client,
1459    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1460) {
1461    server
1462        .enable_feature_flags(&["enable_create_table_from_source"])
1463        .await;
1464
1465    let postgres_url = env::var("POSTGRES_URL")
1466        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1467        .unwrap();
1468
1469    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1470        .await
1471        .unwrap();
1472
1473    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1474    let user = pg_config.get_user().unwrap_or("postgres");
1475    let db_name = pg_config.get_dbname().unwrap_or(user);
1476    let ports = pg_config.get_ports();
1477    let port = if ports.is_empty() { 5432 } else { ports[0] };
1478    let hosts = pg_config.get_hosts();
1479    let host = if hosts.is_empty() {
1480        "localhost".to_string()
1481    } else {
1482        match &hosts[0] {
1483            Host::Tcp(host) => host.to_string(),
1484            Host::Unix(host) => host.to_str().unwrap().to_string(),
1485        }
1486    };
1487    let password = pg_config.get_password();
1488
1489    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1490        if let Err(e) = connection.await {
1491            panic!("connection error: {}", e);
1492        }
1493    });
1494
1495    // Create table in Postgres with publication.
1496    let _ = pg_client
1497        .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1498        .await
1499        .unwrap();
1500    let _ = pg_client
1501        .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1502        .await
1503        .unwrap();
1504    let _ = pg_client
1505        .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1506        .await
1507        .unwrap();
1508    let _ = pg_client
1509        .execute(
1510            &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1511            &[],
1512        )
1513        .await
1514        .unwrap();
1515    let _ = pg_client
1516        .execute(
1517            &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1518            &[],
1519        )
1520        .await
1521        .unwrap();
1522
1523    // Create postgres source in Materialize.
1524    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1525    if let Some(password) = password {
1526        let password = std::str::from_utf8(password).unwrap();
1527        mz_client
1528            .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1529            .await
1530            .unwrap();
1531        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1532    }
1533    mz_client
1534        .batch_execute(&format!(
1535            "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1536        ))
1537        .await
1538        .unwrap();
1539    mz_client
1540        .batch_execute(&format!(
1541            "CREATE SOURCE {source_name}
1542            FROM POSTGRES
1543            CONNECTION pgconn
1544            (PUBLICATION '{source_name}')"
1545        ))
1546        .await
1547        .unwrap();
1548    mz_client
1549        .batch_execute(&format!(
1550            "CREATE TABLE {table_name}
1551            FROM SOURCE {source_name}
1552            (REFERENCE {table_name});"
1553        ))
1554        .await
1555        .unwrap();
1556
1557    let table_name = table_name.to_string();
1558    let source_name = source_name.to_string();
1559    (
1560        pg_client,
1561        move |mz_client: &'a Client, pg_client: &'a Client| {
1562            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1563                mz_client
1564                    .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1565                    .await
1566                    .unwrap();
1567                mz_client
1568                    .batch_execute("DROP CONNECTION pgconn;")
1569                    .await
1570                    .unwrap();
1571
1572                let _ = pg_client
1573                    .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1574                    .await
1575                    .unwrap();
1576                let _ = pg_client
1577                    .execute(&format!("DROP TABLE {table_name};"), &[])
1578                    .await
1579                    .unwrap();
1580            });
1581            f
1582        },
1583    )
1584}
1585
1586pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1587    let current_isolation = mz_client
1588        .query_one("SHOW transaction_isolation", &[])
1589        .await
1590        .unwrap()
1591        .get::<_, String>(0);
1592    mz_client
1593        .batch_execute("SET transaction_isolation = SERIALIZABLE")
1594        .await
1595        .unwrap();
1596    Retry::default()
1597        .retry_async(|_| async move {
1598            let rows = mz_client
1599                .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1600                .await
1601                .unwrap()
1602                .get::<_, i64>(0);
1603            if rows == source_rows {
1604                Ok(())
1605            } else {
1606                Err(format!(
1607                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1608                ))
1609            }
1610        })
1611        .await
1612        .unwrap();
1613    mz_client
1614        .batch_execute(&format!(
1615            "SET transaction_isolation = '{current_isolation}'"
1616        ))
1617        .await
1618        .unwrap();
1619}
1620
1621// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1622pub fn auth_with_ws(
1623    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1624    mut options: BTreeMap<String, String>,
1625) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1626    if !options.contains_key("welcome_message") {
1627        options.insert("welcome_message".into(), "off".into());
1628    }
1629    auth_with_ws_impl(
1630        ws,
1631        Message::Text(
1632            serde_json::to_string(&WebSocketAuth::Basic {
1633                user: "materialize".into(),
1634                password: "".into(),
1635                options,
1636            })
1637            .unwrap()
1638            .into(),
1639        ),
1640    )
1641}
1642
1643pub fn auth_with_ws_impl(
1644    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1645    auth_message: Message,
1646) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1647    ws.send(auth_message)?;
1648
1649    // Wait for initial ready response.
1650    let mut msgs = Vec::new();
1651    loop {
1652        let resp = ws.read()?;
1653        match resp {
1654            Message::Text(msg) => {
1655                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1656                match msg {
1657                    WebSocketResponse::ReadyForQuery(_) => break,
1658                    msg => {
1659                        msgs.push(msg);
1660                    }
1661                }
1662            }
1663            Message::Ping(_) => continue,
1664            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1665            Message::Close(Some(close_frame)) => {
1666                return Err(anyhow!("ws closed after auth").context(close_frame));
1667            }
1668            _ => panic!("unexpected response: {:?}", resp),
1669        }
1670    }
1671    Ok(msgs)
1672}
1673
1674pub fn make_header<H: Header>(h: H) -> HeaderMap {
1675    let mut map = HeaderMap::new();
1676    map.typed_insert(h);
1677    map
1678}
1679
1680pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1681where
1682    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1683{
1684    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1685    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1686    // messages with TLS v1.2.
1687    //
1688    // Briefly, in TLS v1.3, failing to present a client certificate does not
1689    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1690    // attempt to read from the stream. But both `postgres` and `hyper` write a
1691    // bunch of data before attempting to read from the stream. With a failed
1692    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1693    // out this data, and then return a nice error message on the call to read.
1694    // But sometimes the connection is closed before they write out the data,
1695    // and so they report "connection closed" before they ever call read, never
1696    // noticing the underlying SSL error.
1697    //
1698    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1699    // if writing to the stream fails to see if a TLS error occured? Is it on
1700    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1701    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1702    // just avoid TLS v1.3 for now.
1703    //
1704    // [1]: https://github.com/openssl/openssl/issues/11118
1705    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1706    connector_builder.set_options(options);
1707    configure(&mut connector_builder).unwrap();
1708    MakeTlsConnector::new(connector_builder.build())
1709}
1710
1711/// A certificate authority for use in tests.
1712pub struct Ca {
1713    pub dir: TempDir,
1714    pub name: X509Name,
1715    pub cert: X509,
1716    pub pkey: PKey<Private>,
1717}
1718
1719impl Ca {
1720    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1721        let dir = tempfile::tempdir()?;
1722        let rsa = Rsa::generate(2048)?;
1723        let pkey = PKey::from_rsa(rsa)?;
1724        let name = {
1725            let mut builder = X509NameBuilder::new()?;
1726            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1727            builder.build()
1728        };
1729        let cert = {
1730            let mut builder = X509::builder()?;
1731            builder.set_version(2)?;
1732            builder.set_pubkey(&pkey)?;
1733            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1734            builder.set_subject_name(&name)?;
1735            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1736            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1737            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1738            builder.sign(
1739                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1740                MessageDigest::sha256(),
1741            )?;
1742            builder.build()
1743        };
1744        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1745        Ok(Ca {
1746            dir,
1747            name,
1748            cert,
1749            pkey,
1750        })
1751    }
1752
1753    /// Creates a new root certificate authority.
1754    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1755        Ca::make_ca(name, None)
1756    }
1757
1758    /// Returns the path to the CA's certificate.
1759    pub fn ca_cert_path(&self) -> PathBuf {
1760        self.dir.path().join("ca.crt")
1761    }
1762
1763    /// Requests a new intermediate certificate authority.
1764    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1765        Ca::make_ca(name, Some(self))
1766    }
1767
1768    /// Generates a certificate with the specified Common Name (CN) that is
1769    /// signed by the CA.
1770    ///
1771    /// Returns the paths to the certificate and key.
1772    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1773        self.request_cert(name, iter::empty())
1774    }
1775
1776    /// Like `request_client_cert`, but permits specifying additional IP
1777    /// addresses to attach as Subject Alternate Names.
1778    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1779    where
1780        I: IntoIterator<Item = IpAddr>,
1781    {
1782        let rsa = Rsa::generate(2048)?;
1783        let pkey = PKey::from_rsa(rsa)?;
1784        let subject_name = {
1785            let mut builder = X509NameBuilder::new()?;
1786            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1787            builder.build()
1788        };
1789        let cert = {
1790            let mut builder = X509::builder()?;
1791            builder.set_version(2)?;
1792            builder.set_pubkey(&pkey)?;
1793            builder.set_issuer_name(self.cert.subject_name())?;
1794            builder.set_subject_name(&subject_name)?;
1795            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1796            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1797            for ip in ips {
1798                builder.append_extension(
1799                    SubjectAlternativeName::new()
1800                        .ip(&ip.to_string())
1801                        .build(&builder.x509v3_context(None, None))?,
1802                )?;
1803            }
1804            builder.sign(&self.pkey, MessageDigest::sha256())?;
1805            builder.build()
1806        };
1807        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1808        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1809        fs::write(&cert_path, cert.to_pem()?)?;
1810        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1811        Ok((cert_path, key_path))
1812    }
1813}