Skip to main content

mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingHandle,
49};
50use mz_persist_client::PersistLocation;
51use mz_persist_client::cache::PersistClientCache;
52use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
53use mz_persist_client::rpc::PersistGrpcPubSubServer;
54use mz_secrets::SecretsController;
55use mz_server_core::listeners::{
56    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
57};
58use mz_server_core::{ReloadTrigger, TlsCertConfig};
59use mz_sql::catalog::EnvironmentId;
60use mz_storage_types::connections::ConnectionContext;
61use mz_tracing::CloneableEnvFilter;
62use openssl::asn1::Asn1Time;
63use openssl::error::ErrorStack;
64use openssl::hash::MessageDigest;
65use openssl::nid::Nid;
66use openssl::pkey::{PKey, Private};
67use openssl::rsa::Rsa;
68use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
69use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
70use openssl::x509::{X509, X509Name, X509NameBuilder};
71use postgres::error::DbError;
72use postgres::tls::{MakeTlsConnect, TlsConnect};
73use postgres::types::{FromSql, Type};
74use postgres::{NoTls, Socket};
75use postgres_openssl::MakeTlsConnector;
76use tempfile::TempDir;
77use tokio::net::TcpListener;
78use tokio::runtime::Runtime;
79use tokio_postgres::config::{Host, SslMode};
80use tokio_postgres::{AsyncMessage, Client};
81use tokio_stream::wrappers::TcpListenerStream;
82use tower_http::cors::AllowOrigin;
83use tracing::Level;
84use tracing_capture::SharedStorage;
85use tracing_subscriber::EnvFilter;
86use tungstenite::stream::MaybeTlsStream;
87use tungstenite::{Message, WebSocket};
88
89use crate::{
90    CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
91    WebSocketAuth, WebSocketResponse,
92};
93
94pub static KAFKA_ADDRS: LazyLock<String> =
95    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
96
97/// Entry point for creating and configuring an `environmentd` test harness.
98#[derive(Clone)]
99pub struct TestHarness {
100    data_directory: Option<PathBuf>,
101    tls: Option<TlsCertConfig>,
102    frontegg: Option<FronteggAuthenticator>,
103    external_login_password_mz_system: Option<Password>,
104    listeners_config: ListenersConfig,
105    unsafe_mode: bool,
106    workers: usize,
107    now: NowFn,
108    seed: u32,
109    storage_usage_collection_interval: Duration,
110    storage_usage_retention_period: Option<Duration>,
111    default_cluster_replica_size: String,
112    default_cluster_replication_factor: u32,
113    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
114    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
115    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
116    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
117    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
118
119    propagate_crashes: bool,
120    enable_tracing: bool,
121    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
122    // tracing.
123    orchestrator_tracing_cli_args: TracingCliArgs,
124    bootstrap_role: Option<String>,
125    deploy_generation: u64,
126    system_parameter_defaults: BTreeMap<String, String>,
127    internal_console_redirect_url: Option<String>,
128    metrics_registry: Option<MetricsRegistry>,
129    code_version: semver::Version,
130    capture: Option<SharedStorage>,
131    pub environment_id: EnvironmentId,
132}
133
134impl Default for TestHarness {
135    fn default() -> TestHarness {
136        TestHarness {
137            data_directory: None,
138            tls: None,
139            frontegg: None,
140            external_login_password_mz_system: None,
141            listeners_config: ListenersConfig {
142                sql: btreemap![
143                    "external".to_owned() => SqlListenerConfig {
144                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
145                        authenticator_kind: AuthenticatorKind::None,
146                        allowed_roles: AllowedRoles::Normal,
147                        enable_tls: false,
148                    },
149                    "internal".to_owned() => SqlListenerConfig {
150                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
151                        authenticator_kind: AuthenticatorKind::None,
152                        allowed_roles: AllowedRoles::NormalAndInternal,
153                        enable_tls: false,
154                    },
155                ],
156                http: btreemap![
157                    "external".to_owned() => HttpListenerConfig {
158                        base: BaseListenerConfig {
159                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
160                            authenticator_kind: AuthenticatorKind::None,
161                            allowed_roles: AllowedRoles::Normal,
162                            enable_tls: false,
163                        },
164                        routes: HttpRoutesEnabled{
165                            base: true,
166                            webhook: true,
167                            internal: false,
168                            metrics: false,
169                            profiling: false,
170                            mcp_agent: false,
171                            mcp_developer: false,
172                            console_config: true,
173                        },
174                    },
175                    "internal".to_owned() => HttpListenerConfig {
176                        base: BaseListenerConfig {
177                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
178                            authenticator_kind: AuthenticatorKind::None,
179                            allowed_roles: AllowedRoles::NormalAndInternal,
180                            enable_tls: false,
181                        },
182                        routes: HttpRoutesEnabled{
183                            base: true,
184                            webhook: true,
185                            internal: true,
186                            metrics: true,
187                            profiling: true,
188                            mcp_agent: false,
189                            mcp_developer: false,
190                            console_config: true,
191                        },
192                    },
193                ],
194            },
195            unsafe_mode: false,
196            workers: 1,
197            now: SYSTEM_TIME.clone(),
198            seed: rand::random(),
199            storage_usage_collection_interval: Duration::from_secs(3600),
200            storage_usage_retention_period: None,
201            default_cluster_replica_size: "scale=1,workers=1".to_string(),
202            default_cluster_replication_factor: 1,
203            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
204                size: "scale=1,workers=1".to_string(),
205                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
206            },
207            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
208                size: "scale=1,workers=1".to_string(),
209                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
210            },
211            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
212                size: "scale=1,workers=1".to_string(),
213                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
214            },
215            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
216                size: "scale=1,workers=1".to_string(),
217                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
218            },
219            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
220                size: "scale=1,workers=1".to_string(),
221                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
222            },
223            propagate_crashes: false,
224            enable_tracing: false,
225            bootstrap_role: Some("materialize".into()),
226            deploy_generation: 0,
227            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
228            // If we need those in the future, we might need to change both.
229            system_parameter_defaults: BTreeMap::from([(
230                "log_filter".to_string(),
231                "error".to_string(),
232            )]),
233            internal_console_redirect_url: None,
234            metrics_registry: None,
235            orchestrator_tracing_cli_args: TracingCliArgs {
236                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
237                ..Default::default()
238            },
239            code_version: crate::BUILD_INFO.semver_version(),
240            environment_id: EnvironmentId::for_tests(),
241            capture: None,
242        }
243    }
244}
245
246impl TestHarness {
247    /// Starts a test [`TestServer`], panicking if the server could not be started.
248    ///
249    /// For cases when startup might fail, see [`TestHarness::try_start`].
250    pub async fn start(self) -> TestServer {
251        self.try_start().await.expect("Failed to start test Server")
252    }
253
254    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
255    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
256        self.try_start_with_trigger(tls_reload_certs)
257            .await
258            .expect("Failed to start test Server")
259    }
260
261    /// Starts a test [`TestServer`], returning an error if the server could not be started.
262    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
263        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
264            .await
265    }
266
267    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
268    pub async fn try_start_with_trigger(
269        self,
270        tls_reload_certs: ReloadTrigger,
271    ) -> Result<TestServer, anyhow::Error> {
272        let listeners = Listeners::new(&self).await?;
273        listeners.serve_with_trigger(self, tls_reload_certs).await
274    }
275
276    /// Starts a runtime and returns a [`TestServerWithRuntime`].
277    pub fn start_blocking(self) -> TestServerWithRuntime {
278        let runtime = tokio::runtime::Builder::new_multi_thread()
279            .enable_all()
280            .thread_stack_size(mz_ore::stack::STACK_SIZE)
281            .build()
282            .expect("failed to spawn runtime for test");
283        let runtime = Arc::new(runtime);
284        let server = runtime.block_on(self.start());
285        TestServerWithRuntime { runtime, server }
286    }
287
288    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
289        self.data_directory = Some(data_directory.into());
290        self
291    }
292
293    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
294        self.tls = Some(TlsCertConfig {
295            cert: cert_path.into(),
296            key: key_path.into(),
297        });
298        for (_, listener) in &mut self.listeners_config.sql {
299            listener.enable_tls = true;
300        }
301        for (_, listener) in &mut self.listeners_config.http {
302            listener.base.enable_tls = true;
303        }
304        self
305    }
306
307    pub fn unsafe_mode(mut self) -> Self {
308        self.unsafe_mode = true;
309        self
310    }
311
312    pub fn workers(mut self, workers: usize) -> Self {
313        self.workers = workers;
314        self
315    }
316
317    pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
318        self.frontegg = Some(frontegg.clone());
319        let enable_tls = self.tls.is_some();
320        self.listeners_config = ListenersConfig {
321            sql: btreemap! {
322                "external".to_owned() => SqlListenerConfig {
323                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
324                    authenticator_kind: AuthenticatorKind::Frontegg,
325                    allowed_roles: AllowedRoles::Normal,
326                    enable_tls,
327                },
328                "internal".to_owned() => SqlListenerConfig {
329                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
330                    authenticator_kind: AuthenticatorKind::None,
331                    allowed_roles: AllowedRoles::NormalAndInternal,
332                    enable_tls: false,
333                },
334            },
335            http: btreemap! {
336                "external".to_owned() => HttpListenerConfig {
337                    base: BaseListenerConfig {
338                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
339                        authenticator_kind: AuthenticatorKind::Frontegg,
340                        allowed_roles: AllowedRoles::Normal,
341                        enable_tls,
342                    },
343                    routes: HttpRoutesEnabled{
344                        base: true,
345                        webhook: true,
346                        internal: false,
347                        metrics: false,
348                        profiling: false,
349                        mcp_agent: false,
350                        mcp_developer: false,
351                        console_config: true,
352                    },
353                },
354                "internal".to_owned() => HttpListenerConfig {
355                    base: BaseListenerConfig {
356                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
357                        authenticator_kind: AuthenticatorKind::None,
358                        allowed_roles: AllowedRoles::NormalAndInternal,
359                        enable_tls: false,
360                    },
361                    routes: HttpRoutesEnabled{
362                        base: true,
363                        webhook: true,
364                        internal: true,
365                        metrics: true,
366                        profiling: true,
367                        mcp_agent: false,
368                        mcp_developer: false,
369                        console_config: true,
370                    },
371                },
372            },
373        };
374        self
375    }
376
377    pub fn with_oidc_auth(
378        mut self,
379        issuer: Option<String>,
380        authentication_claim: Option<String>,
381        expected_audiences: Option<Vec<String>>,
382        external_login_password_mz_system: Option<Password>,
383    ) -> Self {
384        let enable_tls = self.tls.is_some();
385        self.listeners_config = ListenersConfig {
386            sql: btreemap! {
387                "external".to_owned() => SqlListenerConfig {
388                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
389                    authenticator_kind: AuthenticatorKind::Oidc,
390                    allowed_roles: AllowedRoles::NormalAndInternal,
391                    enable_tls,
392                },
393                "internal".to_owned() => SqlListenerConfig {
394                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
395                    authenticator_kind: AuthenticatorKind::None,
396                    allowed_roles: AllowedRoles::NormalAndInternal,
397                    enable_tls: false,
398                },
399            },
400            http: btreemap! {
401                "external".to_owned() => HttpListenerConfig {
402                    base: BaseListenerConfig {
403                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
404                        authenticator_kind: AuthenticatorKind::Oidc,
405                        allowed_roles: AllowedRoles::NormalAndInternal,
406                        enable_tls,
407                    },
408                    routes: HttpRoutesEnabled{
409                        base: true,
410                        webhook: true,
411                        internal: false,
412                        metrics: false,
413                        profiling: false,
414                        mcp_agent: false,
415                        mcp_developer: false,
416                        console_config: true,
417                    },
418                },
419                "internal".to_owned() => HttpListenerConfig {
420                    base: BaseListenerConfig {
421                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
422                        authenticator_kind: AuthenticatorKind::None,
423                        allowed_roles: AllowedRoles::NormalAndInternal,
424                        enable_tls: false,
425                    },
426                    routes: HttpRoutesEnabled{
427                        base: true,
428                        webhook: true,
429                        internal: true,
430                        metrics: true,
431                        profiling: true,
432                        mcp_agent: false,
433                        mcp_developer: false,
434                        console_config: true,
435                    },
436                },
437            },
438        };
439
440        if let Some(issuer) = issuer {
441            self.system_parameter_defaults
442                .insert("oidc_issuer".to_string(), issuer);
443        }
444
445        if let Some(authentication_claim) = authentication_claim {
446            self.system_parameter_defaults.insert(
447                "oidc_authentication_claim".to_string(),
448                authentication_claim,
449            );
450        }
451
452        if let Some(expected_audiences) = expected_audiences {
453            self.system_parameter_defaults.insert(
454                "oidc_audience".to_string(),
455                serde_json::to_string(&expected_audiences).unwrap(),
456            );
457        }
458
459        if let Some(external_login_password_mz_system) = external_login_password_mz_system {
460            self.external_login_password_mz_system = Some(external_login_password_mz_system);
461            self.system_parameter_defaults
462                .insert("enable_password_auth".to_string(), "true".to_string());
463        }
464
465        self
466    }
467
468    pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
469        self.external_login_password_mz_system = Some(mz_system_password);
470        let enable_tls = self.tls.is_some();
471        self.listeners_config = ListenersConfig {
472            sql: btreemap! {
473                "external".to_owned() => SqlListenerConfig {
474                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
475                    authenticator_kind: AuthenticatorKind::Password,
476                    allowed_roles: AllowedRoles::NormalAndInternal,
477                    enable_tls,
478                },
479            },
480            http: btreemap! {
481                "external".to_owned() => HttpListenerConfig {
482                    base: BaseListenerConfig {
483                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
484                        authenticator_kind: AuthenticatorKind::Password,
485                        allowed_roles: AllowedRoles::NormalAndInternal,
486                        enable_tls,
487                    },
488                    routes: HttpRoutesEnabled{
489                        base: true,
490                        webhook: true,
491                        internal: true,
492                        metrics: false,
493                        profiling: true,
494                        mcp_agent: false,
495                        mcp_developer: false,
496                        console_config: true,
497                    },
498                },
499                "metrics".to_owned() => HttpListenerConfig {
500                    base: BaseListenerConfig {
501                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
502                        authenticator_kind: AuthenticatorKind::None,
503                        allowed_roles: AllowedRoles::NormalAndInternal,
504                        enable_tls: false,
505                    },
506                    routes: HttpRoutesEnabled{
507                        base: false,
508                        webhook: false,
509                        internal: false,
510                        metrics: true,
511                        profiling: false,
512                        mcp_agent: false,
513                        mcp_developer: false,
514                        console_config: true,
515                    },
516                },
517            },
518        };
519        self
520    }
521
522    pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
523        self.external_login_password_mz_system = Some(mz_system_password);
524        let enable_tls = self.tls.is_some();
525        self.listeners_config = ListenersConfig {
526            sql: btreemap! {
527                "external".to_owned() => SqlListenerConfig {
528                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
529                    authenticator_kind: AuthenticatorKind::Sasl,
530                    allowed_roles: AllowedRoles::NormalAndInternal,
531                    enable_tls,
532                },
533            },
534            http: btreemap! {
535                "external".to_owned() => HttpListenerConfig {
536                    base: BaseListenerConfig {
537                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
538                        authenticator_kind: AuthenticatorKind::Password,
539                        allowed_roles: AllowedRoles::NormalAndInternal,
540                        enable_tls,
541                    },
542                    routes: HttpRoutesEnabled{
543                        base: true,
544                        webhook: true,
545                        internal: true,
546                        metrics: false,
547                        profiling: true,
548                        mcp_agent: false,
549                        mcp_developer: false,
550                        console_config: true,
551                    },
552                },
553                "metrics".to_owned() => HttpListenerConfig {
554                    base: BaseListenerConfig {
555                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
556                        authenticator_kind: AuthenticatorKind::None,
557                        allowed_roles: AllowedRoles::NormalAndInternal,
558                        enable_tls: false,
559                    },
560                    routes: HttpRoutesEnabled{
561                        base: false,
562                        webhook: false,
563                        internal: false,
564                        metrics: true,
565                        profiling: false,
566                        mcp_agent: false,
567                        mcp_developer: false,
568                        console_config: true,
569                    },
570                },
571            },
572        };
573        self
574    }
575
576    pub fn with_now(mut self, now: NowFn) -> Self {
577        self.now = now;
578        self
579    }
580
581    pub fn with_storage_usage_collection_interval(
582        mut self,
583        storage_usage_collection_interval: Duration,
584    ) -> Self {
585        self.storage_usage_collection_interval = storage_usage_collection_interval;
586        self
587    }
588
589    pub fn with_storage_usage_retention_period(
590        mut self,
591        storage_usage_retention_period: Duration,
592    ) -> Self {
593        self.storage_usage_retention_period = Some(storage_usage_retention_period);
594        self
595    }
596
597    pub fn with_default_cluster_replica_size(
598        mut self,
599        default_cluster_replica_size: String,
600    ) -> Self {
601        self.default_cluster_replica_size = default_cluster_replica_size;
602        self
603    }
604
605    pub fn with_builtin_system_cluster_replica_size(
606        mut self,
607        builtin_system_cluster_replica_size: String,
608    ) -> Self {
609        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
610        self
611    }
612
613    pub fn with_builtin_system_cluster_replication_factor(
614        mut self,
615        builtin_system_cluster_replication_factor: u32,
616    ) -> Self {
617        self.builtin_system_cluster_config.replication_factor =
618            builtin_system_cluster_replication_factor;
619        self
620    }
621
622    pub fn with_builtin_catalog_server_cluster_replica_size(
623        mut self,
624        builtin_catalog_server_cluster_replica_size: String,
625    ) -> Self {
626        self.builtin_catalog_server_cluster_config.size =
627            builtin_catalog_server_cluster_replica_size;
628        self
629    }
630
631    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
632        self.propagate_crashes = propagate_crashes;
633        self
634    }
635
636    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
637        self.enable_tracing = enable_tracing;
638        self
639    }
640
641    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
642        self.bootstrap_role = bootstrap_role;
643        self
644    }
645
646    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
647        self.deploy_generation = deploy_generation;
648        self
649    }
650
651    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
652        self.system_parameter_defaults.insert(param, value);
653        self
654    }
655
656    pub fn with_mcp_routes(mut self, agent: bool, developer: bool) -> Self {
657        for config in self.listeners_config.http.values_mut() {
658            config.routes.mcp_agent = agent;
659            config.routes.mcp_developer = developer;
660        }
661        self
662    }
663
664    pub fn with_internal_console_redirect_url(
665        mut self,
666        internal_console_redirect_url: Option<String>,
667    ) -> Self {
668        self.internal_console_redirect_url = internal_console_redirect_url;
669        self
670    }
671
672    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
673        self.metrics_registry = Some(registry);
674        self
675    }
676
677    pub fn with_code_version(mut self, version: semver::Version) -> Self {
678        self.code_version = version;
679        self
680    }
681
682    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
683        self.capture = Some(storage);
684        self
685    }
686}
687
688pub struct Listeners {
689    pub inner: crate::Listeners,
690}
691
692impl Listeners {
693    pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
694        let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
695        Ok(Listeners { inner })
696    }
697
698    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
699        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
700            .await
701    }
702
703    pub async fn serve_with_trigger(
704        self,
705        config: TestHarness,
706        tls_reload_certs: ReloadTrigger,
707    ) -> Result<TestServer, anyhow::Error> {
708        let (data_directory, temp_dir) = match config.data_directory {
709            None => {
710                // If no data directory is provided, we create a temporary
711                // directory. The temporary directory is cleaned up when the
712                // `TempDir` is dropped, so we keep it alive until the `Server` is
713                // dropped.
714                let temp_dir = tempfile::tempdir()?;
715                (temp_dir.path().to_path_buf(), Some(temp_dir))
716            }
717            Some(data_directory) => (data_directory, None),
718        };
719        let scratch_dir = tempfile::tempdir()?;
720        let (consensus_uri, timestamp_oracle_url) = {
721            let seed = config.seed;
722            let cockroach_url = env::var("METADATA_BACKEND_URL")
723                .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
724            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
725            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
726                if let Err(err) = conn.await {
727                    panic!("connection error: {}", err);
728                };
729            });
730            client
731                .batch_execute(&format!(
732                    "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
733                    CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
734                ))
735                .await?;
736            (
737                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
738                    .parse()
739                    .expect("invalid consensus URI"),
740                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
741                    .parse()
742                    .expect("invalid timestamp oracle URI"),
743            )
744        };
745        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
746        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
747            image_dir: env::current_exe()?
748                .parent()
749                .unwrap()
750                .parent()
751                .unwrap()
752                .to_path_buf(),
753            suppress_output: false,
754            environment_id: config.environment_id.to_string(),
755            secrets_dir: data_directory.join("secrets"),
756            command_wrapper: vec![],
757            propagate_crashes: config.propagate_crashes,
758            tcp_proxy: None,
759            scratch_directory: scratch_dir.path().to_path_buf(),
760        })
761        .await?;
762        let orchestrator = Arc::new(orchestrator);
763        // Messing with the clock causes persist to expire leases, causing hangs and
764        // panics. Is it possible/desirable to put this back somehow?
765        let persist_now = SYSTEM_TIME.clone();
766        let dyncfgs = mz_dyncfgs::all_dyncfgs();
767
768        let mut updates = ConfigUpdates::default();
769        // Tune down the number of connections to make this all work a little easier
770        // with local postgres.
771        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
772        updates.apply(&dyncfgs);
773
774        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
775        persist_cfg.build_version = config.code_version;
776        // Stress persist more by writing rollups frequently
777        persist_cfg.set_rollup_threshold(5);
778
779        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
780        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
781        let persist_pubsub_tcp_listener =
782            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
783                .await
784                .expect("pubsub addr binding");
785        let persist_pubsub_server_port = persist_pubsub_tcp_listener
786            .local_addr()
787            .expect("pubsub addr has local addr")
788            .port();
789
790        // Spawn the persist pub-sub server.
791        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
792            persist_pubsub_server
793                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
794                .await
795                .expect("success")
796        });
797        let persist_clients =
798            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
799        let persist_clients = Arc::new(persist_clients);
800
801        let secrets_controller = Arc::clone(&orchestrator);
802        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
803        let orchestrator = Arc::new(TracingOrchestrator::new(
804            orchestrator,
805            config.orchestrator_tracing_cli_args,
806        ));
807        let tracing_handle = if config.enable_tracing {
808            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
809                service_name: "environmentd",
810                stderr_log: StderrLogConfig {
811                    format: StderrLogFormat::Json,
812                    filter: EnvFilter::default(),
813                },
814                opentelemetry: Some(OpenTelemetryConfig {
815                    endpoint: "http://fake_address_for_testing:8080".to_string(),
816                    headers: http::HeaderMap::new(),
817                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
818                    resource: opentelemetry_sdk::resource::Resource::builder().build(),
819                    max_batch_queue_size: 2048,
820                    max_export_batch_size: 512,
821                    max_concurrent_exports: 1,
822                    batch_scheduled_delay: Duration::from_millis(5000),
823                    max_export_timeout: Duration::from_secs(30),
824                }),
825                tokio_console: None,
826                sentry: None,
827                build_version: crate::BUILD_INFO.version,
828                build_sha: crate::BUILD_INFO.sha,
829                registry: metrics_registry.clone(),
830                capture: config.capture,
831            };
832            mz_ore::tracing::configure(config).await?
833        } else {
834            TracingHandle::disabled()
835        };
836        let host_name = format!(
837            "localhost:{}",
838            self.inner.http["external"].handle.local_addr.port()
839        );
840        let catalog_config = CatalogConfig {
841            persist_clients: Arc::clone(&persist_clients),
842            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
843        };
844
845        let inner = self
846            .inner
847            .serve(crate::Config {
848                catalog_config,
849                timestamp_oracle_url: Some(timestamp_oracle_url),
850                controller: ControllerConfig {
851                    build_info: &crate::BUILD_INFO,
852                    orchestrator,
853                    clusterd_image: "clusterd".into(),
854                    init_container_image: None,
855                    deploy_generation: config.deploy_generation,
856                    persist_location: PersistLocation {
857                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
858                            .parse()
859                            .expect("invalid blob URI"),
860                        consensus_uri,
861                    },
862                    persist_clients,
863                    now: config.now.clone(),
864                    metrics_registry: metrics_registry.clone(),
865                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
866                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
867                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
868                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
869                        secrets_reader_kubernetes_context: None,
870                        secrets_reader_aws_prefix: None,
871                        secrets_reader_name_prefix: None,
872                    },
873                    connection_context,
874                    replica_http_locator: Default::default(),
875                },
876                secrets_controller,
877                cloud_resource_controller: None,
878                tls: config.tls,
879                frontegg: config.frontegg,
880                unsafe_mode: config.unsafe_mode,
881                all_features: false,
882                metrics_registry: metrics_registry.clone(),
883                now: config.now,
884                environment_id: config.environment_id,
885                cors_allowed_origin: AllowOrigin::list([]),
886                cors_allowed_origin_list: Vec::new(),
887                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
888                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
889                bootstrap_default_cluster_replication_factor: config
890                    .default_cluster_replication_factor,
891                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
892                bootstrap_builtin_catalog_server_cluster_config: config
893                    .builtin_catalog_server_cluster_config,
894                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
895                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
896                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
897                system_parameter_defaults: config.system_parameter_defaults,
898                availability_zones: Default::default(),
899                tracing_handle,
900                storage_usage_collection_interval: config.storage_usage_collection_interval,
901                storage_usage_retention_period: config.storage_usage_retention_period,
902                segment_api_key: None,
903                segment_client_side: false,
904                test_only_dummy_segment_client: false,
905                egress_addresses: vec![],
906                aws_account_id: None,
907                aws_privatelink_availability_zones: None,
908                launchdarkly_sdk_key: None,
909                launchdarkly_key_map: Default::default(),
910                config_sync_file_path: None,
911                config_sync_timeout: Duration::from_secs(30),
912                config_sync_loop_interval: None,
913                bootstrap_role: config.bootstrap_role,
914                http_host_name: Some(host_name),
915                internal_console_redirect_url: config.internal_console_redirect_url,
916                tls_reload_certs,
917                helm_chart_version: None,
918                license_key: ValidatedLicenseKey::for_tests(),
919                external_login_password_mz_system: config.external_login_password_mz_system,
920                force_builtin_schema_migration: None,
921            })
922            .await?;
923
924        Ok(TestServer {
925            inner,
926            metrics_registry,
927            _temp_dir: temp_dir,
928            _scratch_dir: scratch_dir,
929        })
930    }
931}
932
933/// A running instance of `environmentd`.
934pub struct TestServer {
935    pub inner: crate::Server,
936    pub metrics_registry: MetricsRegistry,
937    /// The `TempDir`s are saved to prevent them from being dropped, and thus cleaned up too early.
938    _temp_dir: Option<TempDir>,
939    _scratch_dir: TempDir,
940}
941
942impl TestServer {
943    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
944        ConnectBuilder::new(self).no_tls()
945    }
946
947    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
948        let internal_client = self.connect().internal().await.unwrap();
949
950        for flag in flags {
951            internal_client
952                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
953                .await
954                .unwrap();
955        }
956    }
957
958    pub async fn disable_feature_flags(&self, flags: &[&'static str]) {
959        let internal_client = self.connect().internal().await.unwrap();
960
961        for flag in flags {
962            internal_client
963                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
964                .await
965                .unwrap();
966        }
967    }
968
969    pub fn ws_addr(&self) -> Uri {
970        format!(
971            "ws://{}/api/experimental/sql",
972            self.inner.http_listener_handles["external"].local_addr
973        )
974        .parse()
975        .unwrap()
976    }
977
978    pub fn internal_ws_addr(&self) -> Uri {
979        format!(
980            "ws://{}/api/experimental/sql",
981            self.inner.http_listener_handles["internal"].local_addr
982        )
983        .parse()
984        .unwrap()
985    }
986
987    pub fn http_local_addr(&self) -> SocketAddr {
988        self.inner.http_listener_handles["external"].local_addr
989    }
990
991    pub fn internal_http_local_addr(&self) -> SocketAddr {
992        self.inner.http_listener_handles["internal"].local_addr
993    }
994
995    pub fn sql_local_addr(&self) -> SocketAddr {
996        self.inner.sql_listener_handles["external"].local_addr
997    }
998
999    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1000        self.inner.sql_listener_handles["internal"].local_addr
1001    }
1002}
1003
1004/// A builder struct to configure a pgwire connection to a running [`TestServer`].
1005///
1006/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
1007pub struct ConnectBuilder<'s, T, H> {
1008    /// A running `environmentd` test server.
1009    server: &'s TestServer,
1010
1011    /// Postgres configuration for connecting to the test server.
1012    pg_config: tokio_postgres::Config,
1013    /// Port to use when connecting to the test server.
1014    port: u16,
1015    /// Tls settings to use.
1016    tls: T,
1017
1018    /// Callback that gets invoked for every notice we receive.
1019    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
1020
1021    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
1022    _with_handle: H,
1023}
1024
1025impl<'s> ConnectBuilder<'s, (), NoHandle> {
1026    fn new(server: &'s TestServer) -> Self {
1027        let mut pg_config = tokio_postgres::Config::new();
1028        pg_config
1029            .host(&Ipv4Addr::LOCALHOST.to_string())
1030            .user("materialize")
1031            .options("--welcome_message=off")
1032            .application_name("environmentd_test_framework");
1033
1034        ConnectBuilder {
1035            server,
1036            pg_config,
1037            port: server.sql_local_addr().port(),
1038            tls: (),
1039            notice_callback: None,
1040            _with_handle: NoHandle,
1041        }
1042    }
1043}
1044
1045impl<'s, T, H> ConnectBuilder<'s, T, H> {
1046    /// Create a pgwire connection without using TLS.
1047    ///
1048    /// Note: this is the default for all connections.
1049    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
1050        ConnectBuilder {
1051            server: self.server,
1052            pg_config: self.pg_config,
1053            port: self.port,
1054            tls: postgres::NoTls,
1055            notice_callback: self.notice_callback,
1056            _with_handle: self._with_handle,
1057        }
1058    }
1059
1060    /// Create a pgwire connection with TLS.
1061    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
1062    where
1063        Tls: MakeTlsConnect<Socket> + Send + 'static,
1064        Tls::TlsConnect: Send,
1065        Tls::Stream: Send,
1066        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
1067    {
1068        ConnectBuilder {
1069            server: self.server,
1070            pg_config: self.pg_config,
1071            port: self.port,
1072            tls,
1073            notice_callback: self.notice_callback,
1074            _with_handle: self._with_handle,
1075        }
1076    }
1077
1078    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
1079    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
1080        self.pg_config = pg_config;
1081        self
1082    }
1083
1084    /// Set the [`SslMode`] to be used with the resulting connection.
1085    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
1086        self.pg_config.ssl_mode(mode);
1087        self
1088    }
1089
1090    /// Set the user for the pgwire connection.
1091    pub fn user(mut self, user: &str) -> Self {
1092        self.pg_config.user(user);
1093        self
1094    }
1095
1096    /// Set the password for the pgwire connection.
1097    pub fn password(mut self, password: &str) -> Self {
1098        self.pg_config.password(password);
1099        self
1100    }
1101
1102    /// Set the application name for the pgwire connection.
1103    pub fn application_name(mut self, application_name: &str) -> Self {
1104        self.pg_config.application_name(application_name);
1105        self
1106    }
1107
1108    /// Set the database name for the pgwire connection.
1109    pub fn dbname(mut self, dbname: &str) -> Self {
1110        self.pg_config.dbname(dbname);
1111        self
1112    }
1113
1114    /// Set the options for the pgwire connection.
1115    pub fn options(mut self, options: &str) -> Self {
1116        self.pg_config.options(options);
1117        self
1118    }
1119
1120    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
1121    /// [`TestServer`].
1122    ///
1123    /// For example, this will change the port we connect to, and the user we connect as.
1124    pub fn internal(mut self) -> Self {
1125        self.port = self.server.internal_sql_local_addr().port();
1126        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
1127        self
1128    }
1129
1130    /// Sets a callback for any database notices that are received from the [`TestServer`].
1131    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
1132        ConnectBuilder {
1133            notice_callback: Some(Box::new(callback)),
1134            ..self
1135        }
1136    }
1137
1138    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
1139    /// polling the underlying postgres connection, associated with the returned client.
1140    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1141        ConnectBuilder {
1142            server: self.server,
1143            pg_config: self.pg_config,
1144            port: self.port,
1145            tls: self.tls,
1146            notice_callback: self.notice_callback,
1147            _with_handle: WithHandle,
1148        }
1149    }
1150
1151    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
1152    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1153        &self.pg_config
1154    }
1155}
1156
1157/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
1158/// of a client connection.
1159pub trait IncludeHandle: Send {
1160    type Output;
1161    fn transform_result(
1162        client: tokio_postgres::Client,
1163        handle: mz_ore::task::JoinHandle<()>,
1164    ) -> Self::Output;
1165}
1166
1167/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
1168/// result of a [`ConnectBuilder`].
1169pub struct NoHandle;
1170impl IncludeHandle for NoHandle {
1171    type Output = tokio_postgres::Client;
1172    fn transform_result(
1173        client: tokio_postgres::Client,
1174        _handle: mz_ore::task::JoinHandle<()>,
1175    ) -> Self::Output {
1176        client
1177    }
1178}
1179
1180/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
1181/// a [`ConnectBuilder`].
1182pub struct WithHandle;
1183impl IncludeHandle for WithHandle {
1184    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1185    fn transform_result(
1186        client: tokio_postgres::Client,
1187        handle: mz_ore::task::JoinHandle<()>,
1188    ) -> Self::Output {
1189        (client, handle)
1190    }
1191}
1192
1193impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1194where
1195    T: MakeTlsConnect<Socket> + Send + 'static,
1196    T::TlsConnect: Send,
1197    T::Stream: Send,
1198    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1199    H: IncludeHandle,
1200{
1201    type Output = Result<H::Output, postgres::Error>;
1202    type IntoFuture = BoxFuture<'static, Self::Output>;
1203
1204    fn into_future(mut self) -> Self::IntoFuture {
1205        Box::pin(async move {
1206            assert!(
1207                self.pg_config.get_ports().is_empty(),
1208                "specifying multiple ports is not supported"
1209            );
1210            self.pg_config.port(self.port);
1211
1212            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1213            let mut notice_callback = self.notice_callback.take();
1214
1215            let handle = task::spawn(|| "connect", async move {
1216                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1217                    match msg {
1218                        Ok(AsyncMessage::Notice(notice)) => {
1219                            if let Some(callback) = notice_callback.as_mut() {
1220                                callback(notice);
1221                            }
1222                        }
1223                        Ok(msg) => {
1224                            tracing::debug!(?msg, "Dropping message from database");
1225                        }
1226                        Err(e) => {
1227                            // tokio_postgres::Connection docs say:
1228                            // > Return values of None or Some(Err(_)) are “terminal”; callers
1229                            // > should not invoke this method again after receiving one of those
1230                            // > values.
1231                            tracing::info!("connection error: {e}");
1232                            break;
1233                        }
1234                    }
1235                }
1236                tracing::info!("connection closed");
1237            });
1238
1239            let output = H::transform_result(client, handle);
1240            Ok(output)
1241        })
1242    }
1243}
1244
1245/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
1246///
1247/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
1248/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
1249pub struct TestServerWithRuntime {
1250    server: TestServer,
1251    runtime: Arc<Runtime>,
1252}
1253
1254impl TestServerWithRuntime {
1255    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
1256    ///
1257    /// Can be used to spawn async tasks.
1258    pub fn runtime(&self) -> &Arc<Runtime> {
1259        &self.runtime
1260    }
1261
1262    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
1263    pub fn inner(&self) -> &crate::Server {
1264        &self.server.inner
1265    }
1266
1267    /// Connect to the __public__ SQL port of the running `environmentd` server.
1268    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1269    where
1270        T: MakeTlsConnect<Socket> + Send + 'static,
1271        T::TlsConnect: Send,
1272        T::Stream: Send,
1273        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1274    {
1275        self.pg_config().connect(tls)
1276    }
1277
1278    /// Connect to the __internal__ SQL port of the running `environmentd` server.
1279    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1280    where
1281        T: MakeTlsConnect<Socket> + Send + 'static,
1282        T::TlsConnect: Send,
1283        T::Stream: Send,
1284        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1285    {
1286        Ok(self.pg_config_internal().connect(tls)?)
1287    }
1288
1289    /// Enable LaunchDarkly feature flags.
1290    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1291        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1292
1293        for flag in flags {
1294            internal_client
1295                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1296                .unwrap();
1297        }
1298    }
1299
1300    /// Disable LaunchDarkly feature flags.
1301    pub fn disable_feature_flags(&self, flags: &[&'static str]) {
1302        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1303
1304        for flag in flags {
1305            internal_client
1306                .batch_execute(&format!("ALTER SYSTEM SET {} = false;", flag))
1307                .unwrap();
1308        }
1309    }
1310
1311    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
1312    /// `environmentd` server.
1313    pub fn pg_config(&self) -> postgres::Config {
1314        let local_addr = self.server.sql_local_addr();
1315        let mut config = postgres::Config::new();
1316        config
1317            .host(&Ipv4Addr::LOCALHOST.to_string())
1318            .port(local_addr.port())
1319            .user("materialize")
1320            .options("--welcome_message=off");
1321        config
1322    }
1323
1324    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
1325    /// `environmentd` server.
1326    pub fn pg_config_internal(&self) -> postgres::Config {
1327        let local_addr = self.server.internal_sql_local_addr();
1328        let mut config = postgres::Config::new();
1329        config
1330            .host(&Ipv4Addr::LOCALHOST.to_string())
1331            .port(local_addr.port())
1332            .user("mz_system")
1333            .options("--welcome_message=off");
1334        config
1335    }
1336
1337    pub fn ws_addr(&self) -> Uri {
1338        self.server.ws_addr()
1339    }
1340
1341    pub fn internal_ws_addr(&self) -> Uri {
1342        self.server.internal_ws_addr()
1343    }
1344
1345    pub fn http_local_addr(&self) -> SocketAddr {
1346        self.server.http_local_addr()
1347    }
1348
1349    pub fn internal_http_local_addr(&self) -> SocketAddr {
1350        self.server.internal_http_local_addr()
1351    }
1352
1353    pub fn sql_local_addr(&self) -> SocketAddr {
1354        self.server.sql_local_addr()
1355    }
1356
1357    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1358        self.server.internal_sql_local_addr()
1359    }
1360
1361    /// Returns the metrics registry for the test server.
1362    pub fn metrics_registry(&self) -> &MetricsRegistry {
1363        &self.server.metrics_registry
1364    }
1365}
1366
1367#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1368pub struct MzTimestamp(pub u64);
1369
1370impl<'a> FromSql<'a> for MzTimestamp {
1371    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1372        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1373        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1374    }
1375
1376    fn accepts(ty: &Type) -> bool {
1377        mz_pgrepr::Numeric::accepts(ty)
1378    }
1379}
1380
1381pub trait PostgresErrorExt {
1382    fn unwrap_db_error(self) -> DbError;
1383}
1384
1385impl PostgresErrorExt for postgres::Error {
1386    fn unwrap_db_error(self) -> DbError {
1387        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1388            Some(e) => e.clone(),
1389            None => panic!("expected DbError, but got: {:?}", self),
1390        }
1391    }
1392}
1393
1394impl<T, E> PostgresErrorExt for Result<T, E>
1395where
1396    E: PostgresErrorExt,
1397{
1398    fn unwrap_db_error(self) -> DbError {
1399        match self {
1400            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1401            Err(e) => e.unwrap_db_error(),
1402        }
1403    }
1404}
1405
1406/// Group commit will block writes until the current time has advanced. This can make
1407/// performing inserts while using deterministic time difficult. This is a helper
1408/// method to perform writes and advance the current time.
1409pub async fn insert_with_deterministic_timestamps(
1410    table: &'static str,
1411    values: &'static str,
1412    server: &TestServer,
1413    now: Arc<std::sync::Mutex<EpochMillis>>,
1414) -> Result<(), Box<dyn Error>> {
1415    let client_write = server.connect().await?;
1416    let client_read = server.connect().await?;
1417
1418    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1419
1420    let insert_query = format!("INSERT INTO {table} VALUES {values}");
1421
1422    let write_future = client_write.execute(&insert_query, &[]);
1423    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1424
1425    let mut write_future = std::pin::pin!(write_future);
1426    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1427
1428    // Keep increasing `now` until the write has executed succeed. Table advancements may
1429    // have increased the global timestamp by an unknown amount.
1430    loop {
1431        tokio::select! {
1432            _ = (&mut write_future) => return Ok(()),
1433            _ = timestamp_interval.tick() => {
1434                current_ts += 1;
1435                *now.lock().expect("lock poisoned") = current_ts;
1436            }
1437        };
1438    }
1439}
1440
1441pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1442    try_get_explain_timestamp(from_suffix, client)
1443        .await
1444        .unwrap()
1445}
1446
1447pub async fn try_get_explain_timestamp(
1448    from_suffix: &str,
1449    client: &Client,
1450) -> Result<EpochMillis, anyhow::Error> {
1451    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1452    let ts = det.determination.timestamp_context.timestamp_or_default();
1453    Ok(ts.into())
1454}
1455
1456pub async fn get_explain_timestamp_determination(
1457    from_suffix: &str,
1458    client: &Client,
1459) -> Result<TimestampExplanation, anyhow::Error> {
1460    let row = client
1461        .query_one(
1462            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1463            &[],
1464        )
1465        .await?;
1466    let explain: String = row.get(0);
1467    Ok(serde_json::from_str(&explain).unwrap())
1468}
1469
1470/// Helper function to create a Postgres source.
1471///
1472/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1473/// Postgres state.
1474///
1475/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1476/// use different postgres tables.
1477pub async fn create_postgres_source_with_table<'a>(
1478    server: &TestServer,
1479    mz_client: &Client,
1480    table_name: &str,
1481    table_schema: &str,
1482    source_name: &str,
1483) -> (
1484    Client,
1485    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1486) {
1487    server
1488        .enable_feature_flags(&["enable_create_table_from_source"])
1489        .await;
1490
1491    let postgres_url = env::var("POSTGRES_URL")
1492        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1493        .unwrap();
1494
1495    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1496        .await
1497        .unwrap();
1498
1499    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1500    let user = pg_config.get_user().unwrap_or("postgres");
1501    let db_name = pg_config.get_dbname().unwrap_or(user);
1502    let ports = pg_config.get_ports();
1503    let port = if ports.is_empty() { 5432 } else { ports[0] };
1504    let hosts = pg_config.get_hosts();
1505    let host = if hosts.is_empty() {
1506        "localhost".to_string()
1507    } else {
1508        match &hosts[0] {
1509            Host::Tcp(host) => host.to_string(),
1510            Host::Unix(host) => host.to_str().unwrap().to_string(),
1511        }
1512    };
1513    let password = pg_config.get_password();
1514
1515    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1516        if let Err(e) = connection.await {
1517            panic!("connection error: {}", e);
1518        }
1519    });
1520
1521    // Create table in Postgres with publication.
1522    let _ = pg_client
1523        .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1524        .await
1525        .unwrap();
1526    let _ = pg_client
1527        .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1528        .await
1529        .unwrap();
1530    let _ = pg_client
1531        .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1532        .await
1533        .unwrap();
1534    let _ = pg_client
1535        .execute(
1536            &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1537            &[],
1538        )
1539        .await
1540        .unwrap();
1541    let _ = pg_client
1542        .execute(
1543            &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1544            &[],
1545        )
1546        .await
1547        .unwrap();
1548
1549    // Create postgres source in Materialize.
1550    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1551    if let Some(password) = password {
1552        let password = std::str::from_utf8(password).unwrap();
1553        mz_client
1554            .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1555            .await
1556            .unwrap();
1557        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1558    }
1559    mz_client
1560        .batch_execute(&format!(
1561            "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1562        ))
1563        .await
1564        .unwrap();
1565    mz_client
1566        .batch_execute(&format!(
1567            "CREATE SOURCE {source_name}
1568            FROM POSTGRES
1569            CONNECTION pgconn
1570            (PUBLICATION '{source_name}')"
1571        ))
1572        .await
1573        .unwrap();
1574    mz_client
1575        .batch_execute(&format!(
1576            "CREATE TABLE {table_name}
1577            FROM SOURCE {source_name}
1578            (REFERENCE {table_name});"
1579        ))
1580        .await
1581        .unwrap();
1582
1583    let table_name = table_name.to_string();
1584    let source_name = source_name.to_string();
1585    (
1586        pg_client,
1587        move |mz_client: &'a Client, pg_client: &'a Client| {
1588            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1589                mz_client
1590                    .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1591                    .await
1592                    .unwrap();
1593                mz_client
1594                    .batch_execute("DROP CONNECTION pgconn;")
1595                    .await
1596                    .unwrap();
1597
1598                let _ = pg_client
1599                    .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1600                    .await
1601                    .unwrap();
1602                let _ = pg_client
1603                    .execute(&format!("DROP TABLE {table_name};"), &[])
1604                    .await
1605                    .unwrap();
1606            });
1607            f
1608        },
1609    )
1610}
1611
1612pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1613    let current_isolation = mz_client
1614        .query_one("SHOW transaction_isolation", &[])
1615        .await
1616        .unwrap()
1617        .get::<_, String>(0);
1618    mz_client
1619        .batch_execute("SET transaction_isolation = SERIALIZABLE")
1620        .await
1621        .unwrap();
1622    Retry::default()
1623        .retry_async(|_| async move {
1624            let rows = mz_client
1625                .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1626                .await
1627                .unwrap()
1628                .get::<_, i64>(0);
1629            if rows == source_rows {
1630                Ok(())
1631            } else {
1632                Err(format!(
1633                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1634                ))
1635            }
1636        })
1637        .await
1638        .unwrap();
1639    mz_client
1640        .batch_execute(&format!(
1641            "SET transaction_isolation = '{current_isolation}'"
1642        ))
1643        .await
1644        .unwrap();
1645}
1646
1647// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1648pub fn auth_with_ws(
1649    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1650    mut options: BTreeMap<String, String>,
1651) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1652    if !options.contains_key("welcome_message") {
1653        options.insert("welcome_message".into(), "off".into());
1654    }
1655    auth_with_ws_impl(
1656        ws,
1657        Message::Text(
1658            serde_json::to_string(&WebSocketAuth::Basic {
1659                user: "materialize".into(),
1660                password: "".into(),
1661                options,
1662            })
1663            .unwrap()
1664            .into(),
1665        ),
1666    )
1667}
1668
1669pub fn auth_with_ws_impl(
1670    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1671    auth_message: Message,
1672) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1673    ws.send(auth_message)?;
1674
1675    // Wait for initial ready response.
1676    let mut msgs = Vec::new();
1677    loop {
1678        let resp = ws.read()?;
1679        match resp {
1680            Message::Text(msg) => {
1681                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1682                match msg {
1683                    WebSocketResponse::ReadyForQuery(_) => break,
1684                    msg => {
1685                        msgs.push(msg);
1686                    }
1687                }
1688            }
1689            Message::Ping(_) => continue,
1690            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1691            Message::Close(Some(close_frame)) => {
1692                return Err(anyhow!("ws closed after auth").context(close_frame));
1693            }
1694            _ => panic!("unexpected response: {:?}", resp),
1695        }
1696    }
1697    Ok(msgs)
1698}
1699
1700pub fn make_header<H: Header>(h: H) -> HeaderMap {
1701    let mut map = HeaderMap::new();
1702    map.typed_insert(h);
1703    map
1704}
1705
1706pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1707where
1708    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1709{
1710    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1711    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1712    // messages with TLS v1.2.
1713    //
1714    // Briefly, in TLS v1.3, failing to present a client certificate does not
1715    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1716    // attempt to read from the stream. But both `postgres` and `hyper` write a
1717    // bunch of data before attempting to read from the stream. With a failed
1718    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1719    // out this data, and then return a nice error message on the call to read.
1720    // But sometimes the connection is closed before they write out the data,
1721    // and so they report "connection closed" before they ever call read, never
1722    // noticing the underlying SSL error.
1723    //
1724    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1725    // if writing to the stream fails to see if a TLS error occured? Is it on
1726    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1727    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1728    // just avoid TLS v1.3 for now.
1729    //
1730    // [1]: https://github.com/openssl/openssl/issues/11118
1731    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1732    connector_builder.set_options(options);
1733    configure(&mut connector_builder).unwrap();
1734    MakeTlsConnector::new(connector_builder.build())
1735}
1736
1737/// A certificate authority for use in tests.
1738pub struct Ca {
1739    pub dir: TempDir,
1740    pub name: X509Name,
1741    pub cert: X509,
1742    pub pkey: PKey<Private>,
1743}
1744
1745impl Ca {
1746    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1747        let dir = tempfile::tempdir()?;
1748        let rsa = Rsa::generate(2048)?;
1749        let pkey = PKey::from_rsa(rsa)?;
1750        let name = {
1751            let mut builder = X509NameBuilder::new()?;
1752            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1753            builder.build()
1754        };
1755        let cert = {
1756            let mut builder = X509::builder()?;
1757            builder.set_version(2)?;
1758            builder.set_pubkey(&pkey)?;
1759            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1760            builder.set_subject_name(&name)?;
1761            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1762            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1763            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1764            builder.sign(
1765                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1766                MessageDigest::sha256(),
1767            )?;
1768            builder.build()
1769        };
1770        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1771        Ok(Ca {
1772            dir,
1773            name,
1774            cert,
1775            pkey,
1776        })
1777    }
1778
1779    /// Creates a new root certificate authority.
1780    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1781        Ca::make_ca(name, None)
1782    }
1783
1784    /// Returns the path to the CA's certificate.
1785    pub fn ca_cert_path(&self) -> PathBuf {
1786        self.dir.path().join("ca.crt")
1787    }
1788
1789    /// Requests a new intermediate certificate authority.
1790    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1791        Ca::make_ca(name, Some(self))
1792    }
1793
1794    /// Generates a certificate with the specified Common Name (CN) that is
1795    /// signed by the CA.
1796    ///
1797    /// Returns the paths to the certificate and key.
1798    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1799        self.request_cert(name, iter::empty())
1800    }
1801
1802    /// Like `request_client_cert`, but permits specifying additional IP
1803    /// addresses to attach as Subject Alternate Names.
1804    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1805    where
1806        I: IntoIterator<Item = IpAddr>,
1807    {
1808        let rsa = Rsa::generate(2048)?;
1809        let pkey = PKey::from_rsa(rsa)?;
1810        let subject_name = {
1811            let mut builder = X509NameBuilder::new()?;
1812            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1813            builder.build()
1814        };
1815        let cert = {
1816            let mut builder = X509::builder()?;
1817            builder.set_version(2)?;
1818            builder.set_pubkey(&pkey)?;
1819            builder.set_issuer_name(self.cert.subject_name())?;
1820            builder.set_subject_name(&subject_name)?;
1821            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1822            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1823            for ip in ips {
1824                builder.append_extension(
1825                    SubjectAlternativeName::new()
1826                        .ip(&ip.to_string())
1827                        .build(&builder.x509v3_context(None, None))?,
1828                )?;
1829            }
1830            builder.sign(&self.pkey, MessageDigest::sha256())?;
1831            builder.build()
1832        };
1833        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1834        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1835        fs::write(&cert_path, cert.to_pem()?)?;
1836        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1837        Ok((cert_path, key_path))
1838    }
1839}