Skip to main content

mz_deploy/client/
connection.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Database client for mz-deploy.
11//!
12//! This module provides the main `Client` struct for interacting with Materialize.
13//! The client handles connection management and delegates specialized operations
14//! to domain-specific sub-clients.
15//!
16//! ## Sub-Client Architecture
17//!
18//! Operations are grouped into domain sub-clients accessed via accessor methods
19//! on `Client`. Each sub-client borrows the `Client` and provides a focused API:
20//!
21//! | Sub-client | Accessor | Responsibility |
22//! |------------|----------|---------------|
23//! | `DeploymentsClient` | `.deployments()` | Deployment lifecycle (stage, promote, abort) |
24//! | `DeploymentsClientMut` | `.deployments_mut()` | Mutable deployment ops (SUBSCRIBE cursors) |
25//! | `IntrospectionClient` | `.introspection()` | Read-only catalog metadata queries |
26//! | `ValidationClient` | `.validation()` | Pre-deployment environment checks |
27//! | `TypeInfoClient` | `.types()` | Column/type introspection for type checking |
28//! | `ProvisioningClient` | `.provisioning()` | Idempotent DDL for databases, schemas, clusters |
29//!
30//! ## TLS Policy
31//!
32//! Per-profile `sslmode` with libpq semantics (`disable`, `prefer`, `require`,
33//! `verify-ca`, `verify-full`). When unset, loopback hosts default to
34//! `prefer` and everything else defaults to `require`. See the design at
35//! `docs/superpowers/specs/2026-04-22-profile-tls-design.md` for the behavior
36//! table and migration notes.
37
38use crate::client::errors::ConnectionError;
39use crate::config::{Profile, SslMode};
40use crate::info;
41use mz_postgres_util::Sql;
42use std::collections::BTreeMap;
43use tokio_postgres::types::ToSql;
44use tokio_postgres::{Client as PgClient, NoTls, Row, SimpleQueryMessage, Transaction};
45
46/// Database client for interacting with Materialize.
47///
48/// The `Client` struct provides methods for:
49/// - Connecting to the database
50/// - Schema and cluster management
51/// - Deployment tracking
52/// - Database introspection
53/// - Project validation
54pub struct Client {
55    client: PgClient,
56    profile: Profile,
57}
58
59/// Domain sub-client for deployment lifecycle operations.
60pub struct DeploymentsClient<'a> {
61    pub(crate) client: &'a Client,
62}
63
64/// Domain sub-client for deployment operations that require mutable client access.
65pub struct DeploymentsClientMut<'a> {
66    pub(crate) client: &'a mut Client,
67}
68
69/// Domain sub-client for metadata and object introspection operations.
70pub struct IntrospectionClient<'a> {
71    pub(crate) client: &'a Client,
72}
73
74/// Domain sub-client for project and privilege validation operations.
75pub struct ValidationClient<'a> {
76    pub(crate) client: &'a Client,
77}
78
79/// Domain sub-client for column/type introspection used by type checking and tests.
80pub struct TypeInfoClient<'a> {
81    pub(crate) client: &'a Client,
82}
83
84/// Domain sub-client for provisioning databases, schemas, and clusters.
85pub struct ProvisioningClient<'a> {
86    pub(crate) client: &'a Client,
87}
88
89/// Domain sub-client for developer overlay manifest operations.
90pub struct DevOverlaysClient<'a> {
91    pub(crate) client: &'a Client,
92}
93
94const APPLICATION_NAME: &str = "mz-deploy";
95
96impl Client {
97    /// Connect to the database using a Profile directly.
98    ///
99    /// TLS behavior is driven by `profile.sslmode`; when unset, loopback hosts
100    /// default to `prefer` and everything else defaults to `require`. Verification
101    /// (`verify-ca` / `verify-full`) sources CAs from `profile.sslrootcert`, then
102    /// the platform CA hunt, then OpenSSL's compiled-in defaults.
103    ///
104    /// Every connection is pinned to `_mz_deploy_server` via libpq options;
105    /// any user-supplied `cluster` in profile.options is silently overridden.
106    /// The unit-test runtime uses `connect_with_profile_no_pin` instead —
107    /// its ephemeral Docker container has no `_mz_deploy_server` cluster.
108    pub async fn connect_with_profile(profile: Profile) -> Result<Self, ConnectionError> {
109        Self::connect_with_profile_inner(profile, /* pin_server_cluster */ true).await
110    }
111
112    /// Connect without pinning the session cluster to `_mz_deploy_server`.
113    ///
114    /// Used in two places where `_mz_deploy_server` is not yet (or never)
115    /// present:
116    /// - The ephemeral Docker container used by unit-test execution.
117    /// - `setup::run`, which is the command that creates the cluster.
118    ///
119    /// Uses whatever cluster the profile or server default selects.
120    /// Deliberately `pub(crate)` so nothing outside the crate can bypass
121    /// the production session-cluster pin.
122    pub(crate) async fn connect_with_profile_no_pin(
123        profile: Profile,
124    ) -> Result<Self, ConnectionError> {
125        Self::connect_with_profile_inner(profile, /* pin_server_cluster */ false).await
126    }
127
128    async fn connect_with_profile_inner(
129        profile: Profile,
130        pin_server_cluster: bool,
131    ) -> Result<Self, ConnectionError> {
132        let host = profile.require_host()?;
133        let mut config = tokio_postgres::Config::new();
134        config.host(host);
135        config.port(profile.port);
136        config.user(&profile.username);
137        config.dbname("materialize");
138        if let Some(password) = &profile.password {
139            config.password(password.as_str());
140        }
141        config.application_name(APPLICATION_NAME);
142
143        let mut effective_options = profile.options.clone();
144        if pin_server_cluster {
145            effective_options.insert(
146                "cluster".to_string(),
147                crate::client::SERVER_CLUSTER_NAME.to_string(),
148            );
149        }
150        if let Some(inner) = build_options_string(&effective_options) {
151            config.options(&inner);
152        }
153
154        let mode = profile.sslmode.unwrap_or_else(|| default_sslmode(host));
155        let hunt: Vec<&std::path::Path> =
156            DEFAULT_CA_PATHS.iter().map(std::path::Path::new).collect();
157        let spec = plan_connector(mode, profile.sslrootcert.as_deref(), host, &hunt, |p| {
158            p.exists()
159        })?;
160        let connector = build_connector(spec)?;
161
162        config.ssl_mode(tokio_ssl_mode(mode));
163
164        // `config.connect(NoTls)` and `config.connect(tls)` return `Connection`s
165        // parameterized over different TLS stream types that can't unify. We box
166        // both to a common `dyn Future` so there's a single spawn site below.
167        type BoxConnection =
168            Box<dyn Future<Output = Result<(), tokio_postgres::Error>> + Send + Unpin>;
169        let (client, connection): (PgClient, BoxConnection) = match connector {
170            Connector::NoTls => {
171                let (client, connection) = config
172                    .connect(NoTls)
173                    .await
174                    .map_err(|source| classify_connect_error(source, &profile, mode))?;
175                (client, Box::new(connection))
176            }
177            Connector::Tls(tls) => {
178                let (client, connection) = config
179                    .connect(tls)
180                    .await
181                    .map_err(|source| classify_connect_error(source, &profile, mode))?;
182                (client, Box::new(connection))
183            }
184        };
185
186        mz_ore::task::spawn(|| "mz-deploy-connection", async move {
187            if let Err(e) = connection.await {
188                info!("connection error: {}", e);
189            }
190        });
191
192        Ok(Client { client, profile })
193    }
194
195    /// Get the profile used for this connection.
196    pub fn profile(&self) -> &Profile {
197        &self.profile
198    }
199
200    /// Start a transaction on the underlying connection.
201    pub(crate) async fn begin_transaction(&mut self) -> Result<Transaction<'_>, ConnectionError> {
202        self.client
203            .transaction()
204            .await
205            .map_err(ConnectionError::Query)
206    }
207
208    /// Access deployment lifecycle operations.
209    pub fn deployments(&self) -> DeploymentsClient<'_> {
210        DeploymentsClient { client: self }
211    }
212
213    /// Access mutable deployment lifecycle operations.
214    pub fn deployments_mut(&mut self) -> DeploymentsClientMut<'_> {
215        DeploymentsClientMut { client: self }
216    }
217
218    /// Access metadata and object introspection operations.
219    pub fn introspection(&self) -> IntrospectionClient<'_> {
220        IntrospectionClient { client: self }
221    }
222
223    /// Access database validation operations.
224    pub fn validation(&self) -> ValidationClient<'_> {
225        ValidationClient { client: self }
226    }
227
228    /// Access type/column introspection operations.
229    pub fn types(&self) -> TypeInfoClient<'_> {
230        TypeInfoClient { client: self }
231    }
232
233    /// Access provisioning operations for databases, schemas, and clusters.
234    pub fn provisioning(&self) -> ProvisioningClient<'_> {
235        ProvisioningClient { client: self }
236    }
237
238    /// Access developer overlay manifest operations.
239    pub fn dev_overlays(&self) -> DevOverlaysClient<'_> {
240        DevOverlaysClient { client: self }
241    }
242
243    /// Execute a SQL statement that doesn't return rows.
244    pub async fn execute(
245        &self,
246        statement: &str,
247        params: &[&(dyn ToSql + Sync)],
248    ) -> Result<u64, ConnectionError> {
249        mz_postgres_util::execute(
250            &self.client,
251            Sql::raw_unchecked(statement.to_string()),
252            params,
253        )
254        .await
255        .map_err(ConnectionError::from)
256    }
257
258    /// Execute a SQL query and return the resulting rows.
259    pub async fn query_one(
260        &self,
261        statement: &str,
262        params: &[&(dyn ToSql + Sync)],
263    ) -> Result<Row, ConnectionError> {
264        mz_postgres_util::query_one(
265            &self.client,
266            Sql::raw_unchecked(statement.to_string()),
267            params,
268        )
269        .await
270        .map_err(ConnectionError::from)
271    }
272
273    /// Execute a SQL query and return the resulting rows.
274    pub async fn query(
275        &self,
276        statement: &str,
277        params: &[&(dyn ToSql + Sync)],
278    ) -> Result<Vec<Row>, ConnectionError> {
279        mz_postgres_util::query(
280            &self.client,
281            Sql::raw_unchecked(statement.to_string()),
282            params,
283        )
284        .await
285        .map_err(ConnectionError::from)
286    }
287
288    /// Execute a SQL statement using the simple query protocol (text-only, no binary encoding).
289    pub async fn simple_query(
290        &self,
291        query: &str,
292    ) -> Result<Vec<SimpleQueryMessage>, ConnectionError> {
293        mz_postgres_util::simple_query(&self.client, Sql::raw_unchecked(query.to_string()))
294            .await
295            .map_err(ConnectionError::from)
296    }
297
298    /// Execute one or more SQL statements that don't return rows, using the simple query protocol.
299    pub async fn batch_execute(&self, query: &str) -> Result<(), ConnectionError> {
300        mz_postgres_util::batch_execute(&self.client, Sql::raw_unchecked(query.to_string()))
301            .await
302            .map_err(ConnectionError::from)
303    }
304}
305
306/// Platform CA bundle candidates, walked in order by `build_connector` when
307/// `sslmode` resolves to `verify-ca` / `verify-full` and the profile does not
308/// set `sslrootcert`. Kept in sync with libpq-like installations on our
309/// supported platforms.
310const DEFAULT_CA_PATHS: &[&str] = &[
311    "/etc/ssl/cert.pem",                    // macOS system
312    "/opt/homebrew/etc/openssl@3/cert.pem", // macOS Homebrew ARM
313    "/usr/local/etc/openssl@3/cert.pem",    // macOS Homebrew Intel
314    "/opt/homebrew/etc/openssl/cert.pem",   // macOS Homebrew ARM (older)
315    "/usr/local/etc/openssl/cert.pem",      // macOS Homebrew Intel (older)
316    "/etc/ssl/certs/ca-certificates.crt",   // Debian/Ubuntu
317    "/etc/pki/tls/certs/ca-bundle.crt",     // RHEL/CentOS
318    "/etc/ssl/ca-bundle.pem",               // OpenSUSE
319];
320
321/// The default `SslMode` applied when a profile does not set `sslmode`.
322///
323/// Loopback hosts get `Prefer` so local Mz (which does not offer TLS) works
324/// without explicit config. Everything else gets `Require` — TLS is required
325/// but certificate verification is not. Users who want verification set
326/// `sslmode = "verify-ca"` or `sslmode = "verify-full"` explicitly.
327pub(crate) fn default_sslmode(host: &str) -> SslMode {
328    if is_loopback_host(host) {
329        SslMode::Prefer
330    } else {
331        SslMode::Require
332    }
333}
334
335/// Returns `true` if `host` names the loopback interface.
336///
337/// Recognizes `localhost`, any address in `127.0.0.0/8`, and `::1` (with or
338/// without URL-style brackets). Used by the SQL TLS defaults and by
339/// `mz-deploy mcp` to pick `http://` vs `https://`.
340pub(crate) fn is_loopback_host(host: &str) -> bool {
341    if host == "localhost" {
342        return true;
343    }
344    let unbracketed = host
345        .strip_prefix('[')
346        .and_then(|s| s.strip_suffix(']'))
347        .unwrap_or(host);
348    if let Ok(ip) = unbracketed.parse::<std::net::IpAddr>() {
349        return ip.is_loopback();
350    }
351    false
352}
353
354fn tokio_ssl_mode(mode: SslMode) -> tokio_postgres::config::SslMode {
355    use tokio_postgres::config::SslMode as TokioMode;
356    match mode {
357        SslMode::Disable => TokioMode::Disable,
358        SslMode::Prefer => TokioMode::Prefer,
359        SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull => TokioMode::Require,
360    }
361}
362
363/// How `verify-full` should match the cert's SAN entries.
364#[derive(Debug)]
365enum HostCheck {
366    /// Match a DNS name. Host parsed as a non-IP string.
367    Dns(String),
368    /// Match an IPv4 or IPv6 literal. Host parsed as `IpAddr`.
369    Ip(std::net::IpAddr),
370}
371
372/// Pure-data representation of the TLS setup for a connection, derived from
373/// a profile's effective `SslMode` and `sslrootcert`.
374#[derive(Debug)]
375enum ConnectorSpec {
376    NoTls,
377    Tls {
378        verify: openssl::ssl::SslVerifyMode,
379        host_check: Option<HostCheck>,
380        ca_source: CaSource,
381    },
382}
383
384/// Where the CA bundle comes from for verifying the server cert, or the
385/// absence thereof for non-verifying modes.
386#[derive(Debug)]
387enum CaSource {
388    /// `disable` / `prefer` / `require` — no CA is loaded.
389    None,
390    /// Explicit path from the profile's `sslrootcert` field.
391    Explicit(std::path::PathBuf),
392    /// Path discovered by walking `DEFAULT_CA_PATHS`.
393    Hunted(std::path::PathBuf),
394    /// Fallback to OpenSSL's compiled-in default verify paths
395    /// (`set_default_verify_paths`). Used only when the hunt finds nothing
396    /// and no explicit path is set.
397    DefaultVerifyPaths,
398}
399
400/// Runtime-ready connector variant handed to `tokio_postgres::Config::connect`.
401enum Connector {
402    NoTls,
403    Tls(postgres_openssl::MakeTlsConnector),
404}
405
406impl std::fmt::Debug for Connector {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        match self {
409            Connector::NoTls => write!(f, "Connector::NoTls"),
410            Connector::Tls(_) => write!(f, "Connector::Tls(...)"),
411        }
412    }
413}
414
415/// Plan the TLS setup for a connection from the resolved (mode, CA) inputs.
416///
417/// Pure: does no network I/O and — aside from the injected `ca_exists`
418/// predicate — does no filesystem I/O. Returns a [`ConnectorSpec`] that
419/// [`build_connector`] then materializes into an OpenSSL context.
420///
421/// `hunt_candidates` is the ordered list of default CA paths to probe when
422/// `sslrootcert` is not set. In production this is [`DEFAULT_CA_PATHS`];
423/// tests pass their own list plus a stubbed `ca_exists` predicate.
424fn plan_connector(
425    mode: SslMode,
426    sslrootcert: Option<&std::path::Path>,
427    host: &str,
428    hunt_candidates: &[&std::path::Path],
429    ca_exists: impl Fn(&std::path::Path) -> bool,
430) -> Result<ConnectorSpec, ConnectionError> {
431    use openssl::ssl::SslVerifyMode;
432
433    match mode {
434        SslMode::Disable => Ok(ConnectorSpec::NoTls),
435        SslMode::Prefer | SslMode::Require => Ok(ConnectorSpec::Tls {
436            verify: SslVerifyMode::NONE,
437            host_check: None,
438            ca_source: CaSource::None,
439        }),
440        SslMode::VerifyCa | SslMode::VerifyFull => {
441            let ca_source = resolve_ca_source(sslrootcert, hunt_candidates, ca_exists)?;
442            let host_check = if matches!(mode, SslMode::VerifyFull) {
443                Some(match host.parse::<std::net::IpAddr>() {
444                    Ok(ip) => HostCheck::Ip(ip),
445                    Err(_) => HostCheck::Dns(host.to_string()),
446                })
447            } else {
448                None
449            };
450            Ok(ConnectorSpec::Tls {
451                verify: SslVerifyMode::PEER,
452                host_check,
453                ca_source,
454            })
455        }
456    }
457}
458
459fn resolve_ca_source(
460    explicit: Option<&std::path::Path>,
461    hunt_candidates: &[&std::path::Path],
462    ca_exists: impl Fn(&std::path::Path) -> bool,
463) -> Result<CaSource, ConnectionError> {
464    if let Some(path) = explicit {
465        if ca_exists(path) {
466            return Ok(CaSource::Explicit(path.to_path_buf()));
467        } else {
468            return Err(ConnectionError::TlsCaNotFound);
469        }
470    }
471    for candidate in hunt_candidates {
472        if ca_exists(candidate) {
473            return Ok(CaSource::Hunted(candidate.to_path_buf()));
474        }
475    }
476    Ok(CaSource::DefaultVerifyPaths)
477}
478
479/// Convert a [`ConnectorSpec`] into a runtime [`Connector`] by wiring up the
480/// OpenSSL context. All filesystem I/O for CAs happens here.
481fn build_connector(spec: ConnectorSpec) -> Result<Connector, ConnectionError> {
482    use openssl::ssl::{SslConnector, SslMethod};
483
484    match spec {
485        ConnectorSpec::NoTls => Ok(Connector::NoTls),
486        ConnectorSpec::Tls {
487            verify,
488            host_check,
489            ca_source,
490        } => {
491            let mut builder = SslConnector::builder(SslMethod::tls()).map_err(|e| {
492                ConnectionError::Message(format!("Failed to create TLS builder: {}", e))
493            })?;
494
495            match ca_source {
496                CaSource::None => {}
497                CaSource::Explicit(path) | CaSource::Hunted(path) => {
498                    builder
499                        .set_ca_file(&path)
500                        .map_err(|_| ConnectionError::TlsCaNotFound)?;
501                }
502                CaSource::DefaultVerifyPaths => {
503                    builder
504                        .set_default_verify_paths()
505                        .map_err(|_| ConnectionError::TlsCaNotFound)?;
506                }
507            }
508
509            builder.set_verify(verify);
510
511            if let Some(check) = host_check {
512                let param = builder.verify_param_mut();
513                match check {
514                    HostCheck::Dns(name) => {
515                        param
516                            .set_host(&name)
517                            .map_err(|e| ConnectionError::Message(format!("{}", e)))?;
518                    }
519                    HostCheck::Ip(ip) => {
520                        param
521                            .set_ip(ip)
522                            .map_err(|e| ConnectionError::Message(format!("{}", e)))?;
523                    }
524                }
525            }
526
527            Ok(Connector::Tls(postgres_openssl::MakeTlsConnector::new(
528                builder.build(),
529            )))
530        }
531    }
532}
533
534/// Classify a `tokio_postgres::Error` surfaced from `Config::connect(...)`
535/// into the most specific `ConnectionError` variant.
536///
537/// Rules:
538/// - OpenSSL error found in the source chain + `mode` is `verify-*` →
539///   [`ConnectionError::TlsVerification`] (with `hostname_suffix` if the
540///   OpenSSL message names a hostname / IP mismatch).
541/// - `mode` is `require` / `verify-*` and the error message indicates the
542///   server refused TLS → [`ConnectionError::TlsRequiredNotSupported`].
543/// - Otherwise → [`ConnectionError::Connect`].
544fn classify_connect_error(
545    source: tokio_postgres::Error,
546    profile: &Profile,
547    mode: SslMode,
548) -> ConnectionError {
549    // Caller has already gone through `require_host()` to attempt the
550    // connection that produced this error, so `host` must be `Some` here.
551    let host = profile.host.clone().unwrap_or_default();
552    if matches!(mode, SslMode::VerifyCa | SslMode::VerifyFull) {
553        if let Some(ssl_msg) = ssl_error_in_chain(&source) {
554            let hostname_suffix = if ssl_msg.contains("hostname mismatch")
555                || ssl_msg.contains("Hostname mismatch")
556                || ssl_msg.contains("IP address mismatch")
557            {
558                " (hostname mismatch)"
559            } else {
560                ""
561            };
562            return ConnectionError::TlsVerification {
563                host,
564                port: profile.port,
565                hostname_suffix,
566                source,
567            };
568        }
569    }
570
571    if matches!(
572        mode,
573        SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull
574    ) && message_indicates_tls_refused(&source)
575    {
576        return ConnectionError::TlsRequiredNotSupported {
577            host,
578            port: profile.port,
579            source,
580        };
581    }
582
583    ConnectionError::Connect {
584        host,
585        port: profile.port,
586        source,
587    }
588}
589
590/// Walk the source chain of a `tokio_postgres::Error` and return the string
591/// form of the first `openssl::error::ErrorStack` found.
592fn ssl_error_in_chain(err: &tokio_postgres::Error) -> Option<String> {
593    let mut cur: &(dyn std::error::Error + 'static) = err;
594    while let Some(source) = std::error::Error::source(cur) {
595        if source.is::<openssl::error::ErrorStack>() {
596            return Some(source.to_string());
597        }
598        cur = source;
599    }
600    None
601}
602
603/// Heuristic: does the error look like "server said no to our TLS request"?
604///
605/// `tokio_postgres` surfaces this as an io error or a "server does not
606/// support TLS" message depending on version. We string-match the Display
607/// form because the typed variants are not all public.
608fn message_indicates_tls_refused(err: &tokio_postgres::Error) -> bool {
609    matches_tls_refused_message(&err.to_string())
610}
611
612/// Pure string check for the substrings `tokio_postgres` produces when the
613/// server refuses the TLS startup request (responds `'N'` to the SSL byte).
614///
615/// Extracted from `message_indicates_tls_refused` so we can unit-test the
616/// substring list — the caller takes `&tokio_postgres::Error`, which has no
617/// public constructor.
618fn matches_tls_refused_message(msg: &str) -> bool {
619    msg.contains("TLS was required")
620        || msg.contains("server does not support TLS")
621        || msg.contains("server does not support SSL")
622}
623
624/// Escape a value for embedding inside the libpq `options` connection
625/// parameter.
626///
627/// Within the `options` string, spaces separate `-c key=value` tokens unless
628/// escaped, and backslash is the escape character. Only spaces and backslashes
629/// are special; all other characters are literal.
630fn escape_options_value(value: &str) -> String {
631    let mut out = String::with_capacity(value.len());
632    for c in value.chars() {
633        match c {
634            '\\' => out.push_str(r"\\"),
635            ' ' => out.push_str(r"\ "),
636            other => out.push(other),
637        }
638    }
639    out
640}
641
642/// Build the inner value of the libpq `options` connection parameter from a
643/// profile's options map.
644///
645/// Produces a space-separated string of `-c key=value` tokens in sorted-key
646/// order, with each value inner-escaped per [`escape_options_value`].
647/// Returns `None` when the map is empty so the caller can omit the fragment.
648pub(crate) fn build_options_string(options: &BTreeMap<String, String>) -> Option<String> {
649    if options.is_empty() {
650        return None;
651    }
652    let joined = options
653        .iter()
654        .map(|(k, v)| format!("-c {k}={}", escape_options_value(v)))
655        .collect::<Vec<_>>()
656        .join(" ");
657    Some(joined)
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663
664    #[mz_ore::test]
665    fn test_escape_options_value_plain() {
666        assert_eq!(escape_options_value("prod"), "prod");
667    }
668
669    #[mz_ore::test]
670    fn test_escape_options_value_space() {
671        assert_eq!(escape_options_value("prod cluster"), r"prod\ cluster");
672    }
673
674    #[mz_ore::test]
675    fn test_escape_options_value_backslash() {
676        assert_eq!(escape_options_value(r"a\b"), r"a\\b");
677    }
678
679    #[mz_ore::test]
680    fn test_escape_options_value_mixed() {
681        // Space then backslash
682        assert_eq!(escape_options_value(r"a \b"), r"a\ \\b");
683    }
684
685    #[mz_ore::test]
686    fn test_build_options_string_empty() {
687        let options: BTreeMap<String, String> = BTreeMap::new();
688        assert_eq!(build_options_string(&options), None);
689    }
690
691    #[mz_ore::test]
692    fn test_build_options_string_single() {
693        let mut options = BTreeMap::new();
694        options.insert("cluster".to_string(), "prod".to_string());
695        assert_eq!(
696            build_options_string(&options),
697            Some("-c cluster=prod".to_string())
698        );
699    }
700
701    #[mz_ore::test]
702    fn test_build_options_string_multiple_sorted() {
703        let mut options = BTreeMap::new();
704        // Insert in reverse order to verify BTreeMap iteration sorts keys.
705        options.insert("search_path".to_string(), "public".to_string());
706        options.insert("cluster".to_string(), "prod".to_string());
707        assert_eq!(
708            build_options_string(&options),
709            Some("-c cluster=prod -c search_path=public".to_string())
710        );
711    }
712
713    #[mz_ore::test]
714    fn test_build_options_string_escapes_value_space() {
715        let mut options = BTreeMap::new();
716        options.insert("cluster".to_string(), "prod cluster".to_string());
717        assert_eq!(
718            build_options_string(&options),
719            Some(r"-c cluster=prod\ cluster".to_string())
720        );
721    }
722
723    #[mz_ore::test]
724    fn test_build_options_string_escapes_value_backslash() {
725        let mut options = BTreeMap::new();
726        options.insert("cluster".to_string(), r"a\b".to_string());
727        assert_eq!(
728            build_options_string(&options),
729            Some(r"-c cluster=a\\b".to_string())
730        );
731    }
732
733    use std::path::Path;
734
735    #[mz_ore::test]
736    fn plan_disable_produces_notls() {
737        let spec = plan_connector(SslMode::Disable, None, "example.com", &[], |_| false).unwrap();
738        assert!(matches!(spec, ConnectorSpec::NoTls));
739    }
740
741    #[mz_ore::test]
742    fn plan_prefer_and_require_have_verify_none_and_no_ca() {
743        for mode in [SslMode::Prefer, SslMode::Require] {
744            let spec = plan_connector(mode, None, "example.com", &[], |_| true).unwrap();
745            match spec {
746                ConnectorSpec::Tls {
747                    verify,
748                    host_check,
749                    ca_source,
750                } => {
751                    assert_eq!(verify, openssl::ssl::SslVerifyMode::NONE);
752                    assert!(host_check.is_none());
753                    assert!(matches!(ca_source, CaSource::None));
754                }
755                ConnectorSpec::NoTls => panic!("expected Tls for {:?}, got NoTls", mode),
756            }
757        }
758    }
759
760    #[mz_ore::test]
761    fn plan_verify_ca_has_peer_verify_no_host_check() {
762        let spec = plan_connector(
763            SslMode::VerifyCa,
764            None,
765            "example.com",
766            &[Path::new("/does/not/exist"), Path::new("/tmp/fake-ca.pem")],
767            |p| p == Path::new("/tmp/fake-ca.pem"),
768        )
769        .unwrap();
770        match spec {
771            ConnectorSpec::Tls {
772                verify,
773                host_check,
774                ca_source,
775            } => {
776                assert_eq!(verify, openssl::ssl::SslVerifyMode::PEER);
777                assert!(host_check.is_none());
778                assert!(
779                    matches!(ca_source, CaSource::Hunted(p) if p == Path::new("/tmp/fake-ca.pem"))
780                );
781            }
782            ConnectorSpec::NoTls => panic!("expected Tls, got NoTls"),
783        }
784    }
785
786    #[mz_ore::test]
787    fn plan_verify_full_dns_host_check() {
788        let spec = plan_connector(
789            SslMode::VerifyFull,
790            None,
791            "example.com",
792            &[Path::new("/tmp/fake-ca.pem")],
793            |_| true,
794        )
795        .unwrap();
796        match spec {
797            ConnectorSpec::Tls {
798                host_check: Some(HostCheck::Dns(ref name)),
799                ..
800            } => assert_eq!(name, "example.com"),
801            other => panic!("expected Tls with Dns host check, got {:?}", other),
802        }
803    }
804
805    #[mz_ore::test]
806    fn plan_verify_full_ip_host_check() {
807        let spec = plan_connector(
808            SslMode::VerifyFull,
809            None,
810            "10.0.0.5",
811            &[Path::new("/tmp/fake-ca.pem")],
812            |_| true,
813        )
814        .unwrap();
815        match spec {
816            ConnectorSpec::Tls {
817                host_check: Some(HostCheck::Ip(ip)),
818                ..
819            } => assert_eq!(ip, "10.0.0.5".parse::<std::net::IpAddr>().unwrap()),
820            other => panic!("expected Tls with Ip host check, got {:?}", other),
821        }
822    }
823
824    #[mz_ore::test]
825    fn plan_explicit_sslrootcert_wins_over_hunt() {
826        let explicit = std::path::PathBuf::from("/my/ca.pem");
827        let spec = plan_connector(
828            SslMode::VerifyCa,
829            Some(&explicit),
830            "example.com",
831            &[Path::new("/tmp/should-be-ignored.pem")],
832            |p| p == explicit.as_path(),
833        )
834        .unwrap();
835        match spec {
836            ConnectorSpec::Tls {
837                ca_source: CaSource::Explicit(p),
838                ..
839            } => assert_eq!(p, explicit),
840            other => panic!("expected Tls/Explicit, got {:?}", other),
841        }
842    }
843
844    #[mz_ore::test]
845    fn plan_explicit_sslrootcert_missing_is_ca_not_found() {
846        let explicit = std::path::PathBuf::from("/no/such/file.pem");
847        let err = plan_connector(
848            SslMode::VerifyCa,
849            Some(&explicit),
850            "example.com",
851            &[Path::new("/tmp/fake-ca.pem")],
852            |_| false,
853        )
854        .unwrap_err();
855        assert!(matches!(err, ConnectionError::TlsCaNotFound));
856    }
857
858    #[mz_ore::test]
859    fn plan_no_ca_sources_at_all_falls_back_to_default_verify_paths() {
860        let spec = plan_connector(
861            SslMode::VerifyFull,
862            None,
863            "example.com",
864            &[Path::new("/nope1"), Path::new("/nope2")],
865            |_| false,
866        )
867        .unwrap();
868        match spec {
869            ConnectorSpec::Tls {
870                ca_source: CaSource::DefaultVerifyPaths,
871                ..
872            } => {}
873            other => panic!("expected Tls/DefaultVerifyPaths, got {:?}", other),
874        }
875    }
876
877    #[mz_ore::test]
878    fn build_disable_returns_notls() {
879        let connector = build_connector(ConnectorSpec::NoTls).unwrap();
880        assert!(matches!(connector, Connector::NoTls));
881    }
882
883    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
884    #[mz_ore::test]
885    fn build_prefer_returns_tls_no_ca_work() {
886        let connector = build_connector(ConnectorSpec::Tls {
887            verify: openssl::ssl::SslVerifyMode::NONE,
888            host_check: None,
889            ca_source: CaSource::None,
890        })
891        .unwrap();
892        assert!(matches!(connector, Connector::Tls(_)));
893    }
894
895    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
896    #[mz_ore::test]
897    fn build_explicit_missing_ca_returns_ca_not_found() {
898        let err = build_connector(ConnectorSpec::Tls {
899            verify: openssl::ssl::SslVerifyMode::PEER,
900            host_check: None,
901            ca_source: CaSource::Explicit(std::path::PathBuf::from("/absolutely/not/a/real/file")),
902        })
903        .unwrap_err();
904        assert!(matches!(err, ConnectionError::TlsCaNotFound));
905    }
906
907    #[mz_ore::test]
908    fn matches_tls_refused_tls_was_required() {
909        assert!(matches_tls_refused_message(
910            "some prefix: TLS was required but not provided"
911        ));
912    }
913
914    #[mz_ore::test]
915    fn matches_tls_refused_does_not_support_tls() {
916        assert!(matches_tls_refused_message(
917            "error: server does not support TLS"
918        ));
919    }
920
921    #[mz_ore::test]
922    fn matches_tls_refused_does_not_support_ssl() {
923        assert!(matches_tls_refused_message(
924            "error: server does not support SSL"
925        ));
926    }
927
928    #[mz_ore::test]
929    fn matches_tls_refused_unrelated_message() {
930        assert!(!matches_tls_refused_message("connection refused"));
931        assert!(!matches_tls_refused_message("database does not exist"));
932        assert!(!matches_tls_refused_message(""));
933    }
934}