mz_environmentd/
test_util.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeMap;
11use std::error::Error;
12use std::future::IntoFuture;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
14use std::path::{Path, PathBuf};
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::Arc;
18use std::sync::LazyLock;
19use std::time::Duration;
20use std::{env, fs, iter};
21
22use anyhow::anyhow;
23use futures::Future;
24use futures::future::{BoxFuture, LocalBoxFuture};
25use headers::{Header, HeaderMapExt};
26use http::Uri;
27use hyper::http::header::HeaderMap;
28use maplit::btreemap;
29use mz_adapter::TimestampExplanation;
30use mz_adapter_types::bootstrap_builtin_cluster_config::{
31    ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR, BootstrapBuiltinClusterConfig,
32    CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR, PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
33    SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR, SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
34};
35
36use mz_auth::password::Password;
37use mz_catalog::config::ClusterReplicaSizeMap;
38use mz_controller::ControllerConfig;
39use mz_dyncfg::ConfigUpdates;
40use mz_license_keys::ValidatedLicenseKey;
41use mz_orchestrator_process::{ProcessOrchestrator, ProcessOrchestratorConfig};
42use mz_orchestrator_tracing::{TracingCliArgs, TracingOrchestrator};
43use mz_ore::metrics::MetricsRegistry;
44use mz_ore::now::{EpochMillis, NowFn, SYSTEM_TIME};
45use mz_ore::retry::Retry;
46use mz_ore::task;
47use mz_ore::tracing::{
48    OpenTelemetryConfig, StderrLogConfig, StderrLogFormat, TracingConfig, TracingGuard,
49    TracingHandle,
50};
51use mz_persist_client::PersistLocation;
52use mz_persist_client::cache::PersistClientCache;
53use mz_persist_client::cfg::{CONSENSUS_CONNECTION_POOL_MAX_SIZE, PersistConfig};
54use mz_persist_client::rpc::PersistGrpcPubSubServer;
55use mz_secrets::SecretsController;
56use mz_server_core::listeners::{
57    AllowedRoles, AuthenticatorKind, BaseListenerConfig, HttpRoutesEnabled,
58};
59use mz_server_core::{ReloadTrigger, TlsCertConfig};
60use mz_sql::catalog::EnvironmentId;
61use mz_storage_types::connections::ConnectionContext;
62use mz_tracing::CloneableEnvFilter;
63use openssl::asn1::Asn1Time;
64use openssl::error::ErrorStack;
65use openssl::hash::MessageDigest;
66use openssl::nid::Nid;
67use openssl::pkey::{PKey, Private};
68use openssl::rsa::Rsa;
69use openssl::ssl::{SslConnector, SslConnectorBuilder, SslMethod, SslOptions};
70use openssl::x509::extension::{BasicConstraints, SubjectAlternativeName};
71use openssl::x509::{X509, X509Name, X509NameBuilder};
72use postgres::error::DbError;
73use postgres::tls::{MakeTlsConnect, TlsConnect};
74use postgres::types::{FromSql, Type};
75use postgres::{NoTls, Socket};
76use postgres_openssl::MakeTlsConnector;
77use tempfile::TempDir;
78use tokio::net::TcpListener;
79use tokio::runtime::Runtime;
80use tokio_postgres::config::{Host, SslMode};
81use tokio_postgres::{AsyncMessage, Client};
82use tokio_stream::wrappers::TcpListenerStream;
83use tower_http::cors::AllowOrigin;
84use tracing::Level;
85use tracing_capture::SharedStorage;
86use tracing_subscriber::EnvFilter;
87use tungstenite::stream::MaybeTlsStream;
88use tungstenite::{Message, WebSocket};
89
90use crate::{
91    CatalogConfig, FronteggAuthenticator, HttpListenerConfig, ListenersConfig, SqlListenerConfig,
92    WebSocketAuth, WebSocketResponse,
93};
94
95pub static KAFKA_ADDRS: LazyLock<String> =
96    LazyLock::new(|| env::var("KAFKA_ADDRS").unwrap_or_else(|_| "localhost:9092".into()));
97
98/// Entry point for creating and configuring an `environmentd` test harness.
99#[derive(Clone)]
100pub struct TestHarness {
101    data_directory: Option<PathBuf>,
102    tls: Option<TlsCertConfig>,
103    frontegg: Option<FronteggAuthenticator>,
104    external_login_password_mz_system: Option<Password>,
105    listeners_config: ListenersConfig,
106    unsafe_mode: bool,
107    workers: usize,
108    now: NowFn,
109    seed: u32,
110    storage_usage_collection_interval: Duration,
111    storage_usage_retention_period: Option<Duration>,
112    default_cluster_replica_size: String,
113    default_cluster_replication_factor: u32,
114    builtin_system_cluster_config: BootstrapBuiltinClusterConfig,
115    builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig,
116    builtin_probe_cluster_config: BootstrapBuiltinClusterConfig,
117    builtin_support_cluster_config: BootstrapBuiltinClusterConfig,
118    builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig,
119
120    propagate_crashes: bool,
121    enable_tracing: bool,
122    // This is currently unrelated to enable_tracing, and is used only to disable orchestrator
123    // tracing.
124    orchestrator_tracing_cli_args: TracingCliArgs,
125    bootstrap_role: Option<String>,
126    deploy_generation: u64,
127    system_parameter_defaults: BTreeMap<String, String>,
128    internal_console_redirect_url: Option<String>,
129    metrics_registry: Option<MetricsRegistry>,
130    code_version: semver::Version,
131    capture: Option<SharedStorage>,
132    pub environment_id: EnvironmentId,
133}
134
135impl Default for TestHarness {
136    fn default() -> TestHarness {
137        TestHarness {
138            data_directory: None,
139            tls: None,
140            frontegg: None,
141            external_login_password_mz_system: None,
142            listeners_config: ListenersConfig {
143                sql: btreemap![
144                    "external".to_owned() => SqlListenerConfig {
145                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
146                        authenticator_kind: AuthenticatorKind::None,
147                        allowed_roles: AllowedRoles::Normal,
148                        enable_tls: false,
149                    },
150                    "internal".to_owned() => SqlListenerConfig {
151                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
152                        authenticator_kind: AuthenticatorKind::None,
153                        allowed_roles: AllowedRoles::NormalAndInternal,
154                        enable_tls: false,
155                    },
156                ],
157                http: btreemap![
158                    "external".to_owned() => HttpListenerConfig {
159                        base: BaseListenerConfig {
160                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
161                            authenticator_kind: AuthenticatorKind::None,
162                            allowed_roles: AllowedRoles::Normal,
163                            enable_tls: false,
164                        },
165                        routes: HttpRoutesEnabled{
166                            base: true,
167                            webhook: true,
168                            internal: false,
169                            metrics: false,
170                            profiling: false,
171                        },
172                    },
173                    "internal".to_owned() => HttpListenerConfig {
174                        base: BaseListenerConfig {
175                            addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
176                            authenticator_kind: AuthenticatorKind::None,
177                            allowed_roles: AllowedRoles::NormalAndInternal,
178                            enable_tls: false,
179                        },
180                        routes: HttpRoutesEnabled{
181                            base: true,
182                            webhook: true,
183                            internal: true,
184                            metrics: true,
185                            profiling: true,
186                        },
187                    },
188                ],
189            },
190            unsafe_mode: false,
191            workers: 1,
192            now: SYSTEM_TIME.clone(),
193            seed: rand::random(),
194            storage_usage_collection_interval: Duration::from_secs(3600),
195            storage_usage_retention_period: None,
196            default_cluster_replica_size: "scale=1,workers=1".to_string(),
197            default_cluster_replication_factor: 1,
198            builtin_system_cluster_config: BootstrapBuiltinClusterConfig {
199                size: "scale=1,workers=1".to_string(),
200                replication_factor: SYSTEM_CLUSTER_DEFAULT_REPLICATION_FACTOR,
201            },
202            builtin_catalog_server_cluster_config: BootstrapBuiltinClusterConfig {
203                size: "scale=1,workers=1".to_string(),
204                replication_factor: CATALOG_SERVER_CLUSTER_DEFAULT_REPLICATION_FACTOR,
205            },
206            builtin_probe_cluster_config: BootstrapBuiltinClusterConfig {
207                size: "scale=1,workers=1".to_string(),
208                replication_factor: PROBE_CLUSTER_DEFAULT_REPLICATION_FACTOR,
209            },
210            builtin_support_cluster_config: BootstrapBuiltinClusterConfig {
211                size: "scale=1,workers=1".to_string(),
212                replication_factor: SUPPORT_CLUSTER_DEFAULT_REPLICATION_FACTOR,
213            },
214            builtin_analytics_cluster_config: BootstrapBuiltinClusterConfig {
215                size: "scale=1,workers=1".to_string(),
216                replication_factor: ANALYTICS_CLUSTER_DEFAULT_REPLICATION_FACTOR,
217            },
218            propagate_crashes: false,
219            enable_tracing: false,
220            bootstrap_role: Some("materialize".into()),
221            deploy_generation: 0,
222            // This and startup_log_filter below are both (?) needed to suppress clusterd messages.
223            // If we need those in the future, we might need to change both.
224            system_parameter_defaults: BTreeMap::from([(
225                "log_filter".to_string(),
226                "error".to_string(),
227            )]),
228            internal_console_redirect_url: None,
229            metrics_registry: None,
230            orchestrator_tracing_cli_args: TracingCliArgs {
231                startup_log_filter: CloneableEnvFilter::from_str("error").expect("must parse"),
232                ..Default::default()
233            },
234            code_version: crate::BUILD_INFO.semver_version(),
235            environment_id: EnvironmentId::for_tests(),
236            capture: None,
237        }
238    }
239}
240
241impl TestHarness {
242    /// Starts a test [`TestServer`], panicking if the server could not be started.
243    ///
244    /// For cases when startup might fail, see [`TestHarness::try_start`].
245    pub async fn start(self) -> TestServer {
246        self.try_start().await.expect("Failed to start test Server")
247    }
248
249    /// Like [`TestHarness::start`] but can specify a cert reload trigger.
250    pub async fn start_with_trigger(self, tls_reload_certs: ReloadTrigger) -> TestServer {
251        self.try_start_with_trigger(tls_reload_certs)
252            .await
253            .expect("Failed to start test Server")
254    }
255
256    /// Starts a test [`TestServer`], returning an error if the server could not be started.
257    pub async fn try_start(self) -> Result<TestServer, anyhow::Error> {
258        self.try_start_with_trigger(mz_server_core::cert_reload_never_reload())
259            .await
260    }
261
262    /// Like [`TestHarness::try_start`] but can specify a cert reload trigger.
263    pub async fn try_start_with_trigger(
264        self,
265        tls_reload_certs: ReloadTrigger,
266    ) -> Result<TestServer, anyhow::Error> {
267        let listeners = Listeners::new(&self).await?;
268        listeners.serve_with_trigger(self, tls_reload_certs).await
269    }
270
271    /// Starts a runtime and returns a [`TestServerWithRuntime`].
272    pub fn start_blocking(self) -> TestServerWithRuntime {
273        stacker::grow(mz_ore::stack::STACK_SIZE, || {
274            let runtime = Runtime::new().expect("failed to spawn runtime for test");
275            let runtime = Arc::new(runtime);
276            let server = runtime.block_on(self.start());
277            TestServerWithRuntime { runtime, server }
278        })
279    }
280
281    pub fn data_directory(mut self, data_directory: impl Into<PathBuf>) -> Self {
282        self.data_directory = Some(data_directory.into());
283        self
284    }
285
286    pub fn with_tls(mut self, cert_path: impl Into<PathBuf>, key_path: impl Into<PathBuf>) -> Self {
287        self.tls = Some(TlsCertConfig {
288            cert: cert_path.into(),
289            key: key_path.into(),
290        });
291        for (_, listener) in &mut self.listeners_config.sql {
292            listener.enable_tls = true;
293        }
294        for (_, listener) in &mut self.listeners_config.http {
295            listener.base.enable_tls = true;
296        }
297        self
298    }
299
300    pub fn unsafe_mode(mut self) -> Self {
301        self.unsafe_mode = true;
302        self
303    }
304
305    pub fn workers(mut self, workers: usize) -> Self {
306        self.workers = workers;
307        self
308    }
309
310    pub fn with_frontegg_auth(mut self, frontegg: &FronteggAuthenticator) -> Self {
311        self.frontegg = Some(frontegg.clone());
312        let enable_tls = self.tls.is_some();
313        self.listeners_config = ListenersConfig {
314            sql: btreemap! {
315                "external".to_owned() => SqlListenerConfig {
316                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
317                    authenticator_kind: AuthenticatorKind::Frontegg,
318                    allowed_roles: AllowedRoles::Normal,
319                    enable_tls,
320                },
321                "internal".to_owned() => SqlListenerConfig {
322                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
323                    authenticator_kind: AuthenticatorKind::None,
324                    allowed_roles: AllowedRoles::NormalAndInternal,
325                    enable_tls: false,
326                },
327            },
328            http: btreemap! {
329                "external".to_owned() => HttpListenerConfig {
330                    base: BaseListenerConfig {
331                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
332                        authenticator_kind: AuthenticatorKind::Frontegg,
333                        allowed_roles: AllowedRoles::Normal,
334                        enable_tls,
335                    },
336                    routes: HttpRoutesEnabled{
337                        base: true,
338                        webhook: true,
339                        internal: false,
340                        metrics: false,
341                        profiling: false,
342                    },
343                },
344                "internal".to_owned() => HttpListenerConfig {
345                    base: BaseListenerConfig {
346                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
347                        authenticator_kind: AuthenticatorKind::None,
348                        allowed_roles: AllowedRoles::NormalAndInternal,
349                        enable_tls: false,
350                    },
351                    routes: HttpRoutesEnabled{
352                        base: true,
353                        webhook: true,
354                        internal: true,
355                        metrics: true,
356                        profiling: true,
357                    },
358                },
359            },
360        };
361        self
362    }
363
364    pub fn with_password_auth(mut self, mz_system_password: Password) -> Self {
365        self.external_login_password_mz_system = Some(mz_system_password);
366        let enable_tls = self.tls.is_some();
367        self.listeners_config = ListenersConfig {
368            sql: btreemap! {
369                "external".to_owned() => SqlListenerConfig {
370                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
371                    authenticator_kind: AuthenticatorKind::Password,
372                    allowed_roles: AllowedRoles::NormalAndInternal,
373                    enable_tls,
374                },
375            },
376            http: btreemap! {
377                "external".to_owned() => HttpListenerConfig {
378                    base: BaseListenerConfig {
379                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
380                        authenticator_kind: AuthenticatorKind::Password,
381                        allowed_roles: AllowedRoles::NormalAndInternal,
382                        enable_tls,
383                    },
384                    routes: HttpRoutesEnabled{
385                        base: true,
386                        webhook: true,
387                        internal: true,
388                        metrics: false,
389                        profiling: true,
390                    },
391                },
392                "metrics".to_owned() => HttpListenerConfig {
393                    base: BaseListenerConfig {
394                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
395                        authenticator_kind: AuthenticatorKind::None,
396                        allowed_roles: AllowedRoles::NormalAndInternal,
397                        enable_tls: false,
398                    },
399                    routes: HttpRoutesEnabled{
400                        base: false,
401                        webhook: false,
402                        internal: false,
403                        metrics: true,
404                        profiling: false,
405                    },
406                },
407            },
408        };
409        self
410    }
411
412    pub fn with_sasl_scram_auth(mut self, mz_system_password: Password) -> Self {
413        self.external_login_password_mz_system = Some(mz_system_password);
414        let enable_tls = self.tls.is_some();
415        self.listeners_config = ListenersConfig {
416            sql: btreemap! {
417                "external".to_owned() => SqlListenerConfig {
418                    addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
419                    authenticator_kind: AuthenticatorKind::Sasl,
420                    allowed_roles: AllowedRoles::NormalAndInternal,
421                    enable_tls,
422                },
423            },
424            http: btreemap! {
425                "external".to_owned() => HttpListenerConfig {
426                    base: BaseListenerConfig {
427                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
428                        authenticator_kind: AuthenticatorKind::Password,
429                        allowed_roles: AllowedRoles::NormalAndInternal,
430                        enable_tls,
431                    },
432                    routes: HttpRoutesEnabled{
433                        base: true,
434                        webhook: true,
435                        internal: true,
436                        metrics: false,
437                        profiling: true,
438                    },
439                },
440                "metrics".to_owned() => HttpListenerConfig {
441                    base: BaseListenerConfig {
442                        addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0),
443                        authenticator_kind: AuthenticatorKind::None,
444                        allowed_roles: AllowedRoles::NormalAndInternal,
445                        enable_tls: false,
446                    },
447                    routes: HttpRoutesEnabled{
448                        base: false,
449                        webhook: false,
450                        internal: false,
451                        metrics: true,
452                        profiling: false,
453                    },
454                },
455            },
456        };
457        self
458    }
459
460    pub fn with_now(mut self, now: NowFn) -> Self {
461        self.now = now;
462        self
463    }
464
465    pub fn with_storage_usage_collection_interval(
466        mut self,
467        storage_usage_collection_interval: Duration,
468    ) -> Self {
469        self.storage_usage_collection_interval = storage_usage_collection_interval;
470        self
471    }
472
473    pub fn with_storage_usage_retention_period(
474        mut self,
475        storage_usage_retention_period: Duration,
476    ) -> Self {
477        self.storage_usage_retention_period = Some(storage_usage_retention_period);
478        self
479    }
480
481    pub fn with_default_cluster_replica_size(
482        mut self,
483        default_cluster_replica_size: String,
484    ) -> Self {
485        self.default_cluster_replica_size = default_cluster_replica_size;
486        self
487    }
488
489    pub fn with_builtin_system_cluster_replica_size(
490        mut self,
491        builtin_system_cluster_replica_size: String,
492    ) -> Self {
493        self.builtin_system_cluster_config.size = builtin_system_cluster_replica_size;
494        self
495    }
496
497    pub fn with_builtin_system_cluster_replication_factor(
498        mut self,
499        builtin_system_cluster_replication_factor: u32,
500    ) -> Self {
501        self.builtin_system_cluster_config.replication_factor =
502            builtin_system_cluster_replication_factor;
503        self
504    }
505
506    pub fn with_builtin_catalog_server_cluster_replica_size(
507        mut self,
508        builtin_catalog_server_cluster_replica_size: String,
509    ) -> Self {
510        self.builtin_catalog_server_cluster_config.size =
511            builtin_catalog_server_cluster_replica_size;
512        self
513    }
514
515    pub fn with_propagate_crashes(mut self, propagate_crashes: bool) -> Self {
516        self.propagate_crashes = propagate_crashes;
517        self
518    }
519
520    pub fn with_enable_tracing(mut self, enable_tracing: bool) -> Self {
521        self.enable_tracing = enable_tracing;
522        self
523    }
524
525    pub fn with_bootstrap_role(mut self, bootstrap_role: Option<String>) -> Self {
526        self.bootstrap_role = bootstrap_role;
527        self
528    }
529
530    pub fn with_deploy_generation(mut self, deploy_generation: u64) -> Self {
531        self.deploy_generation = deploy_generation;
532        self
533    }
534
535    pub fn with_system_parameter_default(mut self, param: String, value: String) -> Self {
536        self.system_parameter_defaults.insert(param, value);
537        self
538    }
539
540    pub fn with_internal_console_redirect_url(
541        mut self,
542        internal_console_redirect_url: Option<String>,
543    ) -> Self {
544        self.internal_console_redirect_url = internal_console_redirect_url;
545        self
546    }
547
548    pub fn with_metrics_registry(mut self, registry: MetricsRegistry) -> Self {
549        self.metrics_registry = Some(registry);
550        self
551    }
552
553    pub fn with_code_version(mut self, version: semver::Version) -> Self {
554        self.code_version = version;
555        self
556    }
557
558    pub fn with_capture(mut self, storage: SharedStorage) -> Self {
559        self.capture = Some(storage);
560        self
561    }
562}
563
564pub struct Listeners {
565    pub inner: crate::Listeners,
566}
567
568impl Listeners {
569    pub async fn new(config: &TestHarness) -> Result<Listeners, anyhow::Error> {
570        let inner = crate::Listeners::bind(config.listeners_config.clone()).await?;
571        Ok(Listeners { inner })
572    }
573
574    pub async fn serve(self, config: TestHarness) -> Result<TestServer, anyhow::Error> {
575        self.serve_with_trigger(config, mz_server_core::cert_reload_never_reload())
576            .await
577    }
578
579    pub async fn serve_with_trigger(
580        self,
581        config: TestHarness,
582        tls_reload_certs: ReloadTrigger,
583    ) -> Result<TestServer, anyhow::Error> {
584        let (data_directory, temp_dir) = match config.data_directory {
585            None => {
586                // If no data directory is provided, we create a temporary
587                // directory. The temporary directory is cleaned up when the
588                // `TempDir` is dropped, so we keep it alive until the `Server` is
589                // dropped.
590                let temp_dir = tempfile::tempdir()?;
591                (temp_dir.path().to_path_buf(), Some(temp_dir))
592            }
593            Some(data_directory) => (data_directory, None),
594        };
595        let scratch_dir = tempfile::tempdir()?;
596        let (consensus_uri, timestamp_oracle_url) = {
597            let seed = config.seed;
598            let cockroach_url = env::var("METADATA_BACKEND_URL")
599                .map_err(|_| anyhow!("METADATA_BACKEND_URL environment variable is not set"))?;
600            let (client, conn) = tokio_postgres::connect(&cockroach_url, NoTls).await?;
601            mz_ore::task::spawn(|| "startup-postgres-conn", async move {
602                if let Err(err) = conn.await {
603                    panic!("connection error: {}", err);
604                };
605            });
606            client
607                .batch_execute(&format!(
608                    "CREATE SCHEMA IF NOT EXISTS consensus_{seed};
609                    CREATE SCHEMA IF NOT EXISTS tsoracle_{seed};"
610                ))
611                .await?;
612            (
613                format!("{cockroach_url}?options=--search_path=consensus_{seed}")
614                    .parse()
615                    .expect("invalid consensus URI"),
616                format!("{cockroach_url}?options=--search_path=tsoracle_{seed}")
617                    .parse()
618                    .expect("invalid timestamp oracle URI"),
619            )
620        };
621        let metrics_registry = config.metrics_registry.unwrap_or_else(MetricsRegistry::new);
622        let orchestrator = ProcessOrchestrator::new(ProcessOrchestratorConfig {
623            image_dir: env::current_exe()?
624                .parent()
625                .unwrap()
626                .parent()
627                .unwrap()
628                .to_path_buf(),
629            suppress_output: false,
630            environment_id: config.environment_id.to_string(),
631            secrets_dir: data_directory.join("secrets"),
632            command_wrapper: vec![],
633            propagate_crashes: config.propagate_crashes,
634            tcp_proxy: None,
635            scratch_directory: scratch_dir.path().to_path_buf(),
636        })
637        .await?;
638        let orchestrator = Arc::new(orchestrator);
639        // Messing with the clock causes persist to expire leases, causing hangs and
640        // panics. Is it possible/desirable to put this back somehow?
641        let persist_now = SYSTEM_TIME.clone();
642        let dyncfgs = mz_dyncfgs::all_dyncfgs();
643
644        let mut updates = ConfigUpdates::default();
645        // Tune down the number of connections to make this all work a little easier
646        // with local postgres.
647        updates.add(&CONSENSUS_CONNECTION_POOL_MAX_SIZE, 1);
648        updates.apply(&dyncfgs);
649
650        let mut persist_cfg = PersistConfig::new(&crate::BUILD_INFO, persist_now.clone(), dyncfgs);
651        persist_cfg.build_version = config.code_version;
652        // Stress persist more by writing rollups frequently
653        persist_cfg.set_rollup_threshold(5);
654
655        let persist_pubsub_server = PersistGrpcPubSubServer::new(&persist_cfg, &metrics_registry);
656        let persist_pubsub_client = persist_pubsub_server.new_same_process_connection();
657        let persist_pubsub_tcp_listener =
658            TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))
659                .await
660                .expect("pubsub addr binding");
661        let persist_pubsub_server_port = persist_pubsub_tcp_listener
662            .local_addr()
663            .expect("pubsub addr has local addr")
664            .port();
665
666        // Spawn the persist pub-sub server.
667        mz_ore::task::spawn(|| "persist_pubsub_server", async move {
668            persist_pubsub_server
669                .serve_with_stream(TcpListenerStream::new(persist_pubsub_tcp_listener))
670                .await
671                .expect("success")
672        });
673        let persist_clients =
674            PersistClientCache::new(persist_cfg, &metrics_registry, |_, _| persist_pubsub_client);
675        let persist_clients = Arc::new(persist_clients);
676
677        let secrets_controller = Arc::clone(&orchestrator);
678        let connection_context = ConnectionContext::for_tests(orchestrator.reader());
679        let orchestrator = Arc::new(TracingOrchestrator::new(
680            orchestrator,
681            config.orchestrator_tracing_cli_args,
682        ));
683        let (tracing_handle, tracing_guard) = if config.enable_tracing {
684            let config = TracingConfig::<fn(&tracing::Metadata) -> sentry_tracing::EventFilter> {
685                service_name: "environmentd",
686                stderr_log: StderrLogConfig {
687                    format: StderrLogFormat::Json,
688                    filter: EnvFilter::default(),
689                },
690                opentelemetry: Some(OpenTelemetryConfig {
691                    endpoint: "http://fake_address_for_testing:8080".to_string(),
692                    headers: http::HeaderMap::new(),
693                    filter: EnvFilter::default().add_directive(Level::DEBUG.into()),
694                    resource: opentelemetry_sdk::resource::Resource::default(),
695                    max_batch_queue_size: 2048,
696                    max_export_batch_size: 512,
697                    max_concurrent_exports: 1,
698                    batch_scheduled_delay: Duration::from_millis(5000),
699                    max_export_timeout: Duration::from_secs(30),
700                }),
701                tokio_console: None,
702                sentry: None,
703                build_version: crate::BUILD_INFO.version,
704                build_sha: crate::BUILD_INFO.sha,
705                registry: metrics_registry.clone(),
706                capture: config.capture,
707            };
708            let (tracing_handle, tracing_guard) = mz_ore::tracing::configure(config).await?;
709            (tracing_handle, Some(tracing_guard))
710        } else {
711            (TracingHandle::disabled(), None)
712        };
713        let host_name = format!(
714            "localhost:{}",
715            self.inner.http["external"].handle.local_addr.port()
716        );
717        let catalog_config = CatalogConfig {
718            persist_clients: Arc::clone(&persist_clients),
719            metrics: Arc::new(mz_catalog::durable::Metrics::new(&MetricsRegistry::new())),
720        };
721
722        let inner = self
723            .inner
724            .serve(crate::Config {
725                catalog_config,
726                timestamp_oracle_url: Some(timestamp_oracle_url),
727                controller: ControllerConfig {
728                    build_info: &crate::BUILD_INFO,
729                    orchestrator,
730                    clusterd_image: "clusterd".into(),
731                    init_container_image: None,
732                    deploy_generation: config.deploy_generation,
733                    persist_location: PersistLocation {
734                        blob_uri: format!("file://{}/persist/blob", data_directory.display())
735                            .parse()
736                            .expect("invalid blob URI"),
737                        consensus_uri,
738                    },
739                    persist_clients,
740                    now: config.now.clone(),
741                    metrics_registry: metrics_registry.clone(),
742                    persist_pubsub_url: format!("http://localhost:{}", persist_pubsub_server_port),
743                    secrets_args: mz_service::secrets::SecretsReaderCliArgs {
744                        secrets_reader: mz_service::secrets::SecretsControllerKind::LocalFile,
745                        secrets_reader_local_file_dir: Some(data_directory.join("secrets")),
746                        secrets_reader_kubernetes_context: None,
747                        secrets_reader_aws_prefix: None,
748                        secrets_reader_name_prefix: None,
749                    },
750                    connection_context,
751                },
752                secrets_controller,
753                cloud_resource_controller: None,
754                tls: config.tls,
755                frontegg: config.frontegg,
756                unsafe_mode: config.unsafe_mode,
757                all_features: false,
758                metrics_registry: metrics_registry.clone(),
759                now: config.now,
760                environment_id: config.environment_id,
761                cors_allowed_origin: AllowOrigin::list([]),
762                cluster_replica_sizes: ClusterReplicaSizeMap::for_tests(),
763                bootstrap_default_cluster_replica_size: config.default_cluster_replica_size,
764                bootstrap_default_cluster_replication_factor: config
765                    .default_cluster_replication_factor,
766                bootstrap_builtin_system_cluster_config: config.builtin_system_cluster_config,
767                bootstrap_builtin_catalog_server_cluster_config: config
768                    .builtin_catalog_server_cluster_config,
769                bootstrap_builtin_probe_cluster_config: config.builtin_probe_cluster_config,
770                bootstrap_builtin_support_cluster_config: config.builtin_support_cluster_config,
771                bootstrap_builtin_analytics_cluster_config: config.builtin_analytics_cluster_config,
772                system_parameter_defaults: config.system_parameter_defaults,
773                availability_zones: Default::default(),
774                tracing_handle,
775                storage_usage_collection_interval: config.storage_usage_collection_interval,
776                storage_usage_retention_period: config.storage_usage_retention_period,
777                segment_api_key: None,
778                segment_client_side: false,
779                test_only_dummy_segment_client: false,
780                egress_addresses: vec![],
781                aws_account_id: None,
782                aws_privatelink_availability_zones: None,
783                launchdarkly_sdk_key: None,
784                launchdarkly_key_map: Default::default(),
785                config_sync_file_path: None,
786                config_sync_timeout: Duration::from_secs(30),
787                config_sync_loop_interval: None,
788                bootstrap_role: config.bootstrap_role,
789                http_host_name: Some(host_name),
790                internal_console_redirect_url: config.internal_console_redirect_url,
791                tls_reload_certs,
792                helm_chart_version: None,
793                license_key: ValidatedLicenseKey::for_tests(),
794                external_login_password_mz_system: config.external_login_password_mz_system,
795            })
796            .await?;
797
798        Ok(TestServer {
799            inner,
800            metrics_registry,
801            _temp_dir: temp_dir,
802            _scratch_dir: scratch_dir,
803            _tracing_guard: tracing_guard,
804        })
805    }
806}
807
808/// A running instance of `environmentd`.
809pub struct TestServer {
810    pub inner: crate::Server,
811    pub metrics_registry: MetricsRegistry,
812    /// The `TempDir`s are saved to prevent them from being dropped, and thus cleaned up too early.
813    _temp_dir: Option<TempDir>,
814    _scratch_dir: TempDir,
815    _tracing_guard: Option<TracingGuard>,
816}
817
818impl TestServer {
819    pub fn connect(&self) -> ConnectBuilder<'_, postgres::NoTls, NoHandle> {
820        ConnectBuilder::new(self).no_tls()
821    }
822
823    pub async fn enable_feature_flags(&self, flags: &[&'static str]) {
824        let internal_client = self.connect().internal().await.unwrap();
825
826        for flag in flags {
827            internal_client
828                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
829                .await
830                .unwrap();
831        }
832    }
833
834    pub fn ws_addr(&self) -> Uri {
835        format!(
836            "ws://{}/api/experimental/sql",
837            self.inner.http_listener_handles["external"].local_addr
838        )
839        .parse()
840        .unwrap()
841    }
842
843    pub fn internal_ws_addr(&self) -> Uri {
844        format!(
845            "ws://{}/api/experimental/sql",
846            self.inner.http_listener_handles["internal"].local_addr
847        )
848        .parse()
849        .unwrap()
850    }
851
852    pub fn http_local_addr(&self) -> SocketAddr {
853        self.inner.http_listener_handles["external"].local_addr
854    }
855
856    pub fn internal_http_local_addr(&self) -> SocketAddr {
857        self.inner.http_listener_handles["internal"].local_addr
858    }
859
860    pub fn sql_local_addr(&self) -> SocketAddr {
861        self.inner.sql_listener_handles["external"].local_addr
862    }
863
864    pub fn internal_sql_local_addr(&self) -> SocketAddr {
865        self.inner.sql_listener_handles["internal"].local_addr
866    }
867}
868
869/// A builder struct to configure a pgwire connection to a running [`TestServer`].
870///
871/// You can create this struct, and thus open a pgwire connection, using [`TestServer::connect`].
872pub struct ConnectBuilder<'s, T, H> {
873    /// A running `environmentd` test server.
874    server: &'s TestServer,
875
876    /// Postgres configuration for connecting to the test server.
877    pg_config: tokio_postgres::Config,
878    /// Port to use when connecting to the test server.
879    port: u16,
880    /// Tls settings to use.
881    tls: T,
882
883    /// Callback that gets invoked for every notice we receive.
884    notice_callback: Option<Box<dyn FnMut(tokio_postgres::error::DbError) + Send + 'static>>,
885
886    /// Type variable for whether or not we include the handle for the spawned [`tokio::task`].
887    _with_handle: H,
888}
889
890impl<'s> ConnectBuilder<'s, (), NoHandle> {
891    fn new(server: &'s TestServer) -> Self {
892        let mut pg_config = tokio_postgres::Config::new();
893        pg_config
894            .host(&Ipv4Addr::LOCALHOST.to_string())
895            .user("materialize")
896            .options("--welcome_message=off")
897            .application_name("environmentd_test_framework");
898
899        ConnectBuilder {
900            server,
901            pg_config,
902            port: server.sql_local_addr().port(),
903            tls: (),
904            notice_callback: None,
905            _with_handle: NoHandle,
906        }
907    }
908}
909
910impl<'s, T, H> ConnectBuilder<'s, T, H> {
911    /// Create a pgwire connection without using TLS.
912    ///
913    /// Note: this is the default for all connections.
914    pub fn no_tls(self) -> ConnectBuilder<'s, postgres::NoTls, H> {
915        ConnectBuilder {
916            server: self.server,
917            pg_config: self.pg_config,
918            port: self.port,
919            tls: postgres::NoTls,
920            notice_callback: self.notice_callback,
921            _with_handle: self._with_handle,
922        }
923    }
924
925    /// Create a pgwire connection with TLS.
926    pub fn with_tls<Tls>(self, tls: Tls) -> ConnectBuilder<'s, Tls, H>
927    where
928        Tls: MakeTlsConnect<Socket> + Send + 'static,
929        Tls::TlsConnect: Send,
930        Tls::Stream: Send,
931        <Tls::TlsConnect as TlsConnect<Socket>>::Future: Send,
932    {
933        ConnectBuilder {
934            server: self.server,
935            pg_config: self.pg_config,
936            port: self.port,
937            tls,
938            notice_callback: self.notice_callback,
939            _with_handle: self._with_handle,
940        }
941    }
942
943    /// Create a [`ConnectBuilder`] using the provided [`tokio_postgres::Config`].
944    pub fn with_config(mut self, pg_config: tokio_postgres::Config) -> Self {
945        self.pg_config = pg_config;
946        self
947    }
948
949    /// Set the [`SslMode`] to be used with the resulting connection.
950    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
951        self.pg_config.ssl_mode(mode);
952        self
953    }
954
955    /// Set the user for the pgwire connection.
956    pub fn user(mut self, user: &str) -> Self {
957        self.pg_config.user(user);
958        self
959    }
960
961    /// Set the password for the pgwire connection.
962    pub fn password(mut self, password: &str) -> Self {
963        self.pg_config.password(password);
964        self
965    }
966
967    /// Set the application name for the pgwire connection.
968    pub fn application_name(mut self, application_name: &str) -> Self {
969        self.pg_config.application_name(application_name);
970        self
971    }
972
973    /// Set the database name for the pgwire connection.
974    pub fn dbname(mut self, dbname: &str) -> Self {
975        self.pg_config.dbname(dbname);
976        self
977    }
978
979    /// Set the options for the pgwire connection.
980    pub fn options(mut self, options: &str) -> Self {
981        self.pg_config.options(options);
982        self
983    }
984
985    /// Configures this [`ConnectBuilder`] to connect to the __internal__ SQL port of the running
986    /// [`TestServer`].
987    ///
988    /// For example, this will change the port we connect to, and the user we connect as.
989    pub fn internal(mut self) -> Self {
990        self.port = self.server.internal_sql_local_addr().port();
991        self.pg_config.user(mz_sql::session::user::SYSTEM_USER_NAME);
992        self
993    }
994
995    /// Sets a callback for any database notices that are received from the [`TestServer`].
996    pub fn notice_callback(self, callback: impl FnMut(DbError) + Send + 'static) -> Self {
997        ConnectBuilder {
998            notice_callback: Some(Box::new(callback)),
999            ..self
1000        }
1001    }
1002
1003    /// Configures this [`ConnectBuilder`] to return the [`mz_ore::task::JoinHandle`] that is
1004    /// polling the underlying postgres connection, associated with the returned client.
1005    pub fn with_handle(self) -> ConnectBuilder<'s, T, WithHandle> {
1006        ConnectBuilder {
1007            server: self.server,
1008            pg_config: self.pg_config,
1009            port: self.port,
1010            tls: self.tls,
1011            notice_callback: self.notice_callback,
1012            _with_handle: WithHandle,
1013        }
1014    }
1015
1016    /// Returns the [`tokio_postgres::Config`] that will be used to connect.
1017    pub fn as_pg_config(&self) -> &tokio_postgres::Config {
1018        &self.pg_config
1019    }
1020}
1021
1022/// This trait enables us to either include or omit the [`mz_ore::task::JoinHandle`] in the result
1023/// of a client connection.
1024pub trait IncludeHandle: Send {
1025    type Output;
1026    fn transform_result(
1027        client: tokio_postgres::Client,
1028        handle: mz_ore::task::JoinHandle<()>,
1029    ) -> Self::Output;
1030}
1031
1032/// Type parameter that denotes we __will not__ return the [`mz_ore::task::JoinHandle`] in the
1033/// result of a [`ConnectBuilder`].
1034pub struct NoHandle;
1035impl IncludeHandle for NoHandle {
1036    type Output = tokio_postgres::Client;
1037    fn transform_result(
1038        client: tokio_postgres::Client,
1039        _handle: mz_ore::task::JoinHandle<()>,
1040    ) -> Self::Output {
1041        client
1042    }
1043}
1044
1045/// Type parameter that denotes we __will__ return the [`mz_ore::task::JoinHandle`] in the result of
1046/// a [`ConnectBuilder`].
1047pub struct WithHandle;
1048impl IncludeHandle for WithHandle {
1049    type Output = (tokio_postgres::Client, mz_ore::task::JoinHandle<()>);
1050    fn transform_result(
1051        client: tokio_postgres::Client,
1052        handle: mz_ore::task::JoinHandle<()>,
1053    ) -> Self::Output {
1054        (client, handle)
1055    }
1056}
1057
1058impl<'s, T, H> IntoFuture for ConnectBuilder<'s, T, H>
1059where
1060    T: MakeTlsConnect<Socket> + Send + 'static,
1061    T::TlsConnect: Send,
1062    T::Stream: Send,
1063    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1064    H: IncludeHandle,
1065{
1066    type Output = Result<H::Output, postgres::Error>;
1067    type IntoFuture = BoxFuture<'static, Self::Output>;
1068
1069    fn into_future(mut self) -> Self::IntoFuture {
1070        Box::pin(async move {
1071            assert!(
1072                self.pg_config.get_ports().is_empty(),
1073                "specifying multiple ports is not supported"
1074            );
1075            self.pg_config.port(self.port);
1076
1077            let (client, mut conn) = self.pg_config.connect(self.tls).await?;
1078            let mut notice_callback = self.notice_callback.take();
1079
1080            let handle = task::spawn(|| "connect", async move {
1081                while let Some(msg) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
1082                    match msg {
1083                        Ok(AsyncMessage::Notice(notice)) => {
1084                            if let Some(callback) = notice_callback.as_mut() {
1085                                callback(notice);
1086                            }
1087                        }
1088                        Ok(msg) => {
1089                            tracing::debug!(?msg, "Dropping message from database");
1090                        }
1091                        Err(e) => {
1092                            // tokio_postgres::Connection docs say:
1093                            // > Return values of None or Some(Err(_)) are “terminal”; callers
1094                            // > should not invoke this method again after receiving one of those
1095                            // > values.
1096                            tracing::info!("connection error: {e}");
1097                            break;
1098                        }
1099                    }
1100                }
1101                tracing::info!("connection closed");
1102            });
1103
1104            let output = H::transform_result(client, handle);
1105            Ok(output)
1106        })
1107    }
1108}
1109
1110/// A running instance of `environmentd`, that exposes blocking/synchronous test helpers.
1111///
1112/// Note: Ideally you should use a [`TestServer`] which relies on an external runtime, e.g. the
1113/// [`tokio::test`] macro. This struct exists so we can incrementally migrate our existing tests.
1114pub struct TestServerWithRuntime {
1115    server: TestServer,
1116    runtime: Arc<Runtime>,
1117}
1118
1119impl TestServerWithRuntime {
1120    /// Returns the [`Runtime`] owned by this [`TestServerWithRuntime`].
1121    ///
1122    /// Can be used to spawn async tasks.
1123    pub fn runtime(&self) -> &Arc<Runtime> {
1124        &self.runtime
1125    }
1126
1127    /// Returns a referece to the inner running `environmentd` [`crate::Server`]`.
1128    pub fn inner(&self) -> &crate::Server {
1129        &self.server.inner
1130    }
1131
1132    /// Connect to the __public__ SQL port of the running `environmentd` server.
1133    pub fn connect<T>(&self, tls: T) -> Result<postgres::Client, postgres::Error>
1134    where
1135        T: MakeTlsConnect<Socket> + Send + 'static,
1136        T::TlsConnect: Send,
1137        T::Stream: Send,
1138        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1139    {
1140        self.pg_config().connect(tls)
1141    }
1142
1143    /// Connect to the __internal__ SQL port of the running `environmentd` server.
1144    pub fn connect_internal<T>(&self, tls: T) -> Result<postgres::Client, anyhow::Error>
1145    where
1146        T: MakeTlsConnect<Socket> + Send + 'static,
1147        T::TlsConnect: Send,
1148        T::Stream: Send,
1149        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
1150    {
1151        Ok(self.pg_config_internal().connect(tls)?)
1152    }
1153
1154    /// Enable LaunchDarkly feature flags.
1155    pub fn enable_feature_flags(&self, flags: &[&'static str]) {
1156        let mut internal_client = self.connect_internal(postgres::NoTls).unwrap();
1157
1158        for flag in flags {
1159            internal_client
1160                .batch_execute(&format!("ALTER SYSTEM SET {} = true;", flag))
1161                .unwrap();
1162        }
1163    }
1164
1165    /// Return a [`postgres::Config`] for connecting to the __public__ SQL port of the running
1166    /// `environmentd` server.
1167    pub fn pg_config(&self) -> postgres::Config {
1168        let local_addr = self.server.sql_local_addr();
1169        let mut config = postgres::Config::new();
1170        config
1171            .host(&Ipv4Addr::LOCALHOST.to_string())
1172            .port(local_addr.port())
1173            .user("materialize")
1174            .options("--welcome_message=off");
1175        config
1176    }
1177
1178    /// Return a [`postgres::Config`] for connecting to the __internal__ SQL port of the running
1179    /// `environmentd` server.
1180    pub fn pg_config_internal(&self) -> postgres::Config {
1181        let local_addr = self.server.internal_sql_local_addr();
1182        let mut config = postgres::Config::new();
1183        config
1184            .host(&Ipv4Addr::LOCALHOST.to_string())
1185            .port(local_addr.port())
1186            .user("mz_system")
1187            .options("--welcome_message=off");
1188        config
1189    }
1190
1191    pub fn ws_addr(&self) -> Uri {
1192        self.server.ws_addr()
1193    }
1194
1195    pub fn internal_ws_addr(&self) -> Uri {
1196        self.server.internal_ws_addr()
1197    }
1198
1199    pub fn http_local_addr(&self) -> SocketAddr {
1200        self.server.http_local_addr()
1201    }
1202
1203    pub fn internal_http_local_addr(&self) -> SocketAddr {
1204        self.server.internal_http_local_addr()
1205    }
1206
1207    pub fn sql_local_addr(&self) -> SocketAddr {
1208        self.server.sql_local_addr()
1209    }
1210
1211    pub fn internal_sql_local_addr(&self) -> SocketAddr {
1212        self.server.internal_sql_local_addr()
1213    }
1214}
1215
1216#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd)]
1217pub struct MzTimestamp(pub u64);
1218
1219impl<'a> FromSql<'a> for MzTimestamp {
1220    fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<MzTimestamp, Box<dyn Error + Sync + Send>> {
1221        let n = mz_pgrepr::Numeric::from_sql(ty, raw)?;
1222        Ok(MzTimestamp(u64::try_from(n.0.0)?))
1223    }
1224
1225    fn accepts(ty: &Type) -> bool {
1226        mz_pgrepr::Numeric::accepts(ty)
1227    }
1228}
1229
1230pub trait PostgresErrorExt {
1231    fn unwrap_db_error(self) -> DbError;
1232}
1233
1234impl PostgresErrorExt for postgres::Error {
1235    fn unwrap_db_error(self) -> DbError {
1236        match self.source().and_then(|e| e.downcast_ref::<DbError>()) {
1237            Some(e) => e.clone(),
1238            None => panic!("expected DbError, but got: {:?}", self),
1239        }
1240    }
1241}
1242
1243impl<T, E> PostgresErrorExt for Result<T, E>
1244where
1245    E: PostgresErrorExt,
1246{
1247    fn unwrap_db_error(self) -> DbError {
1248        match self {
1249            Ok(_) => panic!("expected Err(DbError), but got Ok(_)"),
1250            Err(e) => e.unwrap_db_error(),
1251        }
1252    }
1253}
1254
1255/// Group commit will block writes until the current time has advanced. This can make
1256/// performing inserts while using deterministic time difficult. This is a helper
1257/// method to perform writes and advance the current time.
1258pub async fn insert_with_deterministic_timestamps(
1259    table: &'static str,
1260    values: &'static str,
1261    server: &TestServer,
1262    now: Arc<std::sync::Mutex<EpochMillis>>,
1263) -> Result<(), Box<dyn Error>> {
1264    let client_write = server.connect().await?;
1265    let client_read = server.connect().await?;
1266
1267    let mut current_ts = get_explain_timestamp(table, &client_read).await;
1268
1269    let insert_query = format!("INSERT INTO {table} VALUES {values}");
1270
1271    let write_future = client_write.execute(&insert_query, &[]);
1272    let timestamp_interval = tokio::time::interval(Duration::from_millis(1));
1273
1274    let mut write_future = std::pin::pin!(write_future);
1275    let mut timestamp_interval = std::pin::pin!(timestamp_interval);
1276
1277    // Keep increasing `now` until the write has executed succeed. Table advancements may
1278    // have increased the global timestamp by an unknown amount.
1279    loop {
1280        tokio::select! {
1281            _ = (&mut write_future) => return Ok(()),
1282            _ = timestamp_interval.tick() => {
1283                current_ts += 1;
1284                *now.lock().expect("lock poisoned") = current_ts;
1285            }
1286        };
1287    }
1288}
1289
1290pub async fn get_explain_timestamp(from_suffix: &str, client: &Client) -> EpochMillis {
1291    try_get_explain_timestamp(from_suffix, client)
1292        .await
1293        .unwrap()
1294}
1295
1296pub async fn try_get_explain_timestamp(
1297    from_suffix: &str,
1298    client: &Client,
1299) -> Result<EpochMillis, anyhow::Error> {
1300    let det = get_explain_timestamp_determination(from_suffix, client).await?;
1301    let ts = det.determination.timestamp_context.timestamp_or_default();
1302    Ok(ts.into())
1303}
1304
1305pub async fn get_explain_timestamp_determination(
1306    from_suffix: &str,
1307    client: &Client,
1308) -> Result<TimestampExplanation<mz_repr::Timestamp>, anyhow::Error> {
1309    let row = client
1310        .query_one(
1311            &format!("EXPLAIN TIMESTAMP AS JSON FOR SELECT * FROM {from_suffix}"),
1312            &[],
1313        )
1314        .await?;
1315    let explain: String = row.get(0);
1316    Ok(serde_json::from_str(&explain).unwrap())
1317}
1318
1319/// Helper function to create a Postgres source.
1320///
1321/// IMPORTANT: Make sure to call closure that is returned at the end of the test to clean up
1322/// Postgres state.
1323///
1324/// WARNING: If multiple tests use this, and the tests are run in parallel, then make sure the test
1325/// use different postgres tables.
1326pub async fn create_postgres_source_with_table<'a>(
1327    server: &TestServer,
1328    mz_client: &Client,
1329    table_name: &str,
1330    table_schema: &str,
1331    source_name: &str,
1332) -> (
1333    Client,
1334    impl FnOnce(&'a Client, &'a Client) -> LocalBoxFuture<'a, ()>,
1335) {
1336    server
1337        .enable_feature_flags(&["enable_create_table_from_source"])
1338        .await;
1339
1340    let postgres_url = env::var("POSTGRES_URL")
1341        .map_err(|_| anyhow!("POSTGRES_URL environment variable is not set"))
1342        .unwrap();
1343
1344    let (pg_client, connection) = tokio_postgres::connect(&postgres_url, postgres::NoTls)
1345        .await
1346        .unwrap();
1347
1348    let pg_config: tokio_postgres::Config = postgres_url.parse().unwrap();
1349    let user = pg_config.get_user().unwrap_or("postgres");
1350    let db_name = pg_config.get_dbname().unwrap_or(user);
1351    let ports = pg_config.get_ports();
1352    let port = if ports.is_empty() { 5432 } else { ports[0] };
1353    let hosts = pg_config.get_hosts();
1354    let host = if hosts.is_empty() {
1355        "localhost".to_string()
1356    } else {
1357        match &hosts[0] {
1358            Host::Tcp(host) => host.to_string(),
1359            Host::Unix(host) => host.to_str().unwrap().to_string(),
1360        }
1361    };
1362    let password = pg_config.get_password();
1363
1364    mz_ore::task::spawn(|| "postgres-source-connection", async move {
1365        if let Err(e) = connection.await {
1366            panic!("connection error: {}", e);
1367        }
1368    });
1369
1370    // Create table in Postgres with publication.
1371    let _ = pg_client
1372        .execute(&format!("DROP TABLE IF EXISTS {table_name};"), &[])
1373        .await
1374        .unwrap();
1375    let _ = pg_client
1376        .execute(&format!("DROP PUBLICATION IF EXISTS {source_name};"), &[])
1377        .await
1378        .unwrap();
1379    let _ = pg_client
1380        .execute(&format!("CREATE TABLE {table_name} {table_schema};"), &[])
1381        .await
1382        .unwrap();
1383    let _ = pg_client
1384        .execute(
1385            &format!("ALTER TABLE {table_name} REPLICA IDENTITY FULL;"),
1386            &[],
1387        )
1388        .await
1389        .unwrap();
1390    let _ = pg_client
1391        .execute(
1392            &format!("CREATE PUBLICATION {source_name} FOR TABLE {table_name};"),
1393            &[],
1394        )
1395        .await
1396        .unwrap();
1397
1398    // Create postgres source in Materialize.
1399    let mut connection_str = format!("HOST '{host}', PORT {port}, USER {user}, DATABASE {db_name}");
1400    if let Some(password) = password {
1401        let password = std::str::from_utf8(password).unwrap();
1402        mz_client
1403            .batch_execute(&format!("CREATE SECRET s AS '{password}'"))
1404            .await
1405            .unwrap();
1406        connection_str = format!("{connection_str}, PASSWORD SECRET s");
1407    }
1408    mz_client
1409        .batch_execute(&format!(
1410            "CREATE CONNECTION pgconn TO POSTGRES ({connection_str})"
1411        ))
1412        .await
1413        .unwrap();
1414    mz_client
1415        .batch_execute(&format!(
1416            "CREATE SOURCE {source_name}
1417            FROM POSTGRES
1418            CONNECTION pgconn
1419            (PUBLICATION '{source_name}')"
1420        ))
1421        .await
1422        .unwrap();
1423    mz_client
1424        .batch_execute(&format!(
1425            "CREATE TABLE {table_name}
1426            FROM SOURCE {source_name}
1427            (REFERENCE {table_name});"
1428        ))
1429        .await
1430        .unwrap();
1431
1432    let table_name = table_name.to_string();
1433    let source_name = source_name.to_string();
1434    (
1435        pg_client,
1436        move |mz_client: &'a Client, pg_client: &'a Client| {
1437            let f: Pin<Box<dyn Future<Output = ()> + 'a>> = Box::pin(async move {
1438                mz_client
1439                    .batch_execute(&format!("DROP SOURCE {source_name} CASCADE;"))
1440                    .await
1441                    .unwrap();
1442                mz_client
1443                    .batch_execute("DROP CONNECTION pgconn;")
1444                    .await
1445                    .unwrap();
1446
1447                let _ = pg_client
1448                    .execute(&format!("DROP PUBLICATION {source_name};"), &[])
1449                    .await
1450                    .unwrap();
1451                let _ = pg_client
1452                    .execute(&format!("DROP TABLE {table_name};"), &[])
1453                    .await
1454                    .unwrap();
1455            });
1456            f
1457        },
1458    )
1459}
1460
1461pub async fn wait_for_pg_table_population(mz_client: &Client, view_name: &str, source_rows: i64) {
1462    let current_isolation = mz_client
1463        .query_one("SHOW transaction_isolation", &[])
1464        .await
1465        .unwrap()
1466        .get::<_, String>(0);
1467    mz_client
1468        .batch_execute("SET transaction_isolation = SERIALIZABLE")
1469        .await
1470        .unwrap();
1471    Retry::default()
1472        .retry_async(|_| async move {
1473            let rows = mz_client
1474                .query_one(&format!("SELECT COUNT(*) FROM {view_name};"), &[])
1475                .await
1476                .unwrap()
1477                .get::<_, i64>(0);
1478            if rows == source_rows {
1479                Ok(())
1480            } else {
1481                Err(format!(
1482                    "Waiting for {source_rows} row to be ingested. Currently at {rows}."
1483                ))
1484            }
1485        })
1486        .await
1487        .unwrap();
1488    mz_client
1489        .batch_execute(&format!(
1490            "SET transaction_isolation = '{current_isolation}'"
1491        ))
1492        .await
1493        .unwrap();
1494}
1495
1496// Initializes a websocket connection. Returns the init messages before the initial ReadyForQuery.
1497pub fn auth_with_ws(
1498    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1499    mut options: BTreeMap<String, String>,
1500) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1501    if !options.contains_key("welcome_message") {
1502        options.insert("welcome_message".into(), "off".into());
1503    }
1504    auth_with_ws_impl(
1505        ws,
1506        Message::Text(
1507            serde_json::to_string(&WebSocketAuth::Basic {
1508                user: "materialize".into(),
1509                password: "".into(),
1510                options,
1511            })
1512            .unwrap(),
1513        ),
1514    )
1515}
1516
1517pub fn auth_with_ws_impl(
1518    ws: &mut WebSocket<MaybeTlsStream<TcpStream>>,
1519    auth_message: Message,
1520) -> Result<Vec<WebSocketResponse>, anyhow::Error> {
1521    ws.send(auth_message)?;
1522
1523    // Wait for initial ready response.
1524    let mut msgs = Vec::new();
1525    loop {
1526        let resp = ws.read()?;
1527        match resp {
1528            Message::Text(msg) => {
1529                let msg: WebSocketResponse = serde_json::from_str(&msg).unwrap();
1530                match msg {
1531                    WebSocketResponse::ReadyForQuery(_) => break,
1532                    msg => {
1533                        msgs.push(msg);
1534                    }
1535                }
1536            }
1537            Message::Ping(_) => continue,
1538            Message::Close(None) => return Err(anyhow!("ws closed after auth")),
1539            Message::Close(Some(close_frame)) => {
1540                return Err(anyhow!("ws closed after auth").context(close_frame));
1541            }
1542            _ => panic!("unexpected response: {:?}", resp),
1543        }
1544    }
1545    Ok(msgs)
1546}
1547
1548pub fn make_header<H: Header>(h: H) -> HeaderMap {
1549    let mut map = HeaderMap::new();
1550    map.typed_insert(h);
1551    map
1552}
1553
1554pub fn make_pg_tls<F>(configure: F) -> MakeTlsConnector
1555where
1556    F: FnOnce(&mut SslConnectorBuilder) -> Result<(), ErrorStack>,
1557{
1558    let mut connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
1559    // Disable TLS v1.3 because `postgres` and `hyper` produce stabler error
1560    // messages with TLS v1.2.
1561    //
1562    // Briefly, in TLS v1.3, failing to present a client certificate does not
1563    // error during the TLS handshake, as it does in TLS v1.2, but on the first
1564    // attempt to read from the stream. But both `postgres` and `hyper` write a
1565    // bunch of data before attempting to read from the stream. With a failed
1566    // TLS v1.3 connection, sometimes `postgres` and `hyper` succeed in writing
1567    // out this data, and then return a nice error message on the call to read.
1568    // But sometimes the connection is closed before they write out the data,
1569    // and so they report "connection closed" before they ever call read, never
1570    // noticing the underlying SSL error.
1571    //
1572    // It's unclear who's bug this is. Is it on `hyper`/`postgres` to call read
1573    // if writing to the stream fails to see if a TLS error occured? Is it on
1574    // OpenSSL to provide a better API [1]? Is it a protocol issue that ought to
1575    // be corrected in TLS v1.4? We don't want to answer these questions, so we
1576    // just avoid TLS v1.3 for now.
1577    //
1578    // [1]: https://github.com/openssl/openssl/issues/11118
1579    let options = connector_builder.options() | SslOptions::NO_TLSV1_3;
1580    connector_builder.set_options(options);
1581    configure(&mut connector_builder).unwrap();
1582    MakeTlsConnector::new(connector_builder.build())
1583}
1584
1585/// A certificate authority for use in tests.
1586pub struct Ca {
1587    pub dir: TempDir,
1588    pub name: X509Name,
1589    pub cert: X509,
1590    pub pkey: PKey<Private>,
1591}
1592
1593impl Ca {
1594    fn make_ca(name: &str, parent: Option<&Ca>) -> Result<Ca, Box<dyn Error>> {
1595        let dir = tempfile::tempdir()?;
1596        let rsa = Rsa::generate(2048)?;
1597        let pkey = PKey::from_rsa(rsa)?;
1598        let name = {
1599            let mut builder = X509NameBuilder::new()?;
1600            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1601            builder.build()
1602        };
1603        let cert = {
1604            let mut builder = X509::builder()?;
1605            builder.set_version(2)?;
1606            builder.set_pubkey(&pkey)?;
1607            builder.set_issuer_name(parent.map(|ca| &ca.name).unwrap_or(&name))?;
1608            builder.set_subject_name(&name)?;
1609            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1610            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1611            builder.append_extension(BasicConstraints::new().critical().ca().build()?)?;
1612            builder.sign(
1613                parent.map(|ca| &ca.pkey).unwrap_or(&pkey),
1614                MessageDigest::sha256(),
1615            )?;
1616            builder.build()
1617        };
1618        fs::write(dir.path().join("ca.crt"), cert.to_pem()?)?;
1619        Ok(Ca {
1620            dir,
1621            name,
1622            cert,
1623            pkey,
1624        })
1625    }
1626
1627    /// Creates a new root certificate authority.
1628    pub fn new_root(name: &str) -> Result<Ca, Box<dyn Error>> {
1629        Ca::make_ca(name, None)
1630    }
1631
1632    /// Returns the path to the CA's certificate.
1633    pub fn ca_cert_path(&self) -> PathBuf {
1634        self.dir.path().join("ca.crt")
1635    }
1636
1637    /// Requests a new intermediate certificate authority.
1638    pub fn request_ca(&self, name: &str) -> Result<Ca, Box<dyn Error>> {
1639        Ca::make_ca(name, Some(self))
1640    }
1641
1642    /// Generates a certificate with the specified Common Name (CN) that is
1643    /// signed by the CA.
1644    ///
1645    /// Returns the paths to the certificate and key.
1646    pub fn request_client_cert(&self, name: &str) -> Result<(PathBuf, PathBuf), Box<dyn Error>> {
1647        self.request_cert(name, iter::empty())
1648    }
1649
1650    /// Like `request_client_cert`, but permits specifying additional IP
1651    /// addresses to attach as Subject Alternate Names.
1652    pub fn request_cert<I>(&self, name: &str, ips: I) -> Result<(PathBuf, PathBuf), Box<dyn Error>>
1653    where
1654        I: IntoIterator<Item = IpAddr>,
1655    {
1656        let rsa = Rsa::generate(2048)?;
1657        let pkey = PKey::from_rsa(rsa)?;
1658        let subject_name = {
1659            let mut builder = X509NameBuilder::new()?;
1660            builder.append_entry_by_nid(Nid::COMMONNAME, name)?;
1661            builder.build()
1662        };
1663        let cert = {
1664            let mut builder = X509::builder()?;
1665            builder.set_version(2)?;
1666            builder.set_pubkey(&pkey)?;
1667            builder.set_issuer_name(self.cert.subject_name())?;
1668            builder.set_subject_name(&subject_name)?;
1669            builder.set_not_before(&*Asn1Time::days_from_now(0)?)?;
1670            builder.set_not_after(&*Asn1Time::days_from_now(365)?)?;
1671            for ip in ips {
1672                builder.append_extension(
1673                    SubjectAlternativeName::new()
1674                        .ip(&ip.to_string())
1675                        .build(&builder.x509v3_context(None, None))?,
1676                )?;
1677            }
1678            builder.sign(&self.pkey, MessageDigest::sha256())?;
1679            builder.build()
1680        };
1681        let cert_path = self.dir.path().join(Path::new(name).with_extension("crt"));
1682        let key_path = self.dir.path().join(Path::new(name).with_extension("key"));
1683        fs::write(&cert_path, cert.to_pem()?)?;
1684        fs::write(&key_path, pkey.private_key_to_pem_pkcs8()?)?;
1685        Ok((cert_path, key_path))
1686    }
1687}