Skip to main content

mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_postgres_util::{
55    Sql, batch_execute as pg_batch_execute, execute as pg_execute, query_one as pg_query_one, sql,
56};
57use mz_secrets::SecretsController;
58use mz_server_core::listeners::{
59    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
60};
61use mz_server_core::{ReloadTrigger, TlsCertConfig};
62use mz_sql::catalog::EnvironmentId;
63use mz_storage_types::connections::ConnectionContext;
64use mz_tracing::CloneableEnvFilter;
65use openssl::asn1::Asn1Time;
66use openssl::error::ErrorStack;
67use openssl::hash::MessageDigest;
68use openssl::nid::Nid;
69use openssl::pkey::{PKey, Private};
70use openssl::rsa::Rsa;
71use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
72use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
73use openssl::x509::{X509, X509Name, X509NameBuilder};
74use postgres::error::DbError;
75use postgres::tls::{MakeTlsConnect, TlsConnect};
76use postgres::types::{FromSql, Type};
77use postgres::{NoTls, Socket};
78use postgres_openssl::MakeTlsConnector;
79use tempfile::TempDir;
80use tokio::net::TcpListener;
81use tokio::runtime::Runtime;
82use tokio_postgres::config::{Host, SslMode};
83use tokio_postgres::{AsyncMessage, Client};
84use tokio_stream::wrappers::TcpListenerStream;
85use tower_http::cors::AllowOrigin;
86use tracing::Level;
87use tracing_capture::SharedStorage;
88use tracing_subscriber::EnvFilter;
89use tungstenite::stream::MaybeTlsStream;
90use tungstenite::{Message, WebSocket};
91
92use crate::{
93    CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
94    WebSocketAuth, WebSocketResponse,
95};
96
97pub static KAFKA_ADDRS: LazyLock<String> =
98    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
99
100/// Entry point for creating and configuring an `environmentd` test harness.
101#[derive(Clone)]
102pub struct TestHarness {
103    data_directory: Option<PathBuf>,
104    tls: Option<TlsCertConfig>,
105    frontegg: Option<FronteggAuthenticator>,
106    external_login_password_mz_system: Option<Password>,
107    listeners_config: ListenersConfig,
108    unsafe_mode: bool,
109    workers: usize,
110    now: NowFn,
111    seed: u32,
112    storage_usage_collection_interval: Duration,
113    storage_usage_retention_period: Option<Duration>,
114    default_cluster_replica_size: String,
115    default_cluster_replication_factor: u32,
116    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
117    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
118    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
119    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
120    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
121
122    propagate_crashes: bool,
123    enable_tracing: bool,
124    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
125    // tracing.
126    orchestrator_tracing_cli_args: TracingCliArgs,
127    bootstrap_role: Option<String>,
128    deploy_generation: u64,
129    system_parameter_defaults: BTreeMap<String, String>,
130    internal_console_redirect_url: Option<String>,
131    metrics_registry: Option<MetricsRegistry>,
132    code_version: semver::Version,
133    capture: Option<SharedStorage>,
134    pub environment_id: EnvironmentId,
135}
136
137impl Default for TestHarness {
138    fn default() -> TestHarness {
139        TestHarness {
140            data_directory: None,
141            tls: None,
142            frontegg: None,
143            external_login_password_mz_system: None,
144            listeners_config: ListenersConfig {
145                sql: btreemap![
146                    "external".to_owned() => SqlListenerConfig {
147                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
148                        authenticator_kind: AuthenticatorKind::None,
149                        allowed_roles: AllowedRoles::Normal,
150                        enable_tls: false,
151                    },
152                    "internal".to_owned() => SqlListenerConfig {
153                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
154                        authenticator_kind: AuthenticatorKind::None,
155                        allowed_roles: AllowedRoles::NormalAndInternal,
156                        enable_tls: false,
157                    },
158                ],
159                http: btreemap![
160                    "external".to_owned() => HttpListenerConfig {
161                        base: BaseListenerConfig {
162                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
163                            authenticator_kind: AuthenticatorKind::None,
164                            allowed_roles: AllowedRoles::Normal,
165                            enable_tls: false,
166                        },
167                        routes: HttpRoutesEnabled{
168                            base: true,
169                            webhook: true,
170                            internal: false,
171                            metrics: false,
172                            profiling: false,
173                            mcp_agent: false,
174                            mcp_developer: false,
175                            console_config: true,
176                        },
177                    },
178                    "internal".to_owned() => HttpListenerConfig {
179                        base: BaseListenerConfig {
180                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
181                            authenticator_kind: AuthenticatorKind::None,
182                            allowed_roles: AllowedRoles::NormalAndInternal,
183                            enable_tls: false,
184                        },
185                        routes: HttpRoutesEnabled{
186                            base: true,
187                            webhook: true,
188                            internal: true,
189                            metrics: true,
190                            profiling: true,
191                            mcp_agent: false,
192                            mcp_developer: false,
193                            console_config: true,
194                        },
195                    },
196                ],
197            },
198            unsafe_mode: false,
199            workers: 1,
200            now: SYSTEM_TIME.clone(),
201            seed: rand::random(),
202            storage_usage_collection_interval: Duration::from_secs(3600),
203            storage_usage_retention_period: None,
204            default_cluster_replica_size: "scale=1,workers=1".to_string(),
205            default_cluster_replication_factor: 1,
206            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
207                size: "scale=1,workers=1".to_string(),
208                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
209            },
210            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
211                size: "scale=1,workers=1".to_string(),
212                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
213            },
214            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
215                size: "scale=1,workers=1".to_string(),
216                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
217            },
218            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
219                size: "scale=1,workers=1".to_string(),
220                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
221            },
222            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
223                size: "scale=1,workers=1".to_string(),
224                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
225            },
226            propagate_crashes: false,
227            enable_tracing: false,
228            bootstrap_role: Some("materialize".into()),
229            deploy_generation: 0,
230            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
231            // If we need those in the future, we might need to change both.
232            system_parameter_defaults: BTreeMap::from([(
233                "log_filter".to_string(),
234                "error".to_string(),
235            )]),
236            internal_console_redirect_url: None,
237            metrics_registry: None,
238            orchestrator_tracing_cli_args: TracingCliArgs {
239                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
240                ..Default::default()
241            },
242            code_version: crate::BUILD_INFO.semver_version(),
243            environment_id: EnvironmentId::for_tests(),
244            capture: None,
245        }
246    }
247}
248
249impl TestHarness {
250    /// Starts a test [`TestServer`], panicking if the server could not be started.
251    ///
252    /// For cases when startup might fail, see [`TestHarness::try_start`].
253    pub async fn start(self) -> TestServer {
254        self.try_start().await.expect("Failed to start test Server")
255    }
256
257    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
258    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
259        self.try_start_with_trigger(tls_reload_certs)
260            .await
261            .expect("Failed to start test Server")
262    }
263
264    /// Starts a test [`TestServer`], returning an error if the server could not be started.
265    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
266        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
267            .await
268    }
269
270    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
271    pub async fn try_start_with_trigger(
272        self,
273        tls_reload_certs: ReloadTrigger,
274    ) -> Result<TestServer, anyhow::Error> {
275        let listeners = Listeners::new(&self).await?;
276        listeners.serve_with_trigger(self, tls_reload_certs).await
277    }
278
279    /// Starts a runtime and returns a [`TestServerWithRuntime`].
280    pub fn start_blocking(self) -> TestServerWithRuntime {
281        let runtime = tokio::runtime::Builder::new_multi_thread()
282            .enable_all()
283            .thread_stack_size(mz_ore::stack::STACK_SIZE)
284            .build()
285            .expect("failed to spawn runtime for test");
286        let runtime = Arc::new(runtime);
287        let server = runtime.block_on(self.start());
288        TestServerWithRuntime { runtime, server }
289    }
290
291    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
292        self.data_directory = Some(data_directory.into());
293        self
294    }
295
296    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
297        self.tls = Some(TlsCertConfig {
298            cert: cert_path.into(),
299            key: key_path.into(),
300        });
301        for (_, listener) in &mut self.listeners_config.sql {
302            listener.enable_tls = true;
303        }
304        for (_, listener) in &mut self.listeners_config.http {
305            listener.base.enable_tls = true;
306        }
307        self
308    }
309
310    pub fn unsafe_mode(mut self) -> Self {
311        self.unsafe_mode = true;
312        self
313    }
314
315    pub fn workers(mut self, workers: usize) -> Self {
316        self.workers = workers;
317        self
318    }
319
320    pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
321        self.frontegg = Some(frontegg.clone());
322        let enable_tls = self.tls.is_some();
323        self.listeners_config = ListenersConfig {
324            sql: btreemap! {
325                "external".to_owned() => SqlListenerConfig {
326                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
327                    authenticator_kind: AuthenticatorKind::Frontegg,
328                    allowed_roles: AllowedRoles::Normal,
329                    enable_tls,
330                },
331                "internal".to_owned() => SqlListenerConfig {
332                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
333                    authenticator_kind: AuthenticatorKind::None,
334                    allowed_roles: AllowedRoles::NormalAndInternal,
335                    enable_tls: false,
336                },
337            },
338            http: btreemap! {
339                "external".to_owned() => HttpListenerConfig {
340                    base: BaseListenerConfig {
341                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
342                        authenticator_kind: AuthenticatorKind::Frontegg,
343                        allowed_roles: AllowedRoles::Normal,
344                        enable_tls,
345                    },
346                    routes: HttpRoutesEnabled{
347                        base: true,
348                        webhook: true,
349                        internal: false,
350                        metrics: false,
351                        profiling: false,
352                        mcp_agent: false,
353                        mcp_developer: false,
354                        console_config: true,
355                    },
356                },
357                "internal".to_owned() => HttpListenerConfig {
358                    base: BaseListenerConfig {
359                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
360                        authenticator_kind: AuthenticatorKind::None,
361                        allowed_roles: AllowedRoles::NormalAndInternal,
362                        enable_tls: false,
363                    },
364                    routes: HttpRoutesEnabled{
365                        base: true,
366                        webhook: true,
367                        internal: true,
368                        metrics: true,
369                        profiling: true,
370                        mcp_agent: false,
371                        mcp_developer: false,
372                        console_config: true,
373                    },
374                },
375            },
376        };
377        self
378    }
379
380    pub fn with_oidc_auth(
381        mut self,
382        issuer: Option<String>,
383        authentication_claim: Option<String>,
384        expected_audiences: Option<Vec<String>>,
385        external_login_password_mz_system: Option<Password>,
386    ) -> Self {
387        let enable_tls = self.tls.is_some();
388        self.listeners_config = ListenersConfig {
389            sql: btreemap! {
390                "external".to_owned() => SqlListenerConfig {
391                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
392                    authenticator_kind: AuthenticatorKind::Oidc,
393                    allowed_roles: AllowedRoles::NormalAndInternal,
394                    enable_tls,
395                },
396                "internal".to_owned() => SqlListenerConfig {
397                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
398                    authenticator_kind: AuthenticatorKind::None,
399                    allowed_roles: AllowedRoles::NormalAndInternal,
400                    enable_tls: false,
401                },
402            },
403            http: btreemap! {
404                "external".to_owned() => HttpListenerConfig {
405                    base: BaseListenerConfig {
406                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
407                        authenticator_kind: AuthenticatorKind::Oidc,
408                        allowed_roles: AllowedRoles::NormalAndInternal,
409                        enable_tls,
410                    },
411                    routes: HttpRoutesEnabled{
412                        base: true,
413                        webhook: true,
414                        internal: false,
415                        metrics: false,
416                        profiling: false,
417                        mcp_agent: false,
418                        mcp_developer: false,
419                        console_config: true,
420                    },
421                },
422                "internal".to_owned() => HttpListenerConfig {
423                    base: BaseListenerConfig {
424                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
425                        authenticator_kind: AuthenticatorKind::None,
426                        allowed_roles: AllowedRoles::NormalAndInternal,
427                        enable_tls: false,
428                    },
429                    routes: HttpRoutesEnabled{
430                        base: true,
431                        webhook: true,
432                        internal: true,
433                        metrics: true,
434                        profiling: true,
435                        mcp_agent: false,
436                        mcp_developer: false,
437                        console_config: true,
438                    },
439                },
440            },
441        };
442
443        if let Some(issuer) = issuer {
444            self.system_parameter_defaults
445                .insert("oidc_issuer".to_string(), issuer);
446        }
447
448        if let Some(authentication_claim) = authentication_claim {
449            self.system_parameter_defaults.insert(
450                "oidc_authentication_claim".to_string(),
451                authentication_claim,
452            );
453        }
454
455        if let Some(expected_audiences) = expected_audiences {
456            self.system_parameter_defaults.insert(
457                "oidc_audience".to_string(),
458                serde_json::to_string(&expected_audiences).unwrap(),
459            );
460        }
461
462        if let Some(external_login_password_mz_system) = external_login_password_mz_system {
463            self.external_login_password_mz_system = Some(external_login_password_mz_system);
464            self.system_parameter_defaults
465                .insert("enable_password_auth".to_string(), "true".to_string());
466        }
467
468        self
469    }
470
471    pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
472        self.external_login_password_mz_system = Some(mz_system_password);
473        let enable_tls = self.tls.is_some();
474        self.listeners_config = ListenersConfig {
475            sql: btreemap! {
476                "external".to_owned() => SqlListenerConfig {
477                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
478                    authenticator_kind: AuthenticatorKind::Password,
479                    allowed_roles: AllowedRoles::NormalAndInternal,
480                    enable_tls,
481                },
482            },
483            http: btreemap! {
484                "external".to_owned() => HttpListenerConfig {
485                    base: BaseListenerConfig {
486                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
487                        authenticator_kind: AuthenticatorKind::Password,
488                        allowed_roles: AllowedRoles::NormalAndInternal,
489                        enable_tls,
490                    },
491                    routes: HttpRoutesEnabled{
492                        base: true,
493                        webhook: true,
494                        internal: true,
495                        metrics: false,
496                        profiling: true,
497                        mcp_agent: false,
498                        mcp_developer: false,
499                        console_config: true,
500                    },
501                },
502                "metrics".to_owned() => HttpListenerConfig {
503                    base: BaseListenerConfig {
504                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
505                        authenticator_kind: AuthenticatorKind::None,
506                        allowed_roles: AllowedRoles::NormalAndInternal,
507                        enable_tls: false,
508                    },
509                    routes: HttpRoutesEnabled{
510                        base: false,
511                        webhook: false,
512                        internal: false,
513                        metrics: true,
514                        profiling: false,
515                        mcp_agent: false,
516                        mcp_developer: false,
517                        console_config: true,
518                    },
519                },
520            },
521        };
522        self
523    }
524
525    pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
526        self.external_login_password_mz_system = Some(mz_system_password);
527        let enable_tls = self.tls.is_some();
528        self.listeners_config = ListenersConfig {
529            sql: btreemap! {
530                "external".to_owned() => SqlListenerConfig {
531                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
532                    authenticator_kind: AuthenticatorKind::Sasl,
533                    allowed_roles: AllowedRoles::NormalAndInternal,
534                    enable_tls,
535                },
536            },
537            http: btreemap! {
538                "external".to_owned() => HttpListenerConfig {
539                    base: BaseListenerConfig {
540                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
541                        authenticator_kind: AuthenticatorKind::Password,
542                        allowed_roles: AllowedRoles::NormalAndInternal,
543                        enable_tls,
544                    },
545                    routes: HttpRoutesEnabled{
546                        base: true,
547                        webhook: true,
548                        internal: true,
549                        metrics: false,
550                        profiling: true,
551                        mcp_agent: false,
552                        mcp_developer: false,
553                        console_config: true,
554                    },
555                },
556                "metrics".to_owned() => HttpListenerConfig {
557                    base: BaseListenerConfig {
558                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
559                        authenticator_kind: AuthenticatorKind::None,
560                        allowed_roles: AllowedRoles::NormalAndInternal,
561                        enable_tls: false,
562                    },
563                    routes: HttpRoutesEnabled{
564                        base: false,
565                        webhook: false,
566                        internal: false,
567                        metrics: true,
568                        profiling: false,
569                        mcp_agent: false,
570                        mcp_developer: false,
571                        console_config: true,
572                    },
573                },
574            },
575        };
576        self
577    }
578
579    pub fn with_now(mut self, now: NowFn) -> Self {
580        self.now = now;
581        self
582    }
583
584    pub fn with_storage_usage_collection_interval(
585        mut self,
586        storage_usage_collection_interval: Duration,
587    ) -> Self {
588        self.storage_usage_collection_interval = storage_usage_collection_interval;
589        self
590    }
591
592    pub fn with_storage_usage_retention_period(
593        mut self,
594        storage_usage_retention_period: Duration,
595    ) -> Self {
596        self.storage_usage_retention_period = Some(storage_usage_retention_period);
597        self
598    }
599
600    pub fn with_default_cluster_replica_size(
601        mut self,
602        default_cluster_replica_size: String,
603    ) -> Self {
604        self.default_cluster_replica_size = default_cluster_replica_size;
605        self
606    }
607
608    pub fn with_builtin_system_cluster_replica_size(
609        mut self,
610        builtin_system_cluster_replica_size: String,
611    ) -> Self {
612        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
613        self
614    }
615
616    pub fn with_builtin_system_cluster_replication_factor(
617        mut self,
618        builtin_system_cluster_replication_factor: u32,
619    ) -> Self {
620        self.builtin_system_cluster_config.replication_factor =
621            builtin_system_cluster_replication_factor;
622        self
623    }
624
625    pub fn with_builtin_catalog_server_cluster_replica_size(
626        mut self,
627        builtin_catalog_server_cluster_replica_size: String,
628    ) -> Self {
629        self.builtin_catalog_server_cluster_config.size =
630            builtin_catalog_server_cluster_replica_size;
631        self
632    }
633
634    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
635        self.propagate_crashes = propagate_crashes;
636        self
637    }
638
639    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
640        self.enable_tracing = enable_tracing;
641        self
642    }
643
644    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
645        self.bootstrap_role = bootstrap_role;
646        self
647    }
648
649    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
650        self.deploy_generation = deploy_generation;
651        self
652    }
653
654    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
655        self.system_parameter_defaults.insert(param, value);
656        self
657    }
658
659    pub fn with_mcp_routes(mut self, agent: bool, developer: bool) -> Self {
660        for config in self.listeners_config.http.values_mut() {
661            config.routes.mcp_agent = agent;
662            config.routes.mcp_developer = developer;
663        }
664        self
665    }
666
667    pub fn with_internal_console_redirect_url(
668        mut self,
669        internal_console_redirect_url: Option<String>,
670    ) -> Self {
671        self.internal_console_redirect_url = internal_console_redirect_url;
672        self
673    }
674
675    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
676        self.metrics_registry = Some(registry);
677        self
678    }
679
680    pub fn with_code_version(mut self, version: semver::Version) -> Self {
681        self.code_version = version;
682        self
683    }
684
685    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
686        self.capture = Some(storage);
687        self
688    }
689}
690
691pub struct Listeners {
692    pub inner: crate::Listeners,
693}
694
695impl Listeners {
696    pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
697        let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
698        Ok(Listeners { inner })
699    }
700
701    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
702        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
703            .await
704    }
705
706    pub async fn serve_with_trigger(
707        self,
708        config: TestHarness,
709        tls_reload_certs: ReloadTrigger,
710    ) -> Result<TestServer, anyhow::Error> {
711        let (data_directory, temp_dir) = match config.data_directory {
712            None => {
713                // If no data directory is provided, we create a temporary
714                // directory. The temporary directory is cleaned up when the
715                // `TempDir` is dropped, so we keep it alive until the `Server` is
716                // dropped.
717                let temp_dir = tempfile::tempdir()?;
718                (temp_dir.path().to_path_buf(), Some(temp_dir))
719            }
720            Some(data_directory) => (data_directory, None),
721        };
722        let scratch_dir = tempfile::tempdir()?;
723        let (consensus_uri, timestamp_oracle_url) = {
724            let seed = config.seed;
725            let cockroach_url = env::var("METADATA_BACKEND_URL")
726                .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
727            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
728            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
729                if let Err(err) = conn.await {
730                    panic!("connection error: {}", err);
731                };
732            });
733            let consensus_schema = sql!("consensus_{}", seed);
734            let tsoracle_schema = sql!("tsoracle_{}", seed);
735            pg_batch_execute(
736                &client,
737                sql!(
738                    "CREATE SCHEMA IF NOT EXISTS {};
739                     CREATE SCHEMA IF NOT EXISTS {};",
740                    consensus_schema,
741                    tsoracle_schema,
742                ),
743            )
744            .await?;
745            (
746                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
747                    .parse()
748                    .expect("invalid consensus URI"),
749                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
750                    .parse()
751                    .expect("invalid timestamp oracle URI"),
752            )
753        };
754        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
755        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
756            image_dir: env::current_exe()?
757                .parent()
758                .unwrap()
759                .parent()
760                .unwrap()
761                .to_path_buf(),
762            suppress_output: false,
763            environment_id: config.environment_id.to_string(),
764            secrets_dir: data_directory.join("secrets"),
765            command_wrapper: vec![],
766            propagate_crashes: config.propagate_crashes,
767            tcp_proxy: None,
768            scratch_directory: scratch_dir.path().to_path_buf(),
769        })
770        .await?;
771        let orchestrator = Arc::new(orchestrator);
772        // Messing with the clock causes persist to expire leases, causing hangs and
773        // panics. Is it possible/desirable to put this back somehow?
774        let persist_now = SYSTEM_TIME.clone();
775        let dyncfgs = mz_dyncfgs::all_dyncfgs();
776
777        let mut updates = ConfigUpdates::default();
778        // Tune down the number of connections to make this all work a little easier
779        // with local postgres.
780        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
781        updates.apply(&dyncfgs);
782
783        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
784        persist_cfg.build_version = config.code_version;
785        // Stress persist more by writing rollups frequently
786        persist_cfg.set_rollup_threshold(5);
787
788        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
789        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
790        let persist_pubsub_tcp_listener =
791            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
792                .await
793                .expect("pubsub addr binding");
794        let persist_pubsub_server_port = persist_pubsub_tcp_listener
795            .local_addr()
796            .expect("pubsub addr has local addr")
797            .port();
798
799        // Spawn the persist pub-sub server.
800        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
801            persist_pubsub_server
802                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
803                .await
804                .expect("success")
805        });
806        let persist_clients =
807            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
808        let persist_clients = Arc::new(persist_clients);
809
810        let secrets_controller = Arc::clone(&orchestrator);
811        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
812        let orchestrator = Arc::new(TracingOrchestrator::new(
813            orchestrator,
814            config.orchestrator_tracing_cli_args,
815        ));
816        let tracing_handle = if config.enable_tracing {
817            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
818                service_name: "environmentd",
819                stderr_log: StderrLogConfig {
820                    format: StderrLogFormat::Json,
821                    filter: EnvFilter::default(),
822                },
823                opentelemetry: Some(OpenTelemetryConfig {
824                    endpoint: "http://fake_address_for_testing:8080".to_string(),
825                    headers: http::HeaderMap::new(),
826                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
827                    resource: opentelemetry_sdk::resource::Resource::builder().build(),
828                    max_batch_queue_size: 2048,
829                    max_export_batch_size: 512,
830                    max_concurrent_exports: 1,
831                    batch_scheduled_delay: Duration::from_millis(5000),
832                    max_export_timeout: Duration::from_secs(30),
833                }),
834                tokio_console: None,
835                sentry: None,
836                build_version: crate::BUILD_INFO.version,
837                build_sha: crate::BUILD_INFO.sha,
838                registry: metrics_registry.clone(),
839                capture: config.capture,
840            };
841            mz_ore::tracing::configure(config).await?
842        } else {
843            TracingHandle::disabled()
844        };
845        let host_name = format!(
846            "localhost:{}",
847            self.inner.http["external"].handle.local_addr.port()
848        );
849        let catalog_config = CatalogConfig {
850            persist_clients: Arc::clone(&persist_clients),
851            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
852        };
853
854        let inner = self
855            .inner
856            .serve(crate::Config {
857                catalog_config,
858                timestamp_oracle_url: Some(timestamp_oracle_url),
859                controller: ControllerConfig {
860                    build_info: &crate::BUILD_INFO,
861                    orchestrator,
862                    clusterd_image: "clusterd".into(),
863                    init_container_image: None,
864                    deploy_generation: config.deploy_generation,
865                    persist_location: PersistLocation {
866                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
867                            .parse()
868                            .expect("invalid blob URI"),
869                        consensus_uri,
870                    },
871                    persist_clients,
872                    now: config.now.clone(),
873                    metrics_registry: metrics_registry.clone(),
874                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
875                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
876                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
877                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
878                        secrets_reader_kubernetes_context: None,
879                        secrets_reader_aws_prefix: None,
880                        secrets_reader_name_prefix: None,
881                    },
882                    connection_context,
883                    replica_http_locator: Default::default(),
884                },
885                secrets_controller,
886                cloud_resource_controller: None,
887                tls: config.tls,
888                frontegg: config.frontegg,
889                frontegg_oauth_issuer_url: None,
890                unsafe_mode: config.unsafe_mode,
891                all_features: false,
892                metrics_registry: metrics_registry.clone(),
893                now: config.now,
894                environment_id: config.environment_id,
895                cors_allowed_origin: AllowOrigin::list([]),
896                cors_allowed_origin_list: Vec::new(),
897                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
898                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
899                bootstrap_default_cluster_replication_factor: config
900                    .default_cluster_replication_factor,
901                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
902                bootstrap_builtin_catalog_server_cluster_config: config
903                    .builtin_catalog_server_cluster_config,
904                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
905                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
906                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
907                system_parameter_defaults: config.system_parameter_defaults,
908                availability_zones: Default::default(),
909                tracing_handle,
910                storage_usage_collection_interval: config.storage_usage_collection_interval,
911                storage_usage_retention_period: config.storage_usage_retention_period,
912                segment_api_key: None,
913                segment_client_side: false,
914                test_only_dummy_segment_client: false,
915                egress_addresses: vec![],
916                aws_account_id: None,
917                aws_privatelink_availability_zones: None,
918                launchdarkly_sdk_key: None,
919                launchdarkly_key_map: Default::default(),
920                config_sync_file_path: None,
921                config_sync_timeout: Duration::from_secs(30),
922                config_sync_loop_interval: None,
923                bootstrap_role: config.bootstrap_role,
924                http_host_name: Some(host_name),
925                internal_console_redirect_url: config.internal_console_redirect_url,
926                tls_reload_certs,
927                helm_chart_version: None,
928                license_key: ValidatedLicenseKey::for_tests(),
929                external_login_password_mz_system: config.external_login_password_mz_system,
930                force_builtin_schema_migration: None,
931            })
932            .await?;
933
934        Ok(TestServer {
935            inner,
936            metrics_registry,
937            _temp_dir: temp_dir,
938            _scratch_dir: scratch_dir,
939        })
940    }
941}
942
943/// A running instance of `environmentd`.
944pub struct TestServer {
945    pub inner: crate::Server,
946    pub metrics_registry: MetricsRegistry,
947    /// The `TempDir`s are saved to prevent them from being dropped, and thus cleaned up too early.
948    _temp_dir: Option<TempDir>,
949    _scratch_dir: TempDir,
950}
951
952impl TestServer {
953    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
954        ConnectBuilder::new(self).no_tls()
955    }
956
957    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
958        let internal_client = self.connect().internal().await.unwrap();
959
960        for flag in flags {
961            let query = sql!("ALTER SYSTEM SET {} = true;", Sql::ident(flag));
962            pg_batch_execute(&internal_client, query).await.unwrap();
963        }
964    }
965
966    pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
967        let internal_client = self.connect().internal().await.unwrap();
968
969        for flag in flags {
970            let query = sql!("ALTER SYSTEM SET {} = false;", Sql::ident(flag));
971            pg_batch_execute(&internal_client, query).await.unwrap();
972        }
973    }
974
975    pub fn ws_addr(&self) -> Uri {
976        format!(
977            "ws://{}/api/experimental/sql",
978            self.inner.http_listener_handles["external"].local_addr
979        )
980        .parse()
981        .unwrap()
982    }
983
984    pub fn internal_ws_addr(&self) -> Uri {
985        format!(
986            "ws://{}/api/experimental/sql",
987            self.inner.http_listener_handles["internal"].local_addr
988        )
989        .parse()
990        .unwrap()
991    }
992
993    pub fn http_local_addr(&self) -> SocketAddr {
994        self.inner.http_listener_handles["external"].local_addr
995    }
996
997    pub fn internal_http_local_addr(&self) -> SocketAddr {
998        self.inner.http_listener_handles["internal"].local_addr
999    }
1000
1001    pub fn sql_local_addr(&self) -> SocketAddr {
1002        self.inner.sql_listener_handles["external"].local_addr
1003    }
1004
1005    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1006        self.inner.sql_listener_handles["internal"].local_addr
1007    }
1008}
1009
1010/// A builder struct to configure a pgwire connection to a running [`TestServer`].
1011///
1012/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
1013pub struct ConnectBuilder<'s, T, H> {
1014    /// A running `environmentd` test server.
1015    server: &'s TestServer,
1016
1017    /// Postgres configuration for connecting to the test server.
1018    pg_config: tokio_postgres::Config,
1019    /// Port to use when connecting to the test server.
1020    port: u16,
1021    /// Tls settings to use.
1022    tls: T,
1023
1024    /// Callback that gets invoked for every notice we receive.
1025    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
1026
1027    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
1028    _with_handle: H,
1029}
1030
1031impl<'s> ConnectBuilder<'s, (), NoHandle> {
1032    fn new(server: &'s TestServer) -> Self {
1033        let mut pg_config = tokio_postgres::Config::new();
1034        pg_config
1035            .host(&Ipv4Addr::LOCALHOST.to_string())
1036            .user("materialize")
1037            .options("--welcome_message=off")
1038            .application_name("environmentd_test_framework");
1039
1040        ConnectBuilder {
1041            server,
1042            pg_config,
1043            port: server.sql_local_addr().port(),
1044            tls: (),
1045            notice_callback: None,
1046            _with_handle: NoHandle,
1047        }
1048    }
1049}
1050
1051impl<'s, T, H> ConnectBuilder<'s, T, H> {
1052    /// Create a pgwire connection without using TLS.
1053    ///
1054    /// Note: this is the default for all connections.
1055    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
1056        ConnectBuilder {
1057            server: self.server,
1058            pg_config: self.pg_config,
1059            port: self.port,
1060            tls: postgres::NoTls,
1061            notice_callback: self.notice_callback,
1062            _with_handle: self._with_handle,
1063        }
1064    }
1065
1066    /// Create a pgwire connection with TLS.
1067    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1068    where
1069        Tls: MakeTlsConnect<Socket> + Send + 'static,
1070        Tls::TlsConnect: Send,
1071        Tls::Stream: Send,
1072        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1073    {
1074        ConnectBuilder {
1075            server: self.server,
1076            pg_config: self.pg_config,
1077            port: self.port,
1078            tls,
1079            notice_callback: self.notice_callback,
1080            _with_handle: self._with_handle,
1081        }
1082    }
1083
1084    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
1085    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1086        self.pg_config = pg_config;
1087        self
1088    }
1089
1090    /// Set the [`SslMode`] to be used with the resulting connection.
1091    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1092        self.pg_config.ssl_mode(mode);
1093        self
1094    }
1095
1096    /// Set the user for the pgwire connection.
1097    pub fn user(mut self, user: &str) -> Self {
1098        self.pg_config.user(user);
1099        self
1100    }
1101
1102    /// Set the password for the pgwire connection.
1103    pub fn password(mut self, password: &str) -> Self {
1104        self.pg_config.password(password);
1105        self
1106    }
1107
1108    /// Set the application name for the pgwire connection.
1109    pub fn application_name(mut self, application_name: &str) -> Self {
1110        self.pg_config.application_name(application_name);
1111        self
1112    }
1113
1114    /// Set the database name for the pgwire connection.
1115    pub fn dbname(mut self, dbname: &str) -> Self {
1116        self.pg_config.dbname(dbname);
1117        self
1118    }
1119
1120    /// Set the options for the pgwire connection.
1121    pub fn options(mut self, options: &str) -> Self {
1122        self.pg_config.options(options);
1123        self
1124    }
1125
1126    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
1127    /// [`TestServer`].
1128    ///
1129    /// For example, this will change the port we connect to, and the user we connect as.
1130    pub fn internal(mut self) -> Self {
1131        self.port = self.server.internal_sql_local_addr().port();
1132        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1133        self
1134    }
1135
1136    /// Sets a callback for any database notices that are received from the [`TestServer`].
1137    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1138        ConnectBuilder {
1139            notice_callback: Some(Box::new(callback)),
1140            ..self
1141        }
1142    }
1143
1144    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
1145    /// polling the underlying postgres connection, associated with the returned client.
1146    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1147        ConnectBuilder {
1148            server: self.server,
1149            pg_config: self.pg_config,
1150            port: self.port,
1151            tls: self.tls,
1152            notice_callback: self.notice_callback,
1153            _with_handle: WithHandle,
1154        }
1155    }
1156
1157    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
1158    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1159        &self.pg_config
1160    }
1161}
1162
1163/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
1164/// of a client connection.
1165pub trait IncludeHandle: Send {
1166    type Output;
1167    fn transform_result(
1168        client: tokio_postgres::Client,
1169        handle: mz_ore::task::JoinHandle<()>,
1170    ) -> Self::Output;
1171}
1172
1173/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
1174/// result of a [`ConnectBuilder`].
1175pub struct NoHandle;
1176impl IncludeHandle for NoHandle {
1177    type Output = tokio_postgres::Client;
1178    fn transform_result(
1179        client: tokio_postgres::Client,
1180        _handle: mz_ore::task::JoinHandle<()>,
1181    ) -> Self::Output {
1182        client
1183    }
1184}
1185
1186/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
1187/// a [`ConnectBuilder`].
1188pub struct WithHandle;
1189impl IncludeHandle for WithHandle {
1190    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1191    fn transform_result(
1192        client: tokio_postgres::Client,
1193        handle: mz_ore::task::JoinHandle<()>,
1194    ) -> Self::Output {
1195        (client, handle)
1196    }
1197}
1198
1199impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1200where
1201    T: MakeTlsConnect<Socket> + Send + 'static,
1202    T::TlsConnect: Send,
1203    T::Stream: Send,
1204    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1205    H: IncludeHandle,
1206{
1207    type Output = Result<H::Output, postgres::Error>;
1208    type IntoFuture = BoxFuture<'static, Self::Output>;
1209
1210    fn into_future(mut self) -> Self::IntoFuture {
1211        Box::pin(async move {
1212            assert!(
1213                self.pg_config.get_ports().is_empty(),
1214                "specifying multiple ports is not supported"
1215            );
1216            self.pg_config.port(self.port);
1217
1218            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1219            let mut notice_callback = self.notice_callback.take();
1220
1221            let handle = task::spawn(|| "connect", async move {
1222                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1223                    match msg {
1224                        Ok(AsyncMessage::Notice(notice)) => {
1225                            if let Some(callback) = notice_callback.as_mut() {
1226                                callback(notice);
1227                            }
1228                        }
1229                        Ok(msg) => {
1230                            tracing::debug!(?msg, "Dropping message from database");
1231                        }
1232                        Err(e) => {
1233                            // tokio_postgres::Connection docs say:
1234                            // > Return values of None or Some(Err(_)) are “terminal”; callers
1235                            // > should not invoke this method again after receiving one of those
1236                            // > values.
1237                            tracing::info!("connection error: {e}");
1238                            break;
1239                        }
1240                    }
1241                }
1242                tracing::info!("connection closed");
1243            });
1244
1245            let output = H::transform_result(client, handle);
1246            Ok(output)
1247        })
1248    }
1249}
1250
1251/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
1252///
1253/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
1254/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
1255pub struct TestServerWithRuntime {
1256    server: TestServer,
1257    runtime: Arc<Runtime>,
1258}
1259
1260impl TestServerWithRuntime {
1261    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
1262    ///
1263    /// Can be used to spawn async tasks.
1264    pub fn runtime(&self) -> &Arc<Runtime> {
1265        &self.runtime
1266    }
1267
1268    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
1269    pub fn inner(&self) -> &crate::Server {
1270        &self.server.inner
1271    }
1272
1273    /// Connect to the __public__ SQL port of the running `environmentd` server.
1274    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1275    where
1276        T: MakeTlsConnect<Socket> + Send + 'static,
1277        T::TlsConnect: Send,
1278        T::Stream: Send,
1279        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1280    {
1281        self.pg_config().connect(tls)
1282    }
1283
1284    /// Connect to the __internal__ SQL port of the running `environmentd` server.
1285    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1286    where
1287        T: MakeTlsConnect<Socket> + Send + 'static,
1288        T::TlsConnect: Send,
1289        T::Stream: Send,
1290        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1291    {
1292        Ok(self.pg_config_internal().connect(tls)?)
1293    }
1294
1295    /// Enable LaunchDarkly feature flags.
1296    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1297        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1298
1299        for flag in flags {
1300            let query = sql!("ALTER SYSTEM SET {} = true;", Sql::ident(flag));
1301            // This uses the synchronous `postgres::Client`; wrappers are async
1302            // and currently only defined for tokio-postgres clients.
1303            #[allow(clippy::disallowed_methods)]
1304            internal_client.batch_execute(query.as_str()).unwrap();
1305        }
1306    }
1307
1308    /// Disable LaunchDarkly feature flags.
1309    pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1310        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1311
1312        for flag in flags {
1313            let query = sql!("ALTER SYSTEM SET {} = false;", Sql::ident(flag));
1314            // This uses the synchronous `postgres::Client`; wrappers are async
1315            // and currently only defined for tokio-postgres clients.
1316            #[allow(clippy::disallowed_methods)]
1317            internal_client.batch_execute(query.as_str()).unwrap();
1318        }
1319    }
1320
1321    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
1322    /// `environmentd` server.
1323    pub fn pg_config(&self) -> postgres::Config {
1324        let local_addr = self.server.sql_local_addr();
1325        let mut config = postgres::Config::new();
1326        config
1327            .host(&Ipv4Addr::LOCALHOST.to_string())
1328            .port(local_addr.port())
1329            .user("materialize")
1330            .options("--welcome_message=off");
1331        config
1332    }
1333
1334    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
1335    /// `environmentd` server.
1336    pub fn pg_config_internal(&self) -> postgres::Config {
1337        let local_addr = self.server.internal_sql_local_addr();
1338        let mut config = postgres::Config::new();
1339        config
1340            .host(&Ipv4Addr::LOCALHOST.to_string())
1341            .port(local_addr.port())
1342            .user("mz_system")
1343            .options("--welcome_message=off");
1344        config
1345    }
1346
1347    pub fn ws_addr(&self) -> Uri {
1348        self.server.ws_addr()
1349    }
1350
1351    pub fn internal_ws_addr(&self) -> Uri {
1352        self.server.internal_ws_addr()
1353    }
1354
1355    pub fn http_local_addr(&self) -> SocketAddr {
1356        self.server.http_local_addr()
1357    }
1358
1359    pub fn internal_http_local_addr(&self) -> SocketAddr {
1360        self.server.internal_http_local_addr()
1361    }
1362
1363    pub fn sql_local_addr(&self) -> SocketAddr {
1364        self.server.sql_local_addr()
1365    }
1366
1367    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1368        self.server.internal_sql_local_addr()
1369    }
1370
1371    /// Returns the metrics registry for the test server.
1372    pub fn metrics_registry(&self) -> &MetricsRegistry {
1373        &self.server.metrics_registry
1374    }
1375}
1376
1377#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1378pub struct MzTimestamp(pub u64);
1379
1380impl<'a> FromSql<'a> for MzTimestamp {
1381    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1382        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1383        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1384    }
1385
1386    fn accepts(ty: &Type) -> bool {
1387        mz_pgrepr::Numeric::accepts(ty)
1388    }
1389}
1390
1391pub trait PostgresErrorExt {
1392    fn unwrap_db_error(self) -> DbError;
1393}
1394
1395impl PostgresErrorExt for postgres::Error {
1396    fn unwrap_db_error(self) -> DbError {
1397        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1398            Some(e) => e.clone(),
1399            None => panic!("expected DbError, but got: {:?}", self),
1400        }
1401    }
1402}
1403
1404impl<T, E> PostgresErrorExt for Result<T, E>
1405where
1406    E: PostgresErrorExt,
1407{
1408    fn unwrap_db_error(self) -> DbError {
1409        match self {
1410            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1411            Err(e) => e.unwrap_db_error(),
1412        }
1413    }
1414}
1415
1416/// Group commit will block writes until the current time has advanced. This can make
1417/// performing inserts while using deterministic time difficult. This is a helper
1418/// method to perform writes and advance the current time.
1419pub async fn insert_with_deterministic_timestamps(
1420    table: &'static str,
1421    values: &'static str,
1422    server: &TestServer,
1423    now: Arc<std::sync::Mutex<EpochMillis>>,
1424) -> Result<(), Box<dyn Error>> {
1425    let client_write = server.connect().await?;
1426    let client_read = server.connect().await?;
1427
1428    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1429
1430    let insert_query = format!("INSERT INTO {} VALUES {values}", Sql::ident(table));
1431
1432    // The `values` fragment is raw SQL text in test code and cannot currently
1433    // be represented as a composable `Sql` fragment.
1434    #[allow(clippy::disallowed_methods)]
1435    let write_future = client_write.execute(&insert_query, &[]);
1436    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1437
1438    let mut write_future = std::pin::pin!(write_future);
1439    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1440
1441    // Keep increasing `now` until the write has executed succeed. Table advancements may
1442    // have increased the global timestamp by an unknown amount.
1443    loop {
1444        tokio::select! {
1445            _ = (&mut write_future) => return Ok(()),
1446            _ = timestamp_interval.tick() => {
1447                current_ts += 1;
1448                *now.lock().expect("lock poisoned") = current_ts;
1449            }
1450        };
1451    }
1452}
1453
1454pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1455    try_get_explain_timestamp(from_suffix, client)
1456        .await
1457        .unwrap()
1458}
1459
1460pub async fn try_get_explain_timestamp(
1461    from_suffix: &str,
1462    client: &Client,
1463) -> Result<EpochMillis, anyhow::Error> {
1464    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1465    let ts = det.determination.timestamp_context.timestamp_or_default();
1466    Ok(ts.into())
1467}
1468
1469pub async fn get_explain_timestamp_determination(
1470    from_suffix: &str,
1471    client: &Client,
1472) -> Result<TimestampExplanation, anyhow::Error> {
1473    // `from_suffix` is a raw SQL suffix used by this test helper and cannot
1474    // currently be represented as a composable `Sql` fragment.
1475    #[allow(clippy::disallowed_methods)]
1476    let row = client
1477        .query_one(
1478            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1479            &[],
1480        )
1481        .await?;
1482    let explain: String = row.get(0);
1483    Ok(serde_json::from_str(&explain).unwrap())
1484}
1485
1486/// Helper function to create a Postgres source.
1487///
1488/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1489/// Postgres state.
1490///
1491/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1492/// use different postgres tables.
1493pub async fn create_postgres_source_with_table<'a>(
1494    server: &TestServer,
1495    mz_client: &Client,
1496    table_name: &str,
1497    table_schema: &str,
1498    source_name: &str,
1499) -> (
1500    Client,
1501    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1502) {
1503    server
1504        .enable_feature_flags(&["enable_create_table_from_source"])
1505        .await;
1506
1507    let postgres_url = env::var("POSTGRES_URL")
1508        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1509        .unwrap();
1510
1511    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1512        .await
1513        .unwrap();
1514
1515    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1516    let user = pg_config.get_user().unwrap_or("postgres");
1517    let db_name = pg_config.get_dbname().unwrap_or(user);
1518    let ports = pg_config.get_ports();
1519    let port = if ports.is_empty() { 5432 } else { ports[0] };
1520    let hosts = pg_config.get_hosts();
1521    let host = if hosts.is_empty() {
1522        "localhost".to_string()
1523    } else {
1524        match &hosts[0] {
1525            Host::Tcp(host) => host.to_string(),
1526            Host::Unix(host) => host.to_str().unwrap().to_string(),
1527        }
1528    };
1529    let password = pg_config.get_password();
1530
1531    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1532        if let Err(e) = connection.await {
1533            panic!("connection error: {}", e);
1534        }
1535    });
1536
1537    // Create table in Postgres with publication.
1538    let _ = pg_execute(
1539        &pg_client,
1540        sql!("DROP TABLE IF EXISTS {};", Sql::ident(table_name)),
1541        &[],
1542    )
1543    .await
1544    .unwrap();
1545    let _ = pg_execute(
1546        &pg_client,
1547        sql!("DROP PUBLICATION IF EXISTS {};", Sql::ident(source_name)),
1548        &[],
1549    )
1550    .await
1551    .unwrap();
1552    // `table_schema` is a raw schema fragment in this test helper and cannot
1553    // currently be represented as a composable `Sql` fragment.
1554    #[allow(clippy::disallowed_methods)]
1555    let _ = pg_client
1556        .execute(
1557            format!("CREATE TABLE {} {table_schema};", Sql::ident(table_name)).as_str(),
1558            &[],
1559        )
1560        .await
1561        .unwrap();
1562    let _ = pg_execute(
1563        &pg_client,
1564        sql!(
1565            "ALTER TABLE {} REPLICA IDENTITY FULL;",
1566            Sql::ident(table_name)
1567        ),
1568        &[],
1569    )
1570    .await
1571    .unwrap();
1572    let _ = pg_execute(
1573        &pg_client,
1574        sql!(
1575            "CREATE PUBLICATION {} FOR TABLE {};",
1576            Sql::ident(source_name),
1577            Sql::ident(table_name)
1578        ),
1579        &[],
1580    )
1581    .await
1582    .unwrap();
1583
1584    // Create postgres source in Materialize.
1585    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1586    if let Some(password) = password {
1587        let password = std::str::from_utf8(password).unwrap();
1588        pg_batch_execute(
1589            mz_client,
1590            sql!("CREATE SECRET s AS {}", Sql::literal(password)),
1591        )
1592        .await
1593        .unwrap();
1594        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1595    }
1596    // `connection_str` is a raw connection-option fragment generated for tests
1597    // and cannot currently be represented as a composable `Sql` fragment.
1598    #[allow(clippy::disallowed_methods)]
1599    mz_client
1600        .batch_execute(format!("CREATE CONNECTION pgconn TO POSTGRES ({connection_str})").as_str())
1601        .await
1602        .unwrap();
1603    pg_batch_execute(
1604        mz_client,
1605        sql!(
1606            "CREATE SOURCE {} \
1607             FROM POSTGRES \
1608             CONNECTION pgconn \
1609             (PUBLICATION {})",
1610            Sql::ident(source_name),
1611            Sql::literal(source_name),
1612        ),
1613    )
1614    .await
1615    .unwrap();
1616    pg_batch_execute(
1617        mz_client,
1618        sql!(
1619            "CREATE TABLE {} \
1620             FROM SOURCE {} \
1621             (REFERENCE {});",
1622            Sql::ident(table_name),
1623            Sql::ident(source_name),
1624            Sql::ident(table_name),
1625        ),
1626    )
1627    .await
1628    .unwrap();
1629
1630    let table_name = table_name.to_string();
1631    let source_name = source_name.to_string();
1632    (
1633        pg_client,
1634        move |mz_client: &'a Client, pg_client: &'a Client| {
1635            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1636                pg_batch_execute(
1637                    mz_client,
1638                    sql!("DROP SOURCE {} CASCADE;", Sql::ident(&source_name)),
1639                )
1640                .await
1641                .unwrap();
1642                pg_batch_execute(mz_client, sql!("DROP CONNECTION pgconn;"))
1643                    .await
1644                    .unwrap();
1645
1646                let _ = pg_execute(
1647                    pg_client,
1648                    sql!("DROP PUBLICATION {};", Sql::ident(&source_name)),
1649                    &[],
1650                )
1651                .await
1652                .unwrap();
1653                let _ = pg_execute(
1654                    pg_client,
1655                    sql!("DROP TABLE {};", Sql::ident(&table_name)),
1656                    &[],
1657                )
1658                .await
1659                .unwrap();
1660            });
1661            f
1662        },
1663    )
1664}
1665
1666pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1667    let current_isolation = pg_query_one(mz_client, sql!("SHOW transaction_isolation"), &[])
1668        .await
1669        .unwrap()
1670        .get::<_, String>(0);
1671    pg_batch_execute(mz_client, sql!("SET transaction_isolation = SERIALIZABLE"))
1672        .await
1673        .unwrap();
1674    Retry::default()
1675        .retry_async(|_| async move {
1676            let rows = pg_query_one(
1677                mz_client,
1678                sql!("SELECT COUNT(*) FROM {};", Sql::ident(view_name)),
1679                &[],
1680            )
1681            .await
1682            .unwrap()
1683            .get::<_, i64>(0);
1684            if rows == source_rows {
1685                Ok(())
1686            } else {
1687                Err(format!(
1688                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1689                ))
1690            }
1691        })
1692        .await
1693        .unwrap();
1694    pg_batch_execute(
1695        mz_client,
1696        sql!(
1697            "SET transaction_isolation = {}",
1698            Sql::literal(&current_isolation),
1699        ),
1700    )
1701    .await
1702    .unwrap();
1703}
1704
1705// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1706pub fn auth_with_ws(
1707    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1708    mut options: BTreeMap<String, String>,
1709) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1710    if !options.contains_key("welcome_message") {
1711        options.insert("welcome_message".into(), "off".into());
1712    }
1713    auth_with_ws_impl(
1714        ws,
1715        Message::Text(
1716            serde_json::to_string(&WebSocketAuth::Basic {
1717                user: "materialize".into(),
1718                password: "".into(),
1719                options,
1720            })
1721            .unwrap()
1722            .into(),
1723        ),
1724    )
1725}
1726
1727pub fn auth_with_ws_impl(
1728    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1729    auth_message: Message,
1730) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1731    ws.send(auth_message)?;
1732
1733    // Wait for initial ready response.
1734    let mut msgs = Vec::new();
1735    loop {
1736        let resp = ws.read()?;
1737        match resp {
1738            Message::Text(msg) => {
1739                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1740                match msg {
1741                    WebSocketResponse::ReadyForQuery(_) => break,
1742                    msg => {
1743                        msgs.push(msg);
1744                    }
1745                }
1746            }
1747            Message::Ping(_) => continue,
1748            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1749            Message::Close(Some(close_frame)) => {
1750                return Err(anyhow!("ws closed after auth").context(close_frame));
1751            }
1752            _ => panic!("unexpected response: {:?}", resp),
1753        }
1754    }
1755    Ok(msgs)
1756}
1757
1758pub fn make_header<H: Header>(h: H) -> HeaderMap {
1759    let mut map = HeaderMap::new();
1760    map.typed_insert(h);
1761    map
1762}
1763
1764pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1765where
1766    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1767{
1768    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1769    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1770    // messages with TLS v1.2.
1771    //
1772    // Briefly, in TLS v1.3, failing to present a client certificate does not
1773    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1774    // attempt to read from the stream. But both `postgres` and `hyper` write a
1775    // bunch of data before attempting to read from the stream. With a failed
1776    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1777    // out this data, and then return a nice error message on the call to read.
1778    // But sometimes the connection is closed before they write out the data,
1779    // and so they report "connection closed" before they ever call read, never
1780    // noticing the underlying SSL error.
1781    //
1782    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1783    // if writing to the stream fails to see if a TLS error occured? Is it on
1784    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1785    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1786    // just avoid TLS v1.3 for now.
1787    //
1788    // [1]: https://github.com/openssl/openssl/issues/11118
1789    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1790    connector_builder.set_options(options);
1791    configure(&mut connector_builder).unwrap();
1792    MakeTlsConnector::new(connector_builder.build())
1793}
1794
1795/// A certificate authority for use in tests.
1796pub struct Ca {
1797    pub dir: TempDir,
1798    pub name: X509Name,
1799    pub cert: X509,
1800    pub pkey: PKey<Private>,
1801}
1802
1803impl Ca {
1804    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1805        let dir = tempfile::tempdir()?;
1806        let rsa = Rsa::generate(2048)?;
1807        let pkey = PKey::from_rsa(rsa)?;
1808        let name = {
1809            let mut builder = X509NameBuilder::new()?;
1810            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1811            builder.build()
1812        };
1813        let cert = {
1814            let mut builder = X509::builder()?;
1815            builder.set_version(2)?;
1816            builder.set_pubkey(&pkey)?;
1817            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1818            builder.set_subject_name(&name)?;
1819            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1820            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1821            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1822            builder.sign(
1823                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1824                MessageDigest::sha256(),
1825            )?;
1826            builder.build()
1827        };
1828        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1829        Ok(Ca {
1830            dir,
1831            name,
1832            cert,
1833            pkey,
1834        })
1835    }
1836
1837    /// Creates a new root certificate authority.
1838    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1839        Ca::make_ca(name, None)
1840    }
1841
1842    /// Returns the path to the CA's certificate.
1843    pub fn ca_cert_path(&self) -> PathBuf {
1844        self.dir.path().join("ca.crt")
1845    }
1846
1847    /// Requests a new intermediate certificate authority.
1848    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1849        Ca::make_ca(name, Some(self))
1850    }
1851
1852    /// Generates a certificate with the specified Common Name (CN) that is
1853    /// signed by the CA.
1854    ///
1855    /// Returns the paths to the certificate and key.
1856    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1857        self.request_cert(name, iter::empty())
1858    }
1859
1860    /// Like `request_client_cert`, but permits specifying additional IP
1861    /// addresses to attach as Subject Alternate Names.
1862    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1863    where
1864        I: IntoIterator<Item = IpAddr>,
1865    {
1866        let rsa = Rsa::generate(2048)?;
1867        let pkey = PKey::from_rsa(rsa)?;
1868        let subject_name = {
1869            let mut builder = X509NameBuilder::new()?;
1870            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1871            builder.build()
1872        };
1873        let cert = {
1874            let mut builder = X509::builder()?;
1875            builder.set_version(2)?;
1876            builder.set_pubkey(&pkey)?;
1877            builder.set_issuer_name(self.cert.subject_name())?;
1878            builder.set_subject_name(&subject_name)?;
1879            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1880            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1881            for ip in ips {
1882                builder.append_extension(
1883                    SubjectAlternativeName::new()
1884                        .ip(&ip.to_string())
1885                        .build(&builder.x509v3_context(None, None))?,
1886                )?;
1887            }
1888            builder.sign(&self.pkey, MessageDigest::sha256())?;
1889            builder.build()
1890        };
1891        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1892        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1893        fs::write(&cert_path, cert.to_pem()?)?;
1894        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1895        Ok((cert_path, key_path))
1896    }
1897}