mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow, bail};
18use iceberg::{Catalog, CatalogBuilder};
19use iceberg_catalog_rest::{
20    REST_CATALOG_PROP_URI, REST_CATALOG_PROP_WAREHOUSE, RestCatalogBuilder,
21};
22use itertools::Itertools;
23use mz_ccsr::tls::{Certificate, Identity};
24use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
25use mz_dyncfg::ConfigSet;
26use mz_kafka_util::client::{
27    BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
28};
29use mz_ore::assert_none;
30use mz_ore::error::ErrorExt;
31use mz_ore::future::{InTask, OreFutureExt};
32use mz_ore::netio::DUMMY_DNS_PORT;
33use mz_ore::netio::resolve_address;
34use mz_ore::num::NonNeg;
35use mz_postgres_util::tunnel::PostgresFlavor;
36use mz_proto::tokio_postgres::any_ssl_mode;
37use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
38use mz_repr::url::any_url;
39use mz_repr::{CatalogItemId, GlobalId};
40use mz_secrets::SecretsReader;
41use mz_ssh_util::keys::SshKeyPair;
42use mz_ssh_util::tunnel::SshTunnelConfig;
43use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
44use mz_tls_util::Pkcs12Archive;
45use mz_tracing::CloneableEnvFilter;
46use proptest::prelude::{Arbitrary, BoxedStrategy};
47use proptest::strategy::Strategy;
48use proptest_derive::Arbitrary;
49use rdkafka::ClientContext;
50use rdkafka::config::FromClientConfigAndContext;
51use rdkafka::consumer::{BaseConsumer, Consumer};
52use regex::Regex;
53use serde::{Deserialize, Deserializer, Serialize};
54use tokio::net;
55use tokio::runtime::Handle;
56use tokio_postgres::config::SslMode;
57use tracing::{debug, warn};
58use url::Url;
59
60use crate::AlterCompatible;
61use crate::configuration::StorageConfiguration;
62use crate::connections::aws::{
63    AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
64};
65use crate::connections::string_or_secret::StringOrSecret;
66use crate::controller::AlterError;
67use crate::dyncfgs::{
68    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
69    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM,
70};
71use crate::errors::{ContextCreationError, CsrConnectError};
72
73pub mod aws;
74pub mod inline;
75pub mod string_or_secret;
76
77include!(concat!(env!("OUT_DIR"), "/mz_storage_types.connections.rs"));
78
79const REST_CATALOG_PROP_SCOPE: &str = "scope";
80const REST_CATALOG_PROP_CREDENTIAL: &str = "credential";
81
82/// An extension trait for [`SecretsReader`]
83#[async_trait::async_trait]
84trait SecretsReaderExt {
85    /// `SecretsReader::read`, but optionally run in a task.
86    async fn read_in_task_if(
87        &self,
88        in_task: InTask,
89        id: CatalogItemId,
90    ) -> Result<Vec<u8>, anyhow::Error>;
91
92    /// `SecretsReader::read_string`, but optionally run in a task.
93    async fn read_string_in_task_if(
94        &self,
95        in_task: InTask,
96        id: CatalogItemId,
97    ) -> Result<String, anyhow::Error>;
98}
99
100#[async_trait::async_trait]
101impl SecretsReaderExt for Arc<dyn SecretsReader> {
102    async fn read_in_task_if(
103        &self,
104        in_task: InTask,
105        id: CatalogItemId,
106    ) -> Result<Vec<u8>, anyhow::Error> {
107        let sr = Arc::clone(self);
108        async move { sr.read(id).await }
109            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
110            .await
111    }
112    async fn read_string_in_task_if(
113        &self,
114        in_task: InTask,
115        id: CatalogItemId,
116    ) -> Result<String, anyhow::Error> {
117        let sr = Arc::clone(self);
118        async move { sr.read_string(id).await }
119            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
120            .await
121    }
122}
123
124/// Extra context to pass through when instantiating a connection for a source
125/// or sink.
126///
127/// Should be kept cheaply cloneable.
128#[derive(Debug, Clone)]
129pub struct ConnectionContext {
130    /// An opaque identifier for the environment in which this process is
131    /// running.
132    ///
133    /// The storage layer is intentionally unaware of the structure within this
134    /// identifier. Higher layers of the stack can make use of that structure,
135    /// but the storage layer should be oblivious to it.
136    pub environment_id: String,
137    /// The level for librdkafka's logs.
138    pub librdkafka_log_level: tracing::Level,
139    /// A prefix for an external ID to use for all AWS AssumeRole operations.
140    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
141    /// The ARN for a Materialize-controlled role to assume before assuming
142    /// a customer's requested role for an AWS connection.
143    pub aws_connection_role_arn: Option<String>,
144    /// A secrets reader.
145    pub secrets_reader: Arc<dyn SecretsReader>,
146    /// A cloud resource reader, if supported in this configuration.
147    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
148    /// A manager for SSH tunnels.
149    pub ssh_tunnel_manager: SshTunnelManager,
150}
151
152impl ConnectionContext {
153    /// Constructs a new connection context from command line arguments.
154    ///
155    /// **WARNING:** it is critical for security that the `aws_external_id` be
156    /// provided by the operator of the Materialize service (i.e., via a CLI
157    /// argument or environment variable) and not the end user of Materialize
158    /// (e.g., via a configuration option in a SQL statement). See
159    /// [`AwsExternalIdPrefix`] for details.
160    pub fn from_cli_args(
161        environment_id: String,
162        startup_log_level: &CloneableEnvFilter,
163        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
164        aws_connection_role_arn: Option<String>,
165        secrets_reader: Arc<dyn SecretsReader>,
166        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
167    ) -> ConnectionContext {
168        ConnectionContext {
169            environment_id,
170            librdkafka_log_level: mz_ore::tracing::crate_level(
171                &startup_log_level.clone().into(),
172                "librdkafka",
173            ),
174            aws_external_id_prefix,
175            aws_connection_role_arn,
176            secrets_reader,
177            cloud_resource_reader,
178            ssh_tunnel_manager: SshTunnelManager::default(),
179        }
180    }
181
182    /// Constructs a new connection context for usage in tests.
183    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
184        ConnectionContext {
185            environment_id: "test-environment-id".into(),
186            librdkafka_log_level: tracing::Level::INFO,
187            aws_external_id_prefix: Some(
188                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
189                    "test-aws-external-id-prefix",
190                )
191                .expect("infallible"),
192            ),
193            aws_connection_role_arn: Some(
194                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
195            ),
196            secrets_reader,
197            cloud_resource_reader: None,
198            ssh_tunnel_manager: SshTunnelManager::default(),
199        }
200    }
201}
202
203#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
204pub enum Connection<C: ConnectionAccess = InlinedConnection> {
205    Kafka(KafkaConnection<C>),
206    Csr(CsrConnection<C>),
207    Postgres(PostgresConnection<C>),
208    Ssh(SshConnection),
209    Aws(AwsConnection),
210    AwsPrivatelink(AwsPrivatelinkConnection),
211    MySql(MySqlConnection<C>),
212    SqlServer(SqlServerConnectionDetails<C>),
213    IcebergCatalog(IcebergCatalogConnection<C>),
214}
215
216impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
217    for Connection<ReferencedConnection>
218{
219    fn into_inline_connection(self, r: R) -> Connection {
220        match self {
221            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
222            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
223            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
224            Connection::Ssh(ssh) => Connection::Ssh(ssh),
225            Connection::Aws(aws) => Connection::Aws(aws),
226            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
227            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
228            Connection::SqlServer(sql_server) => {
229                Connection::SqlServer(sql_server.into_inline_connection(r))
230            }
231            Connection::IcebergCatalog(iceberg) => {
232                Connection::IcebergCatalog(iceberg.into_inline_connection(r))
233            }
234        }
235    }
236}
237
238impl<C: ConnectionAccess> Connection<C> {
239    /// Whether this connection should be validated by default on creation.
240    pub fn validate_by_default(&self) -> bool {
241        match self {
242            Connection::Kafka(conn) => conn.validate_by_default(),
243            Connection::Csr(conn) => conn.validate_by_default(),
244            Connection::Postgres(conn) => conn.validate_by_default(),
245            Connection::Ssh(conn) => conn.validate_by_default(),
246            Connection::Aws(conn) => conn.validate_by_default(),
247            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
248            Connection::MySql(conn) => conn.validate_by_default(),
249            Connection::SqlServer(conn) => conn.validate_by_default(),
250            Connection::IcebergCatalog(conn) => conn.validate_by_default(),
251        }
252    }
253}
254
255impl Connection<InlinedConnection> {
256    /// Validates this connection by attempting to connect to the upstream system.
257    pub async fn validate(
258        &self,
259        id: CatalogItemId,
260        storage_configuration: &StorageConfiguration,
261    ) -> Result<(), ConnectionValidationError> {
262        match self {
263            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
264            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
265            Connection::Postgres(conn) => conn.validate(id, storage_configuration).await?,
266            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
267            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
268            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
269            Connection::MySql(conn) => conn.validate(id, storage_configuration).await?,
270            Connection::SqlServer(conn) => conn.validate(id, storage_configuration).await?,
271            Connection::IcebergCatalog(conn) => conn.validate(id, storage_configuration).await?,
272        }
273        Ok(())
274    }
275
276    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
277        match self {
278            Self::Kafka(conn) => conn,
279            o => unreachable!("{o:?} is not a Kafka connection"),
280        }
281    }
282
283    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
284        match self {
285            Self::Postgres(conn) => conn,
286            o => unreachable!("{o:?} is not a Postgres connection"),
287        }
288    }
289
290    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
291        match self {
292            Self::MySql(conn) => conn,
293            o => unreachable!("{o:?} is not a MySQL connection"),
294        }
295    }
296
297    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
298        match self {
299            Self::SqlServer(conn) => conn,
300            o => unreachable!("{o:?} is not a SQL Server connection"),
301        }
302    }
303
304    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
305        match self {
306            Self::Aws(conn) => conn,
307            o => unreachable!("{o:?} is not an AWS connection"),
308        }
309    }
310
311    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
312        match self {
313            Self::Ssh(conn) => conn,
314            o => unreachable!("{o:?} is not an SSH connection"),
315        }
316    }
317
318    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
319        match self {
320            Self::Csr(conn) => conn,
321            o => unreachable!("{o:?} is not a Kafka connection"),
322        }
323    }
324}
325
326/// An error returned by [`Connection::validate`].
327#[derive(thiserror::Error, Debug)]
328pub enum ConnectionValidationError {
329    #[error(transparent)]
330    Aws(#[from] AwsConnectionValidationError),
331    #[error("{}", .0.display_with_causes())]
332    Other(#[from] anyhow::Error),
333}
334
335impl ConnectionValidationError {
336    /// Reports additional details about the error, if any are available.
337    pub fn detail(&self) -> Option<String> {
338        match self {
339            ConnectionValidationError::Aws(e) => e.detail(),
340            ConnectionValidationError::Other(_) => None,
341        }
342    }
343
344    /// Reports a hint for the user about how the error could be fixed.
345    pub fn hint(&self) -> Option<String> {
346        match self {
347            ConnectionValidationError::Aws(e) => e.hint(),
348            ConnectionValidationError::Other(_) => None,
349        }
350    }
351}
352
353impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
354    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
355        match (self, other) {
356            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
357            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
358            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
359            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
360            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
361            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
362            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
363            _ => {
364                tracing::warn!(
365                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
366                    self,
367                    other
368                );
369                Err(AlterError { id })
370            }
371        }
372    }
373}
374
375#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
376pub struct RestIcebergCatalog {
377    /// For REST catalogs, the oauth2 credential in a `CLIENT_ID:CLIENT_SECRET` format
378    pub credential: StringOrSecret,
379    /// The oauth2 scope for REST catalogs
380    pub scope: Option<String>,
381    /// The warehouse for REST catalogs
382    pub warehouse: Option<String>,
383}
384
385#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
386pub struct S3TablesRestIcebergCatalog<C: ConnectionAccess = InlinedConnection> {
387    /// The AWS connection details, for s3tables
388    pub aws_connection: AwsConnectionReference<C>,
389    /// The warehouse for s3tables
390    pub warehouse: String,
391}
392
393impl<R: ConnectionResolver> IntoInlineConnection<S3TablesRestIcebergCatalog, R>
394    for S3TablesRestIcebergCatalog<ReferencedConnection>
395{
396    fn into_inline_connection(self, r: R) -> S3TablesRestIcebergCatalog {
397        S3TablesRestIcebergCatalog {
398            aws_connection: self.aws_connection.into_inline_connection(&r),
399            warehouse: self.warehouse,
400        }
401    }
402}
403
404#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
405pub enum IcebergCatalogType {
406    Rest,
407    S3TablesRest,
408}
409
410#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
411pub enum IcebergCatalogImpl<C: ConnectionAccess = InlinedConnection> {
412    Rest(RestIcebergCatalog),
413    S3TablesRest(S3TablesRestIcebergCatalog<C>),
414}
415
416impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogImpl, R>
417    for IcebergCatalogImpl<ReferencedConnection>
418{
419    fn into_inline_connection(self, r: R) -> IcebergCatalogImpl {
420        match self {
421            IcebergCatalogImpl::Rest(rest) => IcebergCatalogImpl::Rest(rest),
422            IcebergCatalogImpl::S3TablesRest(s3tables) => {
423                IcebergCatalogImpl::S3TablesRest(s3tables.into_inline_connection(r))
424            }
425        }
426    }
427}
428
429#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
430pub struct IcebergCatalogConnection<C: ConnectionAccess = InlinedConnection> {
431    /// The catalog impl impl of that catalog
432    pub catalog: IcebergCatalogImpl<C>,
433    /// Where the catalog is located
434    pub uri: reqwest::Url,
435}
436
437impl<C: ConnectionAccess> Arbitrary for IcebergCatalogConnection<C> {
438    type Parameters = ();
439    type Strategy = BoxedStrategy<Self>;
440
441    fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
442        (
443            IcebergCatalogImpl::arbitrary(),
444            proptest::sample::select(vec![
445                reqwest::Url::parse("https://example.com/catalog").unwrap(),
446                reqwest::Url::parse("https://catalog.example.org").unwrap(),
447            ]),
448        )
449            .prop_map(|(catalog, uri)| IcebergCatalogConnection { catalog, uri })
450            .boxed()
451    }
452}
453
454impl AlterCompatible for IcebergCatalogConnection {
455    fn alter_compatible(&self, id: GlobalId, _other: &Self) -> Result<(), AlterError> {
456        Err(AlterError { id })
457    }
458}
459
460impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogConnection, R>
461    for IcebergCatalogConnection<ReferencedConnection>
462{
463    fn into_inline_connection(self, r: R) -> IcebergCatalogConnection {
464        IcebergCatalogConnection {
465            catalog: self.catalog.into_inline_connection(&r),
466            uri: self.uri,
467        }
468    }
469}
470
471impl<C: ConnectionAccess> IcebergCatalogConnection<C> {
472    fn validate_by_default(&self) -> bool {
473        true
474    }
475}
476
477impl IcebergCatalogConnection<InlinedConnection> {
478    async fn connect(
479        &self,
480        storage_configuration: &StorageConfiguration,
481        in_task: InTask,
482    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
483        match self.catalog {
484            IcebergCatalogImpl::S3TablesRest(ref s3tables) => {
485                self.connect_s3tables(s3tables, storage_configuration, in_task)
486                    .await
487            }
488            IcebergCatalogImpl::Rest(ref rest) => {
489                self.connect_rest(rest, storage_configuration, in_task)
490                    .await
491            }
492        }
493    }
494
495    async fn connect_s3tables(
496        &self,
497        s3tables: &S3TablesRestIcebergCatalog,
498        storage_configuration: &StorageConfiguration,
499        in_task: InTask,
500    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
501        let mut props = BTreeMap::from([(
502            REST_CATALOG_PROP_URI.to_string(),
503            self.uri.to_string().clone(),
504        )]);
505
506        let aws_ref = &s3tables.aws_connection;
507        let aws_config = aws_ref
508            .connection
509            .load_sdk_config(
510                &storage_configuration.connection_context,
511                aws_ref.connection_id,
512                in_task,
513            )
514            .await?;
515
516        props.insert(
517            REST_CATALOG_PROP_WAREHOUSE.to_string(),
518            s3tables.warehouse.clone(),
519        );
520
521        let catalog = RestCatalogBuilder::default()
522            .with_aws_client(aws_config)
523            .load("IcebergCatalog", props.into_iter().collect())
524            .await
525            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
526        Ok(Arc::new(catalog))
527    }
528
529    async fn connect_rest(
530        &self,
531        rest: &RestIcebergCatalog,
532        storage_configuration: &StorageConfiguration,
533        in_task: InTask,
534    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
535        let mut props = BTreeMap::from([(
536            REST_CATALOG_PROP_URI.to_string(),
537            self.uri.to_string().clone(),
538        )]);
539
540        if let Some(warehouse) = &rest.warehouse {
541            props.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
542        }
543
544        let credential = rest
545            .credential
546            .get_string(
547                in_task,
548                &storage_configuration.connection_context.secrets_reader,
549            )
550            .await
551            .map_err(|e| anyhow!("failed to read Iceberg catalog credential: {e}"))?;
552        props.insert(REST_CATALOG_PROP_CREDENTIAL.to_string(), credential);
553
554        if let Some(scope) = &rest.scope {
555            props.insert(REST_CATALOG_PROP_SCOPE.to_string(), scope.clone());
556        }
557
558        let catalog = RestCatalogBuilder::default()
559            .load("IcebergCatalog", props.into_iter().collect())
560            .await
561            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
562        Ok(Arc::new(catalog))
563    }
564
565    async fn validate(
566        &self,
567        _id: CatalogItemId,
568        storage_configuration: &StorageConfiguration,
569    ) -> Result<(), ConnectionValidationError> {
570        let catalog = self
571            .connect(storage_configuration, InTask::No)
572            .await
573            .map_err(|e| {
574                ConnectionValidationError::Other(anyhow!("failed to connect to catalog: {e}"))
575            })?;
576
577        // If we can list namespaces, the connection is valid.
578        catalog.list_namespaces(None).await.map_err(|e| {
579            ConnectionValidationError::Other(anyhow!("failed to list namespaces: {e}"))
580        })?;
581
582        Ok(())
583    }
584}
585
586#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
587pub struct AwsPrivatelinkConnection {
588    pub service_name: String,
589    pub availability_zones: Vec<String>,
590}
591
592impl AlterCompatible for AwsPrivatelinkConnection {
593    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
594        // Every element of the AwsPrivatelinkConnection connection is configurable.
595        Ok(())
596    }
597}
598
599#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
600pub struct KafkaTlsConfig {
601    pub identity: Option<TlsIdentity>,
602    pub root_cert: Option<StringOrSecret>,
603}
604
605#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
606pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
607    pub mechanism: String,
608    pub username: StringOrSecret,
609    pub password: Option<CatalogItemId>,
610    pub aws: Option<AwsConnectionReference<C>>,
611}
612
613impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
614    for KafkaSaslConfig<ReferencedConnection>
615{
616    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
617        KafkaSaslConfig {
618            mechanism: self.mechanism,
619            username: self.username,
620            password: self.password,
621            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
622        }
623    }
624}
625
626/// Specifies a Kafka broker in a [`KafkaConnection`].
627#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
628pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
629    /// The address of the Kafka broker.
630    pub address: String,
631    /// An optional tunnel to use when connecting to the broker.
632    pub tunnel: Tunnel<C>,
633}
634
635impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
636    for KafkaBroker<ReferencedConnection>
637{
638    fn into_inline_connection(self, r: R) -> KafkaBroker {
639        let KafkaBroker { address, tunnel } = self;
640        KafkaBroker {
641            address,
642            tunnel: tunnel.into_inline_connection(r),
643        }
644    }
645}
646
647impl RustType<ProtoKafkaBroker> for KafkaBroker {
648    fn into_proto(&self) -> ProtoKafkaBroker {
649        ProtoKafkaBroker {
650            address: self.address.into_proto(),
651            tunnel: Some(self.tunnel.into_proto()),
652        }
653    }
654
655    fn from_proto(proto: ProtoKafkaBroker) -> Result<Self, TryFromProtoError> {
656        Ok(KafkaBroker {
657            address: proto.address.into_rust()?,
658            tunnel: proto
659                .tunnel
660                .into_rust_if_some("ProtoKafkaConnection::tunnel")?,
661        })
662    }
663}
664
665#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
666pub struct KafkaTopicOptions {
667    /// The replication factor for the topic.
668    /// If `None`, the broker default will be used.
669    pub replication_factor: Option<NonNeg<i32>>,
670    /// The number of partitions to create.
671    /// If `None`, the broker default will be used.
672    pub partition_count: Option<NonNeg<i32>>,
673    /// The initial configuration parameters for the topic.
674    pub topic_config: BTreeMap<String, String>,
675}
676
677impl RustType<ProtoKafkaTopicOptions> for KafkaTopicOptions {
678    fn into_proto(&self) -> ProtoKafkaTopicOptions {
679        ProtoKafkaTopicOptions {
680            replication_factor: self.replication_factor.map(|f| *f),
681            partition_count: self.partition_count.map(|f| *f),
682            topic_config: self.topic_config.clone(),
683        }
684    }
685
686    fn from_proto(proto: ProtoKafkaTopicOptions) -> Result<Self, TryFromProtoError> {
687        Ok(KafkaTopicOptions {
688            replication_factor: proto.replication_factor.map(NonNeg::try_from).transpose()?,
689            partition_count: proto.partition_count.map(NonNeg::try_from).transpose()?,
690            topic_config: proto.topic_config,
691        })
692    }
693}
694
695#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
696pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
697    pub brokers: Vec<KafkaBroker<C>>,
698    /// A tunnel through which to route traffic,
699    /// that can be overridden for individual brokers
700    /// in `brokers`.
701    pub default_tunnel: Tunnel<C>,
702    pub progress_topic: Option<String>,
703    pub progress_topic_options: KafkaTopicOptions,
704    pub options: BTreeMap<String, StringOrSecret>,
705    pub tls: Option<KafkaTlsConfig>,
706    pub sasl: Option<KafkaSaslConfig<C>>,
707}
708
709impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
710    for KafkaConnection<ReferencedConnection>
711{
712    fn into_inline_connection(self, r: R) -> KafkaConnection {
713        let KafkaConnection {
714            brokers,
715            progress_topic,
716            progress_topic_options,
717            default_tunnel,
718            options,
719            tls,
720            sasl,
721        } = self;
722
723        let brokers = brokers
724            .into_iter()
725            .map(|broker| broker.into_inline_connection(&r))
726            .collect();
727
728        KafkaConnection {
729            brokers,
730            progress_topic,
731            progress_topic_options,
732            default_tunnel: default_tunnel.into_inline_connection(&r),
733            options,
734            tls,
735            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
736        }
737    }
738}
739
740impl<C: ConnectionAccess> KafkaConnection<C> {
741    /// Returns the name of the progress topic to use for the connection.
742    ///
743    /// The caller is responsible for providing the connection ID as it is not
744    /// known to `KafkaConnection`.
745    pub fn progress_topic(
746        &self,
747        connection_context: &ConnectionContext,
748        connection_id: CatalogItemId,
749    ) -> Cow<'_, str> {
750        if let Some(progress_topic) = &self.progress_topic {
751            Cow::Borrowed(progress_topic)
752        } else {
753            Cow::Owned(format!(
754                "_materialize-progress-{}-{}",
755                connection_context.environment_id, connection_id,
756            ))
757        }
758    }
759
760    fn validate_by_default(&self) -> bool {
761        true
762    }
763}
764
765impl KafkaConnection {
766    /// Generates a string that can be used as the base for a configuration ID
767    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
768    /// or sink.
769    pub fn id_base(
770        connection_context: &ConnectionContext,
771        connection_id: CatalogItemId,
772        object_id: GlobalId,
773    ) -> String {
774        format!(
775            "materialize-{}-{}-{}",
776            connection_context.environment_id, connection_id, object_id,
777        )
778    }
779
780    /// Enriches the provided `client_id` according to any enrichment rules in
781    /// the `kafka_client_id_enrichment_rules` configuration parameter.
782    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
783        #[derive(Debug, Deserialize)]
784        struct EnrichmentRule {
785            #[serde(deserialize_with = "deserialize_regex")]
786            pattern: Regex,
787            payload: String,
788        }
789
790        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
791        where
792            D: Deserializer<'de>,
793        {
794            let buf = String::deserialize(deserializer)?;
795            Regex::new(&buf).map_err(serde::de::Error::custom)
796        }
797
798        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
799        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
800            Ok(rules) => rules,
801            Err(e) => {
802                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
803                return;
804            }
805        };
806
807        // Check every rule against every broker. Rules are matched in the order
808        // that they are specified. It is usually a configuration error if
809        // multiple rules match the same list of Kafka brokers, but we
810        // nonetheless want to provide well defined semantics.
811        debug!(?self.brokers, "evaluating client ID enrichment rules");
812        for rule in rules {
813            let is_match = self
814                .brokers
815                .iter()
816                .any(|b| rule.pattern.is_match(&b.address));
817            debug!(?rule, is_match, "evaluated client ID enrichment rule");
818            if is_match {
819                client_id.push('-');
820                client_id.push_str(&rule.payload);
821            }
822        }
823    }
824
825    /// Creates a Kafka client for the connection.
826    pub async fn create_with_context<C, T>(
827        &self,
828        storage_configuration: &StorageConfiguration,
829        context: C,
830        extra_options: &BTreeMap<&str, String>,
831        in_task: InTask,
832    ) -> Result<T, ContextCreationError>
833    where
834        C: ClientContext,
835        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
836    {
837        let mut options = self.options.clone();
838
839        // Ensure that Kafka topics are *not* automatically created when
840        // consuming, producing, or fetching metadata for a topic. This ensures
841        // that we don't accidentally create topics with the wrong number of
842        // partitions.
843        options.insert("allow.auto.create.topics".into(), "false".into());
844
845        let brokers = match &self.default_tunnel {
846            Tunnel::AwsPrivatelink(t) => {
847                assert!(&self.brokers.is_empty());
848
849                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
850                    .get(storage_configuration.config_set());
851                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
852
853                // When using a default privatelink tunnel broker/brokers cannot be specified
854                // instead the tunnel connection_id and port are used for the initial connection.
855                format!(
856                    "{}:{}",
857                    vpc_endpoint_host(
858                        t.connection_id,
859                        None, // Default tunnel does not support availability zones.
860                    ),
861                    t.port.unwrap_or(9092)
862                )
863            }
864            _ => self.brokers.iter().map(|b| &b.address).join(","),
865        };
866        options.insert("bootstrap.servers".into(), brokers.into());
867        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
868            (false, false) => "PLAINTEXT",
869            (true, false) => "SSL",
870            (false, true) => "SASL_PLAINTEXT",
871            (true, true) => "SASL_SSL",
872        };
873        options.insert("security.protocol".into(), security_protocol.into());
874        if let Some(tls) = &self.tls {
875            if let Some(root_cert) = &tls.root_cert {
876                options.insert("ssl.ca.pem".into(), root_cert.clone());
877            }
878            if let Some(identity) = &tls.identity {
879                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
880                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
881            }
882        }
883        if let Some(sasl) = &self.sasl {
884            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
885            options.insert("sasl.username".into(), sasl.username.clone());
886            if let Some(password) = sasl.password {
887                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
888            }
889        }
890
891        let mut config = mz_kafka_util::client::create_new_client_config(
892            storage_configuration
893                .connection_context
894                .librdkafka_log_level,
895            storage_configuration.parameters.kafka_timeout_config,
896        );
897        for (k, v) in options {
898            config.set(
899                k,
900                v.get_string(
901                    in_task,
902                    &storage_configuration.connection_context.secrets_reader,
903                )
904                .await
905                .context("reading kafka secret")?,
906            );
907        }
908        for (k, v) in extra_options {
909            config.set(*k, v);
910        }
911
912        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
913            None => None,
914            Some(aws) => Some(
915                aws.connection
916                    .load_sdk_config(
917                        &storage_configuration.connection_context,
918                        aws.connection_id,
919                        in_task,
920                    )
921                    .await?,
922            ),
923        };
924
925        // TODO(roshan): Implement enforcement of external address validation once
926        // rdkafka client has been updated to support providing multiple resolved
927        // addresses for brokers
928        let mut context = TunnelingClientContext::new(
929            context,
930            Handle::current(),
931            storage_configuration
932                .connection_context
933                .ssh_tunnel_manager
934                .clone(),
935            storage_configuration.parameters.ssh_timeout_config,
936            aws_config,
937            in_task,
938        );
939
940        match &self.default_tunnel {
941            Tunnel::Direct => {
942                // By default, don't offer a default override for broker address lookup.
943            }
944            Tunnel::AwsPrivatelink(pl) => {
945                context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
946                    pl.connection_id,
947                    None, // Default tunnel does not support availability zones.
948                )));
949            }
950            Tunnel::Ssh(ssh_tunnel) => {
951                let secret = storage_configuration
952                    .connection_context
953                    .secrets_reader
954                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
955                    .await?;
956                let key_pair = SshKeyPair::from_bytes(&secret)?;
957
958                // Ensure any ssh-bastion address we connect to is resolved to an external address.
959                let resolved = resolve_address(
960                    &ssh_tunnel.connection.host,
961                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
962                )
963                .await?;
964                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
965                    host: resolved
966                        .iter()
967                        .map(|a| a.to_string())
968                        .collect::<BTreeSet<_>>(),
969                    port: ssh_tunnel.connection.port,
970                    user: ssh_tunnel.connection.user.clone(),
971                    key_pair,
972                }));
973            }
974        }
975
976        for broker in &self.brokers {
977            let mut addr_parts = broker.address.splitn(2, ':');
978            let addr = BrokerAddr {
979                host: addr_parts
980                    .next()
981                    .context("BROKER is not address:port")?
982                    .into(),
983                port: addr_parts
984                    .next()
985                    .unwrap_or("9092")
986                    .parse()
987                    .context("parsing BROKER port")?,
988            };
989            match &broker.tunnel {
990                Tunnel::Direct => {
991                    // By default, don't override broker address lookup.
992                    //
993                    // N.B.
994                    //
995                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
996                    // we avoid doing because:
997                    // - Its not necessary.
998                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
999                    // in the `TunnelingClientContext`.
1000                }
1001                Tunnel::AwsPrivatelink(aws_privatelink) => {
1002                    let host = mz_cloud_resources::vpc_endpoint_host(
1003                        aws_privatelink.connection_id,
1004                        aws_privatelink.availability_zone.as_deref(),
1005                    );
1006                    let port = aws_privatelink.port;
1007                    context.add_broker_rewrite(
1008                        addr,
1009                        BrokerRewrite {
1010                            host: host.clone(),
1011                            port,
1012                        },
1013                    );
1014                }
1015                Tunnel::Ssh(ssh_tunnel) => {
1016                    // Ensure any SSH bastion address we connect to is resolved to an external address.
1017                    let ssh_host_resolved = resolve_address(
1018                        &ssh_tunnel.connection.host,
1019                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1020                    )
1021                    .await?;
1022                    context
1023                        .add_ssh_tunnel(
1024                            addr,
1025                            SshTunnelConfig {
1026                                host: ssh_host_resolved
1027                                    .iter()
1028                                    .map(|a| a.to_string())
1029                                    .collect::<BTreeSet<_>>(),
1030                                port: ssh_tunnel.connection.port,
1031                                user: ssh_tunnel.connection.user.clone(),
1032                                key_pair: SshKeyPair::from_bytes(
1033                                    &storage_configuration
1034                                        .connection_context
1035                                        .secrets_reader
1036                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
1037                                        .await?,
1038                                )?,
1039                            },
1040                        )
1041                        .await
1042                        .map_err(ContextCreationError::Ssh)?;
1043                }
1044            }
1045        }
1046
1047        Ok(config.create_with_context(context)?)
1048    }
1049
1050    async fn validate(
1051        &self,
1052        _id: CatalogItemId,
1053        storage_configuration: &StorageConfiguration,
1054    ) -> Result<(), anyhow::Error> {
1055        let (context, error_rx) = MzClientContext::with_errors();
1056        let consumer: BaseConsumer<_> = self
1057            .create_with_context(
1058                storage_configuration,
1059                context,
1060                &BTreeMap::new(),
1061                // We are in a normal tokio context during validation, already.
1062                InTask::No,
1063            )
1064            .await?;
1065        let consumer = Arc::new(consumer);
1066
1067        let timeout = storage_configuration
1068            .parameters
1069            .kafka_timeout_config
1070            .fetch_metadata_timeout;
1071
1072        // librdkafka doesn't expose an API for determining whether a connection to
1073        // the Kafka cluster has been successfully established. So we make a
1074        // metadata request, though we don't care about the results, so that we can
1075        // report any errors making that request. If the request succeeds, we know
1076        // we were able to contact at least one broker, and that's a good proxy for
1077        // being able to contact all the brokers in the cluster.
1078        //
1079        // The downside of this approach is it produces a generic error message like
1080        // "metadata fetch error" with no additional details. The real networking
1081        // error is buried in the librdkafka logs, which are not visible to users.
1082        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
1083            let consumer = Arc::clone(&consumer);
1084            move || consumer.fetch_metadata(None, timeout)
1085        })
1086        .await?;
1087        match result {
1088            Ok(_) => Ok(()),
1089            // The error returned by `fetch_metadata` does not provide any details which makes for
1090            // a crappy user facing error message. For this reason we attempt to grab a better
1091            // error message from the client context, which should contain any error logs emitted
1092            // by librdkafka, and fallback to the generic error if there is nothing there.
1093            Err(err) => {
1094                // Multiple errors might have been logged during this validation but some are more
1095                // relevant than others. Specifically, we prefer non-internal errors over internal
1096                // errors since those give much more useful information to the users.
1097                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
1098                    MzKafkaError::Internal(_) => new,
1099                    _ => cur,
1100                });
1101
1102                // Don't drop the consumer until after we've drained the errors
1103                // channel. Dropping the consumer can introduce spurious errors.
1104                // See database-issues#7432.
1105                drop(consumer);
1106
1107                match main_err {
1108                    Some(err) => Err(err.into()),
1109                    None => Err(err.into()),
1110                }
1111            }
1112        }
1113    }
1114}
1115
1116impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
1117    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1118        let KafkaConnection {
1119            brokers: _,
1120            default_tunnel: _,
1121            progress_topic,
1122            progress_topic_options,
1123            options: _,
1124            tls: _,
1125            sasl: _,
1126        } = self;
1127
1128        let compatibility_checks = [
1129            (progress_topic == &other.progress_topic, "progress_topic"),
1130            (
1131                progress_topic_options == &other.progress_topic_options,
1132                "progress_topic_options",
1133            ),
1134        ];
1135
1136        for (compatible, field) in compatibility_checks {
1137            if !compatible {
1138                tracing::warn!(
1139                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1140                    self,
1141                    other
1142                );
1143
1144                return Err(AlterError { id });
1145            }
1146        }
1147
1148        Ok(())
1149    }
1150}
1151
1152impl RustType<ProtoKafkaConnectionTlsConfig> for KafkaTlsConfig {
1153    fn into_proto(&self) -> ProtoKafkaConnectionTlsConfig {
1154        ProtoKafkaConnectionTlsConfig {
1155            identity: self.identity.into_proto(),
1156            root_cert: self.root_cert.into_proto(),
1157        }
1158    }
1159
1160    fn from_proto(proto: ProtoKafkaConnectionTlsConfig) -> Result<Self, TryFromProtoError> {
1161        Ok(KafkaTlsConfig {
1162            root_cert: proto.root_cert.into_rust()?,
1163            identity: proto.identity.into_rust()?,
1164        })
1165    }
1166}
1167
1168impl RustType<ProtoKafkaConnectionSaslConfig> for KafkaSaslConfig {
1169    fn into_proto(&self) -> ProtoKafkaConnectionSaslConfig {
1170        ProtoKafkaConnectionSaslConfig {
1171            mechanism: self.mechanism.into_proto(),
1172            username: Some(self.username.into_proto()),
1173            password: self.password.into_proto(),
1174            aws: self.aws.into_proto(),
1175        }
1176    }
1177
1178    fn from_proto(proto: ProtoKafkaConnectionSaslConfig) -> Result<Self, TryFromProtoError> {
1179        Ok(KafkaSaslConfig {
1180            mechanism: proto.mechanism,
1181            username: proto
1182                .username
1183                .into_rust_if_some("ProtoKafkaConnectionSaslConfig::username")?,
1184            password: proto.password.into_rust()?,
1185            aws: proto.aws.into_rust()?,
1186        })
1187    }
1188}
1189
1190impl RustType<ProtoKafkaConnection> for KafkaConnection {
1191    fn into_proto(&self) -> ProtoKafkaConnection {
1192        ProtoKafkaConnection {
1193            brokers: self.brokers.into_proto(),
1194            default_tunnel: Some(self.default_tunnel.into_proto()),
1195            progress_topic: self.progress_topic.into_proto(),
1196            progress_topic_options: Some(self.progress_topic_options.into_proto()),
1197            options: self
1198                .options
1199                .iter()
1200                .map(|(k, v)| (k.clone(), v.into_proto()))
1201                .collect(),
1202            tls: self.tls.into_proto(),
1203            sasl: self.sasl.into_proto(),
1204        }
1205    }
1206
1207    fn from_proto(proto: ProtoKafkaConnection) -> Result<Self, TryFromProtoError> {
1208        Ok(KafkaConnection {
1209            brokers: proto.brokers.into_rust()?,
1210            default_tunnel: proto
1211                .default_tunnel
1212                .into_rust_if_some("ProtoKafkaConnection::default_tunnel")?,
1213            progress_topic: proto.progress_topic,
1214            progress_topic_options: match proto.progress_topic_options {
1215                Some(progress_topic_options) => progress_topic_options.into_rust()?,
1216                None => Default::default(),
1217            },
1218            options: proto
1219                .options
1220                .into_iter()
1221                .map(|(k, v)| StringOrSecret::from_proto(v).map(|v| (k, v)))
1222                .collect::<Result<_, _>>()?,
1223            tls: proto.tls.into_rust()?,
1224            sasl: proto.sasl.into_rust()?,
1225        })
1226    }
1227}
1228
1229/// A connection to a Confluent Schema Registry.
1230#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1231pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1232    /// The URL of the schema registry.
1233    #[proptest(strategy = "any_url()")]
1234    pub url: Url,
1235    /// Trusted root TLS certificate in PEM format.
1236    pub tls_root_cert: Option<StringOrSecret>,
1237    /// An optional TLS client certificate for authentication with the schema
1238    /// registry.
1239    pub tls_identity: Option<TlsIdentity>,
1240    /// Optional HTTP authentication credentials for the schema registry.
1241    pub http_auth: Option<CsrConnectionHttpAuth>,
1242    /// A tunnel through which to route traffic.
1243    pub tunnel: Tunnel<C>,
1244}
1245
1246impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1247    for CsrConnection<ReferencedConnection>
1248{
1249    fn into_inline_connection(self, r: R) -> CsrConnection {
1250        let CsrConnection {
1251            url,
1252            tls_root_cert,
1253            tls_identity,
1254            http_auth,
1255            tunnel,
1256        } = self;
1257        CsrConnection {
1258            url,
1259            tls_root_cert,
1260            tls_identity,
1261            http_auth,
1262            tunnel: tunnel.into_inline_connection(r),
1263        }
1264    }
1265}
1266
1267impl<C: ConnectionAccess> CsrConnection<C> {
1268    fn validate_by_default(&self) -> bool {
1269        true
1270    }
1271}
1272
1273impl CsrConnection {
1274    /// Constructs a schema registry client from the connection.
1275    pub async fn connect(
1276        &self,
1277        storage_configuration: &StorageConfiguration,
1278        in_task: InTask,
1279    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1280        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1281        if let Some(root_cert) = &self.tls_root_cert {
1282            let root_cert = root_cert
1283                .get_string(
1284                    in_task,
1285                    &storage_configuration.connection_context.secrets_reader,
1286                )
1287                .await?;
1288            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1289            client_config = client_config.add_root_certificate(root_cert);
1290        }
1291
1292        if let Some(tls_identity) = &self.tls_identity {
1293            let key = &storage_configuration
1294                .connection_context
1295                .secrets_reader
1296                .read_string_in_task_if(in_task, tls_identity.key)
1297                .await?;
1298            let cert = tls_identity
1299                .cert
1300                .get_string(
1301                    in_task,
1302                    &storage_configuration.connection_context.secrets_reader,
1303                )
1304                .await?;
1305            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1306            client_config = client_config.identity(ident);
1307        }
1308
1309        if let Some(http_auth) = &self.http_auth {
1310            let username = http_auth
1311                .username
1312                .get_string(
1313                    in_task,
1314                    &storage_configuration.connection_context.secrets_reader,
1315                )
1316                .await?;
1317            let password = match http_auth.password {
1318                None => None,
1319                Some(password) => Some(
1320                    storage_configuration
1321                        .connection_context
1322                        .secrets_reader
1323                        .read_string_in_task_if(in_task, password)
1324                        .await?,
1325                ),
1326            };
1327            client_config = client_config.auth(username, password);
1328        }
1329
1330        // TODO: use types to enforce that the URL has a string hostname.
1331        let host = self
1332            .url
1333            .host_str()
1334            .ok_or_else(|| anyhow!("url missing host"))?;
1335        match &self.tunnel {
1336            Tunnel::Direct => {
1337                // Ensure any host we connect to is resolved to an external address.
1338                let resolved = resolve_address(
1339                    host,
1340                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1341                )
1342                .await?;
1343                client_config = client_config.resolve_to_addrs(
1344                    host,
1345                    &resolved
1346                        .iter()
1347                        .map(|addr| SocketAddr::new(*addr, DUMMY_DNS_PORT))
1348                        .collect::<Vec<_>>(),
1349                )
1350            }
1351            Tunnel::Ssh(ssh_tunnel) => {
1352                let ssh_tunnel = ssh_tunnel
1353                    .connect(
1354                        storage_configuration,
1355                        host,
1356                        // Default to the default http port, but this
1357                        // could default to 8081...
1358                        self.url.port().unwrap_or(80),
1359                        in_task,
1360                    )
1361                    .await
1362                    .map_err(CsrConnectError::Ssh)?;
1363
1364                // Carefully inject the SSH tunnel into the client
1365                // configuration. This is delicate because we need TLS
1366                // verification to continue to use the remote hostname rather
1367                // than the tunnel hostname.
1368
1369                client_config = client_config
1370                    // `resolve_to_addrs` allows us to rewrite the hostname
1371                    // at the DNS level, which means the TCP connection is
1372                    // correctly routed through the tunnel, but TLS verification
1373                    // is still performed against the remote hostname.
1374                    // Unfortunately the port here is ignored...
1375                    .resolve_to_addrs(
1376                        host,
1377                        &[SocketAddr::new(
1378                            ssh_tunnel.local_addr().ip(),
1379                            DUMMY_DNS_PORT,
1380                        )],
1381                    )
1382                    // ...so we also dynamically rewrite the URL to use the
1383                    // current port for the SSH tunnel.
1384                    //
1385                    // WARNING: this is brittle, because we only dynamically
1386                    // update the client configuration with the tunnel *port*,
1387                    // and not the hostname This works fine in practice, because
1388                    // only the SSH tunnel port will change if the tunnel fails
1389                    // and has to be restarted (the hostname is always
1390                    // 127.0.0.1)--but this is an an implementation detail of
1391                    // the SSH tunnel code that we're relying on.
1392                    .dynamic_url({
1393                        let remote_url = self.url.clone();
1394                        move || {
1395                            let mut url = remote_url.clone();
1396                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1397                                .expect("cannot fail");
1398                            url
1399                        }
1400                    });
1401            }
1402            Tunnel::AwsPrivatelink(connection) => {
1403                assert_none!(connection.port);
1404
1405                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1406                    connection.connection_id,
1407                    connection.availability_zone.as_deref(),
1408                );
1409                let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_DNS_PORT))
1410                    .await
1411                    .context("resolving PrivateLink host")?
1412                    .collect();
1413                client_config = client_config.resolve_to_addrs(host, &addrs)
1414            }
1415        }
1416
1417        Ok(client_config.build()?)
1418    }
1419
1420    async fn validate(
1421        &self,
1422        _id: CatalogItemId,
1423        storage_configuration: &StorageConfiguration,
1424    ) -> Result<(), anyhow::Error> {
1425        let client = self
1426            .connect(
1427                storage_configuration,
1428                // We are in a normal tokio context during validation, already.
1429                InTask::No,
1430            )
1431            .await?;
1432        client.list_subjects().await?;
1433        Ok(())
1434    }
1435}
1436
1437impl RustType<ProtoCsrConnection> for CsrConnection {
1438    fn into_proto(&self) -> ProtoCsrConnection {
1439        ProtoCsrConnection {
1440            url: Some(self.url.into_proto()),
1441            tls_root_cert: self.tls_root_cert.into_proto(),
1442            tls_identity: self.tls_identity.into_proto(),
1443            http_auth: self.http_auth.into_proto(),
1444            tunnel: Some(self.tunnel.into_proto()),
1445        }
1446    }
1447
1448    fn from_proto(proto: ProtoCsrConnection) -> Result<Self, TryFromProtoError> {
1449        Ok(CsrConnection {
1450            url: proto.url.into_rust_if_some("ProtoCsrConnection::url")?,
1451            tls_root_cert: proto.tls_root_cert.into_rust()?,
1452            tls_identity: proto.tls_identity.into_rust()?,
1453            http_auth: proto.http_auth.into_rust()?,
1454            tunnel: proto
1455                .tunnel
1456                .into_rust_if_some("ProtoCsrConnection::tunnel")?,
1457        })
1458    }
1459}
1460
1461impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1462    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1463        let CsrConnection {
1464            tunnel,
1465            // All non-tunnel fields may change
1466            url: _,
1467            tls_root_cert: _,
1468            tls_identity: _,
1469            http_auth: _,
1470        } = self;
1471
1472        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1473
1474        for (compatible, field) in compatibility_checks {
1475            if !compatible {
1476                tracing::warn!(
1477                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1478                    self,
1479                    other
1480                );
1481
1482                return Err(AlterError { id });
1483            }
1484        }
1485        Ok(())
1486    }
1487}
1488
1489/// A TLS key pair used for client identity.
1490#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1491pub struct TlsIdentity {
1492    /// The client's TLS public certificate in PEM format.
1493    pub cert: StringOrSecret,
1494    /// The ID of the secret containing the client's TLS private key in PEM
1495    /// format.
1496    pub key: CatalogItemId,
1497}
1498
1499impl RustType<ProtoTlsIdentity> for TlsIdentity {
1500    fn into_proto(&self) -> ProtoTlsIdentity {
1501        ProtoTlsIdentity {
1502            cert: Some(self.cert.into_proto()),
1503            key: Some(self.key.into_proto()),
1504        }
1505    }
1506
1507    fn from_proto(proto: ProtoTlsIdentity) -> Result<Self, TryFromProtoError> {
1508        Ok(TlsIdentity {
1509            cert: proto.cert.into_rust_if_some("ProtoTlsIdentity::cert")?,
1510            key: proto.key.into_rust_if_some("ProtoTlsIdentity::key")?,
1511        })
1512    }
1513}
1514
1515/// HTTP authentication credentials in a [`CsrConnection`].
1516#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1517pub struct CsrConnectionHttpAuth {
1518    /// The username.
1519    pub username: StringOrSecret,
1520    /// The ID of the secret containing the password, if any.
1521    pub password: Option<CatalogItemId>,
1522}
1523
1524impl RustType<ProtoCsrConnectionHttpAuth> for CsrConnectionHttpAuth {
1525    fn into_proto(&self) -> ProtoCsrConnectionHttpAuth {
1526        ProtoCsrConnectionHttpAuth {
1527            username: Some(self.username.into_proto()),
1528            password: self.password.into_proto(),
1529        }
1530    }
1531
1532    fn from_proto(proto: ProtoCsrConnectionHttpAuth) -> Result<Self, TryFromProtoError> {
1533        Ok(CsrConnectionHttpAuth {
1534            username: proto
1535                .username
1536                .into_rust_if_some("ProtoCsrConnectionHttpAuth::username")?,
1537            password: proto.password.into_rust()?,
1538        })
1539    }
1540}
1541
1542/// A connection to a PostgreSQL server.
1543#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1544pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1545    /// The hostname of the server.
1546    pub host: String,
1547    /// The port of the server.
1548    pub port: u16,
1549    /// The name of the database to connect to.
1550    pub database: String,
1551    /// The username to authenticate as.
1552    pub user: StringOrSecret,
1553    /// An optional password for authentication.
1554    pub password: Option<CatalogItemId>,
1555    /// A tunnel through which to route traffic.
1556    pub tunnel: Tunnel<C>,
1557    /// Whether to use TLS for encryption, authentication, or both.
1558    #[proptest(strategy = "any_ssl_mode()")]
1559    pub tls_mode: SslMode,
1560    /// An optional root TLS certificate in PEM format, to verify the server's
1561    /// identity.
1562    pub tls_root_cert: Option<StringOrSecret>,
1563    /// An optional TLS client certificate for authentication.
1564    pub tls_identity: Option<TlsIdentity>,
1565    /// The kind of postgres server we are connecting to. This can be vanilla, for a normal
1566    /// postgres server or some other system that is pg compatible, like Yugabyte, Aurora, etc.
1567    pub flavor: PostgresFlavor,
1568}
1569
1570impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1571    for PostgresConnection<ReferencedConnection>
1572{
1573    fn into_inline_connection(self, r: R) -> PostgresConnection {
1574        let PostgresConnection {
1575            host,
1576            port,
1577            database,
1578            user,
1579            password,
1580            tunnel,
1581            tls_mode,
1582            tls_root_cert,
1583            tls_identity,
1584            flavor,
1585        } = self;
1586
1587        PostgresConnection {
1588            host,
1589            port,
1590            database,
1591            user,
1592            password,
1593            tunnel: tunnel.into_inline_connection(r),
1594            tls_mode,
1595            tls_root_cert,
1596            tls_identity,
1597            flavor,
1598        }
1599    }
1600}
1601
1602impl<C: ConnectionAccess> PostgresConnection<C> {
1603    fn validate_by_default(&self) -> bool {
1604        true
1605    }
1606}
1607
1608impl PostgresConnection<InlinedConnection> {
1609    pub async fn config(
1610        &self,
1611        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1612        storage_configuration: &StorageConfiguration,
1613        in_task: InTask,
1614    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1615        let params = &storage_configuration.parameters;
1616
1617        let mut config = tokio_postgres::Config::new();
1618        config
1619            .host(&self.host)
1620            .port(self.port)
1621            .dbname(&self.database)
1622            .user(&self.user.get_string(in_task, secrets_reader).await?)
1623            .ssl_mode(self.tls_mode);
1624        if let Some(password) = self.password {
1625            let password = secrets_reader
1626                .read_string_in_task_if(in_task, password)
1627                .await?;
1628            config.password(password);
1629        }
1630        if let Some(tls_root_cert) = &self.tls_root_cert {
1631            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1632            config.ssl_root_cert(tls_root_cert.as_bytes());
1633        }
1634        if let Some(tls_identity) = &self.tls_identity {
1635            let cert = tls_identity
1636                .cert
1637                .get_string(in_task, secrets_reader)
1638                .await?;
1639            let key = secrets_reader
1640                .read_string_in_task_if(in_task, tls_identity.key)
1641                .await?;
1642            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1643        }
1644
1645        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1646            config.connect_timeout(connect_timeout);
1647        }
1648        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1649            config.keepalives_retries(keepalives_retries);
1650        }
1651        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1652            config.keepalives_idle(keepalives_idle);
1653        }
1654        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1655            config.keepalives_interval(keepalives_interval);
1656        }
1657        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1658            config.tcp_user_timeout(tcp_user_timeout);
1659        }
1660
1661        let mut options = vec![];
1662        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1663            options.push(format!(
1664                "--wal_sender_timeout={}",
1665                wal_sender_timeout.as_millis()
1666            ));
1667        };
1668        if params.pg_source_tcp_configure_server {
1669            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1670                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1671            }
1672            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1673                options.push(format!(
1674                    "--tcp_keepalives_idle={}",
1675                    keepalives_idle.as_secs()
1676                ));
1677            }
1678            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1679                options.push(format!(
1680                    "--tcp_keepalives_interval={}",
1681                    keepalives_interval.as_secs()
1682                ));
1683            }
1684            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1685                options.push(format!(
1686                    "--tcp_user_timeout={}",
1687                    tcp_user_timeout.as_millis()
1688                ));
1689            }
1690        }
1691        config.options(options.join(" ").as_str());
1692
1693        let tunnel = match &self.tunnel {
1694            Tunnel::Direct => {
1695                // Ensure any host we connect to is resolved to an external address.
1696                let resolved = resolve_address(
1697                    &self.host,
1698                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1699                )
1700                .await?;
1701                mz_postgres_util::TunnelConfig::Direct {
1702                    resolved_ips: Some(resolved),
1703                }
1704            }
1705            Tunnel::Ssh(SshTunnel {
1706                connection_id,
1707                connection,
1708            }) => {
1709                let secret = secrets_reader
1710                    .read_in_task_if(in_task, *connection_id)
1711                    .await?;
1712                let key_pair = SshKeyPair::from_bytes(&secret)?;
1713                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1714                let resolved = resolve_address(
1715                    &connection.host,
1716                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1717                )
1718                .await?;
1719                mz_postgres_util::TunnelConfig::Ssh {
1720                    config: SshTunnelConfig {
1721                        host: resolved
1722                            .iter()
1723                            .map(|a| a.to_string())
1724                            .collect::<BTreeSet<_>>(),
1725                        port: connection.port,
1726                        user: connection.user.clone(),
1727                        key_pair,
1728                    },
1729                }
1730            }
1731            Tunnel::AwsPrivatelink(connection) => {
1732                assert_none!(connection.port);
1733                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1734                    connection_id: connection.connection_id,
1735                }
1736            }
1737        };
1738
1739        Ok(mz_postgres_util::Config::new(
1740            config,
1741            tunnel,
1742            params.ssh_timeout_config,
1743            in_task,
1744        )?)
1745    }
1746
1747    async fn validate(
1748        &self,
1749        _id: CatalogItemId,
1750        storage_configuration: &StorageConfiguration,
1751    ) -> Result<(), anyhow::Error> {
1752        let config = self
1753            .config(
1754                &storage_configuration.connection_context.secrets_reader,
1755                storage_configuration,
1756                // We are in a normal tokio context during validation, already.
1757                InTask::No,
1758            )
1759            .await?;
1760        let client = config
1761            .connect(
1762                "connection validation",
1763                &storage_configuration.connection_context.ssh_tunnel_manager,
1764            )
1765            .await?;
1766        use PostgresFlavor::*;
1767        match (client.server_flavor(), &self.flavor) {
1768            (Vanilla, Yugabyte) => bail!("Expected to find PostgreSQL server, found Yugabyte."),
1769            (Yugabyte, Vanilla) => bail!("Expected to find Yugabyte server, found PostgreSQL."),
1770            (Vanilla, Vanilla) | (Yugabyte, Yugabyte) => {}
1771        }
1772        Ok(())
1773    }
1774}
1775
1776impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1777    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1778        let PostgresConnection {
1779            tunnel,
1780            flavor,
1781            // All non-tunnel options may change arbitrarily
1782            host: _,
1783            port: _,
1784            database: _,
1785            user: _,
1786            password: _,
1787            tls_mode: _,
1788            tls_root_cert: _,
1789            tls_identity: _,
1790        } = self;
1791
1792        let compatibility_checks = [
1793            (tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel"),
1794            (flavor == &other.flavor, "flavor"),
1795        ];
1796
1797        for (compatible, field) in compatibility_checks {
1798            if !compatible {
1799                tracing::warn!(
1800                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1801                    self,
1802                    other
1803                );
1804
1805                return Err(AlterError { id });
1806            }
1807        }
1808        Ok(())
1809    }
1810}
1811
1812impl RustType<ProtoPostgresConnection> for PostgresConnection {
1813    fn into_proto(&self) -> ProtoPostgresConnection {
1814        ProtoPostgresConnection {
1815            host: self.host.into_proto(),
1816            port: self.port.into_proto(),
1817            database: self.database.into_proto(),
1818            user: Some(self.user.into_proto()),
1819            password: self.password.into_proto(),
1820            tls_mode: Some(self.tls_mode.into_proto()),
1821            tls_root_cert: self.tls_root_cert.into_proto(),
1822            tls_identity: self.tls_identity.into_proto(),
1823            tunnel: Some(self.tunnel.into_proto()),
1824            flavor: Some(self.flavor.into_proto()),
1825        }
1826    }
1827
1828    fn from_proto(proto: ProtoPostgresConnection) -> Result<Self, TryFromProtoError> {
1829        Ok(PostgresConnection {
1830            host: proto.host,
1831            port: proto.port.into_rust()?,
1832            database: proto.database,
1833            user: proto
1834                .user
1835                .into_rust_if_some("ProtoPostgresConnection::user")?,
1836            password: proto.password.into_rust()?,
1837            tunnel: proto
1838                .tunnel
1839                .into_rust_if_some("ProtoPostgresConnection::tunnel")?,
1840            tls_mode: proto
1841                .tls_mode
1842                .into_rust_if_some("ProtoPostgresConnection::tls_mode")?,
1843            tls_root_cert: proto.tls_root_cert.into_rust()?,
1844            tls_identity: proto.tls_identity.into_rust()?,
1845            flavor: proto
1846                .flavor
1847                .into_rust_if_some("ProtoPostgresConnection::flavor")?,
1848        })
1849    }
1850}
1851
1852/// Specifies how to tunnel a connection.
1853#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1854pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1855    /// No tunneling.
1856    Direct,
1857    /// Via the specified SSH tunnel connection.
1858    Ssh(SshTunnel<C>),
1859    /// Via the specified AWS PrivateLink connection.
1860    AwsPrivatelink(AwsPrivatelink),
1861}
1862
1863impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1864    fn into_inline_connection(self, r: R) -> Tunnel {
1865        match self {
1866            Tunnel::Direct => Tunnel::Direct,
1867            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1868            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1869        }
1870    }
1871}
1872
1873impl RustType<ProtoTunnel> for Tunnel<InlinedConnection> {
1874    fn into_proto(&self) -> ProtoTunnel {
1875        use proto_tunnel::Tunnel as ProtoTunnelField;
1876        ProtoTunnel {
1877            tunnel: Some(match &self {
1878                Tunnel::Direct => ProtoTunnelField::Direct(()),
1879                Tunnel::Ssh(ssh) => ProtoTunnelField::Ssh(ssh.into_proto()),
1880                Tunnel::AwsPrivatelink(aws) => ProtoTunnelField::AwsPrivatelink(aws.into_proto()),
1881            }),
1882        }
1883    }
1884
1885    fn from_proto(proto: ProtoTunnel) -> Result<Self, TryFromProtoError> {
1886        use proto_tunnel::Tunnel as ProtoTunnelField;
1887        Ok(match proto.tunnel {
1888            None => return Err(TryFromProtoError::missing_field("ProtoTunnel::tunnel")),
1889            Some(ProtoTunnelField::Direct(())) => Tunnel::Direct,
1890            Some(ProtoTunnelField::Ssh(ssh)) => Tunnel::Ssh(ssh.into_rust()?),
1891            Some(ProtoTunnelField::AwsPrivatelink(aws)) => Tunnel::AwsPrivatelink(aws.into_rust()?),
1892        })
1893    }
1894}
1895
1896impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1897    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1898        let compatible = match (self, other) {
1899            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1900            (s, o) => s == o,
1901        };
1902
1903        if !compatible {
1904            tracing::warn!(
1905                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1906                self,
1907                other
1908            );
1909
1910            return Err(AlterError { id });
1911        }
1912
1913        Ok(())
1914    }
1915}
1916
1917/// Specifies which MySQL SSL Mode to use:
1918/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1919/// This is not available as an enum in the mysql-async crate, so we define our own.
1920#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1921pub enum MySqlSslMode {
1922    Disabled,
1923    Required,
1924    VerifyCa,
1925    VerifyIdentity,
1926}
1927
1928impl RustType<i32> for MySqlSslMode {
1929    fn into_proto(&self) -> i32 {
1930        match self {
1931            MySqlSslMode::Disabled => ProtoMySqlSslMode::Disabled.into(),
1932            MySqlSslMode::Required => ProtoMySqlSslMode::Required.into(),
1933            MySqlSslMode::VerifyCa => ProtoMySqlSslMode::VerifyCa.into(),
1934            MySqlSslMode::VerifyIdentity => ProtoMySqlSslMode::VerifyIdentity.into(),
1935        }
1936    }
1937
1938    fn from_proto(proto: i32) -> Result<Self, TryFromProtoError> {
1939        Ok(match ProtoMySqlSslMode::try_from(proto) {
1940            Ok(ProtoMySqlSslMode::Disabled) => MySqlSslMode::Disabled,
1941            Ok(ProtoMySqlSslMode::Required) => MySqlSslMode::Required,
1942            Ok(ProtoMySqlSslMode::VerifyCa) => MySqlSslMode::VerifyCa,
1943            Ok(ProtoMySqlSslMode::VerifyIdentity) => MySqlSslMode::VerifyIdentity,
1944            Err(_) => {
1945                return Err(TryFromProtoError::UnknownEnumVariant(
1946                    "tls_mode".to_string(),
1947                ));
1948            }
1949        })
1950    }
1951}
1952
1953pub fn any_mysql_ssl_mode() -> impl Strategy<Value = MySqlSslMode> {
1954    proptest::sample::select(vec![
1955        MySqlSslMode::Disabled,
1956        MySqlSslMode::Required,
1957        MySqlSslMode::VerifyCa,
1958        MySqlSslMode::VerifyIdentity,
1959    ])
1960}
1961
1962/// A connection to a MySQL server.
1963#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1964pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1965    /// The hostname of the server.
1966    pub host: String,
1967    /// The port of the server.
1968    pub port: u16,
1969    /// The username to authenticate as.
1970    pub user: StringOrSecret,
1971    /// An optional password for authentication.
1972    pub password: Option<CatalogItemId>,
1973    /// A tunnel through which to route traffic.
1974    pub tunnel: Tunnel<C>,
1975    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1976    #[proptest(strategy = "any_mysql_ssl_mode()")]
1977    pub tls_mode: MySqlSslMode,
1978    /// An optional root TLS certificate in PEM format, to verify the server's
1979    /// identity.
1980    pub tls_root_cert: Option<StringOrSecret>,
1981    /// An optional TLS client certificate for authentication.
1982    pub tls_identity: Option<TlsIdentity>,
1983    /// Reference to the AWS connection information to be used for IAM authenitcation and
1984    /// assuming AWS roles.
1985    pub aws_connection: Option<AwsConnectionReference<C>>,
1986}
1987
1988impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
1989    for MySqlConnection<ReferencedConnection>
1990{
1991    fn into_inline_connection(self, r: R) -> MySqlConnection {
1992        let MySqlConnection {
1993            host,
1994            port,
1995            user,
1996            password,
1997            tunnel,
1998            tls_mode,
1999            tls_root_cert,
2000            tls_identity,
2001            aws_connection,
2002        } = self;
2003
2004        MySqlConnection {
2005            host,
2006            port,
2007            user,
2008            password,
2009            tunnel: tunnel.into_inline_connection(&r),
2010            tls_mode,
2011            tls_root_cert,
2012            tls_identity,
2013            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
2014        }
2015    }
2016}
2017
2018impl<C: ConnectionAccess> MySqlConnection<C> {
2019    fn validate_by_default(&self) -> bool {
2020        true
2021    }
2022}
2023
2024impl MySqlConnection<InlinedConnection> {
2025    pub async fn config(
2026        &self,
2027        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2028        storage_configuration: &StorageConfiguration,
2029        in_task: InTask,
2030    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
2031        // TODO(roshan): Set appropriate connection timeouts
2032        let mut opts = mysql_async::OptsBuilder::default()
2033            .ip_or_hostname(&self.host)
2034            .tcp_port(self.port)
2035            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
2036
2037        if let Some(password) = self.password {
2038            let password = secrets_reader
2039                .read_string_in_task_if(in_task, password)
2040                .await?;
2041            opts = opts.pass(Some(password));
2042        }
2043
2044        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
2045        // which uses opt-in security features (SSL, CA verification, & Identity verification).
2046        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
2047        // we need to appropriately disable features to match the intent of each enum value.
2048        let mut ssl_opts = match self.tls_mode {
2049            MySqlSslMode::Disabled => None,
2050            MySqlSslMode::Required => Some(
2051                mysql_async::SslOpts::default()
2052                    .with_danger_accept_invalid_certs(true)
2053                    .with_danger_skip_domain_validation(true),
2054            ),
2055            MySqlSslMode::VerifyCa => {
2056                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
2057            }
2058            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
2059        };
2060
2061        if matches!(
2062            self.tls_mode,
2063            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
2064        ) {
2065            if let Some(tls_root_cert) = &self.tls_root_cert {
2066                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
2067                ssl_opts = ssl_opts.map(|opts| {
2068                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
2069                });
2070            }
2071        }
2072
2073        if let Some(identity) = &self.tls_identity {
2074            let key = secrets_reader
2075                .read_string_in_task_if(in_task, identity.key)
2076                .await?;
2077            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
2078            let Pkcs12Archive { der, pass } =
2079                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
2080
2081            // Add client identity to SSLOpts
2082            ssl_opts = ssl_opts.map(|opts| {
2083                opts.with_client_identity(Some(
2084                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
2085                ))
2086            });
2087        }
2088
2089        opts = opts.ssl_opts(ssl_opts);
2090
2091        let tunnel = match &self.tunnel {
2092            Tunnel::Direct => {
2093                // Ensure any host we connect to is resolved to an external address.
2094                let resolved = resolve_address(
2095                    &self.host,
2096                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2097                )
2098                .await?;
2099                mz_mysql_util::TunnelConfig::Direct {
2100                    resolved_ips: Some(resolved),
2101                }
2102            }
2103            Tunnel::Ssh(SshTunnel {
2104                connection_id,
2105                connection,
2106            }) => {
2107                let secret = secrets_reader
2108                    .read_in_task_if(in_task, *connection_id)
2109                    .await?;
2110                let key_pair = SshKeyPair::from_bytes(&secret)?;
2111                // Ensure any ssh-bastion host we connect to is resolved to an external address.
2112                let resolved = resolve_address(
2113                    &connection.host,
2114                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2115                )
2116                .await?;
2117                mz_mysql_util::TunnelConfig::Ssh {
2118                    config: SshTunnelConfig {
2119                        host: resolved
2120                            .iter()
2121                            .map(|a| a.to_string())
2122                            .collect::<BTreeSet<_>>(),
2123                        port: connection.port,
2124                        user: connection.user.clone(),
2125                        key_pair,
2126                    },
2127                }
2128            }
2129            Tunnel::AwsPrivatelink(connection) => {
2130                assert_none!(connection.port);
2131                mz_mysql_util::TunnelConfig::AwsPrivatelink {
2132                    connection_id: connection.connection_id,
2133                }
2134            }
2135        };
2136
2137        let aws_config = match self.aws_connection.as_ref() {
2138            None => None,
2139            Some(aws_ref) => Some(
2140                aws_ref
2141                    .connection
2142                    .load_sdk_config(
2143                        &storage_configuration.connection_context,
2144                        aws_ref.connection_id,
2145                        in_task,
2146                    )
2147                    .await?,
2148            ),
2149        };
2150
2151        Ok(mz_mysql_util::Config::new(
2152            opts,
2153            tunnel,
2154            storage_configuration.parameters.ssh_timeout_config,
2155            in_task,
2156            storage_configuration
2157                .parameters
2158                .mysql_source_timeouts
2159                .clone(),
2160            aws_config,
2161        )?)
2162    }
2163
2164    async fn validate(
2165        &self,
2166        _id: CatalogItemId,
2167        storage_configuration: &StorageConfiguration,
2168    ) -> Result<(), anyhow::Error> {
2169        let config = self
2170            .config(
2171                &storage_configuration.connection_context.secrets_reader,
2172                storage_configuration,
2173                // We are in a normal tokio context during validation, already.
2174                InTask::No,
2175            )
2176            .await?;
2177        let conn = config
2178            .connect(
2179                "connection validation",
2180                &storage_configuration.connection_context.ssh_tunnel_manager,
2181            )
2182            .await?;
2183        conn.disconnect().await?;
2184        Ok(())
2185    }
2186}
2187
2188impl RustType<ProtoMySqlConnection> for MySqlConnection {
2189    fn into_proto(&self) -> ProtoMySqlConnection {
2190        ProtoMySqlConnection {
2191            host: self.host.into_proto(),
2192            port: self.port.into_proto(),
2193            user: Some(self.user.into_proto()),
2194            password: self.password.into_proto(),
2195            tls_mode: self.tls_mode.into_proto(),
2196            tls_root_cert: self.tls_root_cert.into_proto(),
2197            tls_identity: self.tls_identity.into_proto(),
2198            tunnel: Some(self.tunnel.into_proto()),
2199            aws_connection: self.aws_connection.into_proto(),
2200        }
2201    }
2202
2203    fn from_proto(proto: ProtoMySqlConnection) -> Result<Self, TryFromProtoError> {
2204        Ok(MySqlConnection {
2205            host: proto.host,
2206            port: proto.port.into_rust()?,
2207            user: proto.user.into_rust_if_some("ProtoMySqlConnection::user")?,
2208            password: proto.password.into_rust()?,
2209            tunnel: proto
2210                .tunnel
2211                .into_rust_if_some("ProtoMySqlConnection::tunnel")?,
2212            tls_mode: proto.tls_mode.into_rust()?,
2213            tls_root_cert: proto.tls_root_cert.into_rust()?,
2214            tls_identity: proto.tls_identity.into_rust()?,
2215            aws_connection: proto.aws_connection.into_rust()?,
2216        })
2217    }
2218}
2219
2220impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
2221    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2222        let MySqlConnection {
2223            tunnel,
2224            // All non-tunnel options may change arbitrarily
2225            host: _,
2226            port: _,
2227            user: _,
2228            password: _,
2229            tls_mode: _,
2230            tls_root_cert: _,
2231            tls_identity: _,
2232            aws_connection: _,
2233        } = self;
2234
2235        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2236
2237        for (compatible, field) in compatibility_checks {
2238            if !compatible {
2239                tracing::warn!(
2240                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2241                    self,
2242                    other
2243                );
2244
2245                return Err(AlterError { id });
2246            }
2247        }
2248        Ok(())
2249    }
2250}
2251
2252/// Details how to connect to an instance of Microsoft SQL Server.
2253///
2254/// For specifics of connecting to SQL Server for purposes of creating a
2255/// Materialize Source, see [`SqlServerSource`] which wraps this type.
2256///
2257/// [`SqlServerSource`]: crate::sources::SqlServerSource
2258#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2259pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2260    /// The hostname of the server.
2261    pub host: String,
2262    /// The port of the server.
2263    pub port: u16,
2264    /// Database we should connect to.
2265    pub database: String,
2266    /// The username to authenticate as.
2267    pub user: StringOrSecret,
2268    /// Password used for authentication.
2269    pub password: CatalogItemId,
2270    /// A tunnel through which to route traffic.
2271    pub tunnel: Tunnel<C>,
2272    /// Level of encryption to use for the connection.
2273    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2274    /// Certificate validation policy
2275    pub certificate_validation_policy: mz_sql_server_util::config::CertificateValidationPolicy,
2276    /// TLS CA Certifiecate in PEM format
2277    pub tls_root_cert: Option<StringOrSecret>,
2278}
2279
2280impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2281    fn validate_by_default(&self) -> bool {
2282        true
2283    }
2284}
2285
2286impl SqlServerConnectionDetails<InlinedConnection> {
2287    /// Attempts to open a connection to the upstream SQL Server instance.
2288    async fn validate(
2289        &self,
2290        _id: CatalogItemId,
2291        storage_configuration: &StorageConfiguration,
2292    ) -> Result<(), anyhow::Error> {
2293        let config = self
2294            .resolve_config(
2295                &storage_configuration.connection_context.secrets_reader,
2296                storage_configuration,
2297                InTask::No,
2298            )
2299            .await?;
2300        tracing::debug!(?config, "Validating SQL Server connection");
2301
2302        // Just connecting is enough to validate, no need to send any queries.
2303        let _client = mz_sql_server_util::Client::connect(config).await?;
2304
2305        Ok(())
2306    }
2307
2308    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2309    /// so the returned [`Config`] can be used to open a connection with the
2310    /// upstream system.
2311    ///
2312    /// The provided [`InTask`] argument determines whether any I/O is run in an
2313    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2314    /// future. The main goal here is to prevent running I/O in timely threads.
2315    ///
2316    /// [`Config`]: mz_sql_server_util::Config
2317    pub async fn resolve_config(
2318        &self,
2319        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2320        storage_configuration: &StorageConfiguration,
2321        in_task: InTask,
2322    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2323        let dyncfg = storage_configuration.config_set();
2324        let mut inner_config = tiberius::Config::new();
2325
2326        // Setup default connection params.
2327        inner_config.host(&self.host);
2328        inner_config.port(self.port);
2329        inner_config.database(self.database.clone());
2330        inner_config.encryption(self.encryption.into());
2331        match self.certificate_validation_policy {
2332            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2333                inner_config.trust_cert()
2334            }
2335            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2336                inner_config.trust_cert_ca_pem(
2337                    self.tls_root_cert
2338                        .as_ref()
2339                        .unwrap()
2340                        .get_string(in_task, secrets_reader)
2341                        .await
2342                        .context("ca certificate")?,
2343                );
2344            }
2345            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => (), // no-op
2346        }
2347
2348        inner_config.application_name("materialize");
2349
2350        // Read our auth settings from
2351        let user = self
2352            .user
2353            .get_string(in_task, secrets_reader)
2354            .await
2355            .context("username")?;
2356        let password = secrets_reader
2357            .read_string_in_task_if(in_task, self.password)
2358            .await
2359            .context("password")?;
2360        // TODO(sql_server3): Support other methods of authentication besides
2361        // username and password.
2362        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2363
2364        // Prevent users from probing our internal network ports by trying to
2365        // connect to localhost, or another non-external IP.
2366        let enfoce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2367
2368        let tunnel = match &self.tunnel {
2369            Tunnel::Direct => mz_sql_server_util::config::TunnelConfig::Direct,
2370            Tunnel::Ssh(SshTunnel {
2371                connection_id,
2372                connection: ssh_connection,
2373            }) => {
2374                let secret = secrets_reader
2375                    .read_in_task_if(in_task, *connection_id)
2376                    .await
2377                    .context("ssh secret")?;
2378                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2379                // Ensure any SSH-bastion host we connect to is resolved to an
2380                // external address.
2381                let addresses = resolve_address(&ssh_connection.host, enfoce_external_addresses)
2382                    .await
2383                    .context("ssh tunnel")?;
2384
2385                let config = SshTunnelConfig {
2386                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2387                    port: ssh_connection.port,
2388                    user: ssh_connection.user.clone(),
2389                    key_pair,
2390                };
2391                mz_sql_server_util::config::TunnelConfig::Ssh {
2392                    config,
2393                    manager: storage_configuration
2394                        .connection_context
2395                        .ssh_tunnel_manager
2396                        .clone(),
2397                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2398                    host: self.host.clone(),
2399                    port: self.port,
2400                }
2401            }
2402            Tunnel::AwsPrivatelink(private_link_connection) => {
2403                assert_none!(private_link_connection.port);
2404                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2405                    connection_id: private_link_connection.connection_id,
2406                    port: self.port,
2407                }
2408            }
2409        };
2410
2411        Ok(mz_sql_server_util::Config::new(
2412            inner_config,
2413            tunnel,
2414            in_task,
2415        ))
2416    }
2417}
2418
2419impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2420    for SqlServerConnectionDetails<ReferencedConnection>
2421{
2422    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2423        let SqlServerConnectionDetails {
2424            host,
2425            port,
2426            database,
2427            user,
2428            password,
2429            tunnel,
2430            encryption,
2431            certificate_validation_policy,
2432            tls_root_cert,
2433        } = self;
2434
2435        SqlServerConnectionDetails {
2436            host,
2437            port,
2438            database,
2439            user,
2440            password,
2441            tunnel: tunnel.into_inline_connection(&r),
2442            encryption,
2443            certificate_validation_policy,
2444            tls_root_cert,
2445        }
2446    }
2447}
2448
2449impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2450    fn alter_compatible(
2451        &self,
2452        id: mz_repr::GlobalId,
2453        other: &Self,
2454    ) -> Result<(), crate::controller::AlterError> {
2455        let SqlServerConnectionDetails {
2456            tunnel,
2457            // TODO(sql_server2): Figure out how these variables are allowed to change.
2458            host: _,
2459            port: _,
2460            database: _,
2461            user: _,
2462            password: _,
2463            encryption: _,
2464            certificate_validation_policy: _,
2465            tls_root_cert: _,
2466        } = self;
2467
2468        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2469
2470        for (compatible, field) in compatibility_checks {
2471            if !compatible {
2472                tracing::warn!(
2473                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2474                    self,
2475                    other
2476                );
2477
2478                return Err(AlterError { id });
2479            }
2480        }
2481        Ok(())
2482    }
2483}
2484
2485impl RustType<ProtoSqlServerConnectionDetails> for SqlServerConnectionDetails {
2486    fn into_proto(&self) -> ProtoSqlServerConnectionDetails {
2487        ProtoSqlServerConnectionDetails {
2488            host: self.host.into_proto(),
2489            port: self.port.into_proto(),
2490            database: self.database.into_proto(),
2491            user: Some(self.user.into_proto()),
2492            password: Some(self.password.into_proto()),
2493            tunnel: Some(self.tunnel.into_proto()),
2494            encryption: self.encryption.into_proto().into(),
2495            certificate_validation_policy: self.certificate_validation_policy.into_proto().into(),
2496            tls_root_cert: self.tls_root_cert.into_proto(),
2497        }
2498    }
2499
2500    fn from_proto(proto: ProtoSqlServerConnectionDetails) -> Result<Self, TryFromProtoError> {
2501        Ok(SqlServerConnectionDetails {
2502            host: proto.host,
2503            port: proto.port.into_rust()?,
2504            database: proto.database.into_rust()?,
2505            user: proto
2506                .user
2507                .into_rust_if_some("ProtoSqlServerConnectionDetails::user")?,
2508            password: proto
2509                .password
2510                .into_rust_if_some("ProtoSqlServerConnectionDetails::password")?,
2511            tunnel: proto
2512                .tunnel
2513                .into_rust_if_some("ProtoSqlServerConnectionDetails::tunnel")?,
2514            encryption: ProtoSqlServerEncryptionLevel::try_from(proto.encryption)?.into_rust()?,
2515            certificate_validation_policy: ProtoSqlServerCertificateValidationPolicy::try_from(
2516                proto.certificate_validation_policy,
2517            )?
2518            .into_rust()?,
2519            tls_root_cert: proto.tls_root_cert.into_rust()?,
2520        })
2521    }
2522}
2523
2524impl RustType<ProtoSqlServerEncryptionLevel> for mz_sql_server_util::config::EncryptionLevel {
2525    fn into_proto(&self) -> ProtoSqlServerEncryptionLevel {
2526        match self {
2527            Self::None => ProtoSqlServerEncryptionLevel::SqlServerNone,
2528            Self::Login => ProtoSqlServerEncryptionLevel::SqlServerLogin,
2529            Self::Preferred => ProtoSqlServerEncryptionLevel::SqlServerPreferred,
2530            Self::Required => ProtoSqlServerEncryptionLevel::SqlServerRequired,
2531        }
2532    }
2533
2534    fn from_proto(proto: ProtoSqlServerEncryptionLevel) -> Result<Self, TryFromProtoError> {
2535        Ok(match proto {
2536            ProtoSqlServerEncryptionLevel::SqlServerNone => {
2537                mz_sql_server_util::config::EncryptionLevel::None
2538            }
2539            ProtoSqlServerEncryptionLevel::SqlServerLogin => {
2540                mz_sql_server_util::config::EncryptionLevel::Login
2541            }
2542            ProtoSqlServerEncryptionLevel::SqlServerPreferred => {
2543                mz_sql_server_util::config::EncryptionLevel::Preferred
2544            }
2545            ProtoSqlServerEncryptionLevel::SqlServerRequired => {
2546                mz_sql_server_util::config::EncryptionLevel::Required
2547            }
2548        })
2549    }
2550}
2551
2552impl RustType<ProtoSqlServerCertificateValidationPolicy>
2553    for mz_sql_server_util::config::CertificateValidationPolicy
2554{
2555    fn into_proto(&self) -> ProtoSqlServerCertificateValidationPolicy {
2556        match self {
2557            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2558                ProtoSqlServerCertificateValidationPolicy::SqlServerTrustAll
2559            }
2560            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => {
2561                ProtoSqlServerCertificateValidationPolicy::SqlServerVerifySystem
2562            }
2563            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2564                ProtoSqlServerCertificateValidationPolicy::SqlServerVerifyCa
2565            }
2566        }
2567    }
2568
2569    fn from_proto(
2570        proto: ProtoSqlServerCertificateValidationPolicy,
2571    ) -> Result<Self, TryFromProtoError> {
2572        Ok(match proto {
2573            ProtoSqlServerCertificateValidationPolicy::SqlServerTrustAll => {
2574                mz_sql_server_util::config::CertificateValidationPolicy::TrustAll
2575            }
2576            ProtoSqlServerCertificateValidationPolicy::SqlServerVerifySystem => {
2577                mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem
2578            }
2579            ProtoSqlServerCertificateValidationPolicy::SqlServerVerifyCa => {
2580                mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA
2581            }
2582        })
2583    }
2584}
2585
2586/// A connection to an SSH tunnel.
2587#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2588pub struct SshConnection {
2589    pub host: String,
2590    pub port: u16,
2591    pub user: String,
2592}
2593
2594use self::inline::{
2595    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2596    ReferencedConnection,
2597};
2598
2599impl RustType<ProtoSshConnection> for SshConnection {
2600    fn into_proto(&self) -> ProtoSshConnection {
2601        ProtoSshConnection {
2602            host: self.host.into_proto(),
2603            port: self.port.into_proto(),
2604            user: self.user.into_proto(),
2605        }
2606    }
2607
2608    fn from_proto(proto: ProtoSshConnection) -> Result<Self, TryFromProtoError> {
2609        Ok(SshConnection {
2610            host: proto.host,
2611            port: proto.port.into_rust()?,
2612            user: proto.user,
2613        })
2614    }
2615}
2616
2617impl AlterCompatible for SshConnection {
2618    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2619        // Every element of the SSH connection is configurable.
2620        Ok(())
2621    }
2622}
2623
2624/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2625#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2626pub struct AwsPrivatelink {
2627    /// The ID of the connection to the AWS PrivateLink service.
2628    pub connection_id: CatalogItemId,
2629    // The availability zone to use when connecting to the AWS PrivateLink service.
2630    pub availability_zone: Option<String>,
2631    /// The port to use when connecting to the AWS PrivateLink service, if
2632    /// different from the port in [`KafkaBroker::address`].
2633    pub port: Option<u16>,
2634}
2635
2636impl RustType<ProtoAwsPrivatelink> for AwsPrivatelink {
2637    fn into_proto(&self) -> ProtoAwsPrivatelink {
2638        ProtoAwsPrivatelink {
2639            connection_id: Some(self.connection_id.into_proto()),
2640            availability_zone: self.availability_zone.into_proto(),
2641            port: self.port.into_proto(),
2642        }
2643    }
2644
2645    fn from_proto(proto: ProtoAwsPrivatelink) -> Result<Self, TryFromProtoError> {
2646        Ok(AwsPrivatelink {
2647            connection_id: proto
2648                .connection_id
2649                .into_rust_if_some("ProtoAwsPrivatelink::connection_id")?,
2650            availability_zone: proto.availability_zone.into_rust()?,
2651            port: proto.port.into_rust()?,
2652        })
2653    }
2654}
2655
2656impl AlterCompatible for AwsPrivatelink {
2657    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2658        let AwsPrivatelink {
2659            connection_id,
2660            availability_zone: _,
2661            port: _,
2662        } = self;
2663
2664        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2665
2666        for (compatible, field) in compatibility_checks {
2667            if !compatible {
2668                tracing::warn!(
2669                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2670                    self,
2671                    other
2672                );
2673
2674                return Err(AlterError { id });
2675            }
2676        }
2677
2678        Ok(())
2679    }
2680}
2681
2682/// Specifies an SSH tunnel connection.
2683#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2684pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2685    /// id of the ssh connection
2686    pub connection_id: CatalogItemId,
2687    /// ssh connection object
2688    pub connection: C::Ssh,
2689}
2690
2691impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2692    fn into_inline_connection(self, r: R) -> SshTunnel {
2693        let SshTunnel {
2694            connection,
2695            connection_id,
2696        } = self;
2697
2698        SshTunnel {
2699            connection: r.resolve_connection(connection).unwrap_ssh(),
2700            connection_id,
2701        }
2702    }
2703}
2704
2705impl RustType<ProtoSshTunnel> for SshTunnel<InlinedConnection> {
2706    fn into_proto(&self) -> ProtoSshTunnel {
2707        ProtoSshTunnel {
2708            connection_id: Some(self.connection_id.into_proto()),
2709            connection: Some(self.connection.into_proto()),
2710        }
2711    }
2712
2713    fn from_proto(proto: ProtoSshTunnel) -> Result<Self, TryFromProtoError> {
2714        Ok(SshTunnel {
2715            connection_id: proto
2716                .connection_id
2717                .into_rust_if_some("ProtoSshTunnel::connection_id")?,
2718            connection: proto
2719                .connection
2720                .into_rust_if_some("ProtoSshTunnel::connection")?,
2721        })
2722    }
2723}
2724
2725impl SshTunnel<InlinedConnection> {
2726    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2727    /// secret.
2728    async fn connect(
2729        &self,
2730        storage_configuration: &StorageConfiguration,
2731        remote_host: &str,
2732        remote_port: u16,
2733        in_task: InTask,
2734    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2735        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2736        let resolved = resolve_address(
2737            &self.connection.host,
2738            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2739        )
2740        .await?;
2741        storage_configuration
2742            .connection_context
2743            .ssh_tunnel_manager
2744            .connect(
2745                SshTunnelConfig {
2746                    host: resolved
2747                        .iter()
2748                        .map(|a| a.to_string())
2749                        .collect::<BTreeSet<_>>(),
2750                    port: self.connection.port,
2751                    user: self.connection.user.clone(),
2752                    key_pair: SshKeyPair::from_bytes(
2753                        &storage_configuration
2754                            .connection_context
2755                            .secrets_reader
2756                            .read_in_task_if(in_task, self.connection_id)
2757                            .await?,
2758                    )?,
2759                },
2760                remote_host,
2761                remote_port,
2762                storage_configuration.parameters.ssh_timeout_config,
2763                in_task,
2764            )
2765            .await
2766    }
2767}
2768
2769impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2770    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2771        let SshTunnel {
2772            connection_id,
2773            connection,
2774        } = self;
2775
2776        let compatibility_checks = [
2777            (connection_id == &other.connection_id, "connection_id"),
2778            (
2779                connection.alter_compatible(id, &other.connection).is_ok(),
2780                "connection",
2781            ),
2782        ];
2783
2784        for (compatible, field) in compatibility_checks {
2785            if !compatible {
2786                tracing::warn!(
2787                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2788                    self,
2789                    other
2790                );
2791
2792                return Err(AlterError { id });
2793            }
2794        }
2795
2796        Ok(())
2797    }
2798}
2799
2800impl SshConnection {
2801    #[allow(clippy::unused_async)]
2802    async fn validate(
2803        &self,
2804        id: CatalogItemId,
2805        storage_configuration: &StorageConfiguration,
2806    ) -> Result<(), anyhow::Error> {
2807        let secret = storage_configuration
2808            .connection_context
2809            .secrets_reader
2810            .read_in_task_if(
2811                // We are in a normal tokio context during validation, already.
2812                InTask::No,
2813                id,
2814            )
2815            .await?;
2816        let key_pair = SshKeyPair::from_bytes(&secret)?;
2817
2818        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2819        let resolved = resolve_address(
2820            &self.host,
2821            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2822        )
2823        .await?;
2824
2825        let config = SshTunnelConfig {
2826            host: resolved
2827                .iter()
2828                .map(|a| a.to_string())
2829                .collect::<BTreeSet<_>>(),
2830            port: self.port,
2831            user: self.user.clone(),
2832            key_pair,
2833        };
2834        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2835        // can actually create a new connection to the ssh bastion, without tunneling.
2836        config
2837            .validate(storage_configuration.parameters.ssh_timeout_config)
2838            .await
2839    }
2840
2841    fn validate_by_default(&self) -> bool {
2842        false
2843    }
2844}
2845
2846impl AwsPrivatelinkConnection {
2847    #[allow(clippy::unused_async)]
2848    async fn validate(
2849        &self,
2850        id: CatalogItemId,
2851        storage_configuration: &StorageConfiguration,
2852    ) -> Result<(), anyhow::Error> {
2853        let Some(ref cloud_resource_reader) = storage_configuration
2854            .connection_context
2855            .cloud_resource_reader
2856        else {
2857            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2858        };
2859
2860        // No need to optionally run this in a task, as we are just validating from envd.
2861        let status = cloud_resource_reader.read(id).await?;
2862
2863        let availability = status
2864            .conditions
2865            .as_ref()
2866            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2867
2868        match availability {
2869            Some(condition) if condition.status == "True" => Ok(()),
2870            Some(condition) => Err(anyhow!("{}", condition.message)),
2871            None => Err(anyhow!("Endpoint availability is unknown")),
2872        }
2873    }
2874
2875    fn validate_by_default(&self) -> bool {
2876        false
2877    }
2878}