mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow};
18use iceberg::{Catalog, CatalogBuilder};
19use iceberg_catalog_rest::{
20    REST_CATALOG_PROP_URI, REST_CATALOG_PROP_WAREHOUSE, RestCatalogBuilder,
21};
22use itertools::Itertools;
23use mz_ccsr::tls::{Certificate, Identity};
24use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
25use mz_dyncfg::ConfigSet;
26use mz_kafka_util::client::{
27    BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
28};
29use mz_mysql_util::{MySqlConn, MySqlError};
30use mz_ore::assert_none;
31use mz_ore::error::ErrorExt;
32use mz_ore::future::{InTask, OreFutureExt};
33use mz_ore::netio::DUMMY_DNS_PORT;
34use mz_ore::netio::resolve_address;
35use mz_ore::num::NonNeg;
36use mz_repr::{CatalogItemId, GlobalId};
37use mz_secrets::SecretsReader;
38use mz_ssh_util::keys::SshKeyPair;
39use mz_ssh_util::tunnel::SshTunnelConfig;
40use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
41use mz_tls_util::Pkcs12Archive;
42use mz_tracing::CloneableEnvFilter;
43use rdkafka::ClientContext;
44use rdkafka::config::FromClientConfigAndContext;
45use rdkafka::consumer::{BaseConsumer, Consumer};
46use regex::Regex;
47use serde::{Deserialize, Deserializer, Serialize};
48use tokio::net;
49use tokio::runtime::Handle;
50use tokio_postgres::config::SslMode;
51use tracing::{debug, warn};
52use url::Url;
53
54use crate::AlterCompatible;
55use crate::configuration::StorageConfiguration;
56use crate::connections::aws::{
57    AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
58};
59use crate::connections::string_or_secret::StringOrSecret;
60use crate::controller::AlterError;
61use crate::dyncfgs::{
62    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
63    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM,
64};
65use crate::errors::{ContextCreationError, CsrConnectError};
66
67pub mod aws;
68pub mod inline;
69pub mod string_or_secret;
70
71const REST_CATALOG_PROP_SCOPE: &str = "scope";
72const REST_CATALOG_PROP_CREDENTIAL: &str = "credential";
73
74/// An extension trait for [`SecretsReader`]
75#[async_trait::async_trait]
76trait SecretsReaderExt {
77    /// `SecretsReader::read`, but optionally run in a task.
78    async fn read_in_task_if(
79        &self,
80        in_task: InTask,
81        id: CatalogItemId,
82    ) -> Result<Vec<u8>, anyhow::Error>;
83
84    /// `SecretsReader::read_string`, but optionally run in a task.
85    async fn read_string_in_task_if(
86        &self,
87        in_task: InTask,
88        id: CatalogItemId,
89    ) -> Result<String, anyhow::Error>;
90}
91
92#[async_trait::async_trait]
93impl SecretsReaderExt for Arc<dyn SecretsReader> {
94    async fn read_in_task_if(
95        &self,
96        in_task: InTask,
97        id: CatalogItemId,
98    ) -> Result<Vec<u8>, anyhow::Error> {
99        let sr = Arc::clone(self);
100        async move { sr.read(id).await }
101            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
102            .await
103    }
104    async fn read_string_in_task_if(
105        &self,
106        in_task: InTask,
107        id: CatalogItemId,
108    ) -> Result<String, anyhow::Error> {
109        let sr = Arc::clone(self);
110        async move { sr.read_string(id).await }
111            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
112            .await
113    }
114}
115
116/// Extra context to pass through when instantiating a connection for a source
117/// or sink.
118///
119/// Should be kept cheaply cloneable.
120#[derive(Debug, Clone)]
121pub struct ConnectionContext {
122    /// An opaque identifier for the environment in which this process is
123    /// running.
124    ///
125    /// The storage layer is intentionally unaware of the structure within this
126    /// identifier. Higher layers of the stack can make use of that structure,
127    /// but the storage layer should be oblivious to it.
128    pub environment_id: String,
129    /// The level for librdkafka's logs.
130    pub librdkafka_log_level: tracing::Level,
131    /// A prefix for an external ID to use for all AWS AssumeRole operations.
132    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
133    /// The ARN for a Materialize-controlled role to assume before assuming
134    /// a customer's requested role for an AWS connection.
135    pub aws_connection_role_arn: Option<String>,
136    /// A secrets reader.
137    pub secrets_reader: Arc<dyn SecretsReader>,
138    /// A cloud resource reader, if supported in this configuration.
139    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
140    /// A manager for SSH tunnels.
141    pub ssh_tunnel_manager: SshTunnelManager,
142}
143
144impl ConnectionContext {
145    /// Constructs a new connection context from command line arguments.
146    ///
147    /// **WARNING:** it is critical for security that the `aws_external_id` be
148    /// provided by the operator of the Materialize service (i.e., via a CLI
149    /// argument or environment variable) and not the end user of Materialize
150    /// (e.g., via a configuration option in a SQL statement). See
151    /// [`AwsExternalIdPrefix`] for details.
152    pub fn from_cli_args(
153        environment_id: String,
154        startup_log_level: &CloneableEnvFilter,
155        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
156        aws_connection_role_arn: Option<String>,
157        secrets_reader: Arc<dyn SecretsReader>,
158        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
159    ) -> ConnectionContext {
160        ConnectionContext {
161            environment_id,
162            librdkafka_log_level: mz_ore::tracing::crate_level(
163                &startup_log_level.clone().into(),
164                "librdkafka",
165            ),
166            aws_external_id_prefix,
167            aws_connection_role_arn,
168            secrets_reader,
169            cloud_resource_reader,
170            ssh_tunnel_manager: SshTunnelManager::default(),
171        }
172    }
173
174    /// Constructs a new connection context for usage in tests.
175    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
176        ConnectionContext {
177            environment_id: "test-environment-id".into(),
178            librdkafka_log_level: tracing::Level::INFO,
179            aws_external_id_prefix: Some(
180                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
181                    "test-aws-external-id-prefix",
182                )
183                .expect("infallible"),
184            ),
185            aws_connection_role_arn: Some(
186                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
187            ),
188            secrets_reader,
189            cloud_resource_reader: None,
190            ssh_tunnel_manager: SshTunnelManager::default(),
191        }
192    }
193}
194
195#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
196pub enum Connection<C: ConnectionAccess = InlinedConnection> {
197    Kafka(KafkaConnection<C>),
198    Csr(CsrConnection<C>),
199    Postgres(PostgresConnection<C>),
200    Ssh(SshConnection),
201    Aws(AwsConnection),
202    AwsPrivatelink(AwsPrivatelinkConnection),
203    MySql(MySqlConnection<C>),
204    SqlServer(SqlServerConnectionDetails<C>),
205    IcebergCatalog(IcebergCatalogConnection<C>),
206}
207
208impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
209    for Connection<ReferencedConnection>
210{
211    fn into_inline_connection(self, r: R) -> Connection {
212        match self {
213            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
214            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
215            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
216            Connection::Ssh(ssh) => Connection::Ssh(ssh),
217            Connection::Aws(aws) => Connection::Aws(aws),
218            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
219            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
220            Connection::SqlServer(sql_server) => {
221                Connection::SqlServer(sql_server.into_inline_connection(r))
222            }
223            Connection::IcebergCatalog(iceberg) => {
224                Connection::IcebergCatalog(iceberg.into_inline_connection(r))
225            }
226        }
227    }
228}
229
230impl<C: ConnectionAccess> Connection<C> {
231    /// Whether this connection should be validated by default on creation.
232    pub fn validate_by_default(&self) -> bool {
233        match self {
234            Connection::Kafka(conn) => conn.validate_by_default(),
235            Connection::Csr(conn) => conn.validate_by_default(),
236            Connection::Postgres(conn) => conn.validate_by_default(),
237            Connection::Ssh(conn) => conn.validate_by_default(),
238            Connection::Aws(conn) => conn.validate_by_default(),
239            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
240            Connection::MySql(conn) => conn.validate_by_default(),
241            Connection::SqlServer(conn) => conn.validate_by_default(),
242            Connection::IcebergCatalog(conn) => conn.validate_by_default(),
243        }
244    }
245}
246
247impl Connection<InlinedConnection> {
248    /// Validates this connection by attempting to connect to the upstream system.
249    pub async fn validate(
250        &self,
251        id: CatalogItemId,
252        storage_configuration: &StorageConfiguration,
253    ) -> Result<(), ConnectionValidationError> {
254        match self {
255            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
256            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
257            Connection::Postgres(conn) => {
258                conn.validate(id, storage_configuration).await?;
259            }
260            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
261            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
262            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
263            Connection::MySql(conn) => {
264                conn.validate(id, storage_configuration).await?;
265            }
266            Connection::SqlServer(conn) => {
267                conn.validate(id, storage_configuration).await?;
268            }
269            Connection::IcebergCatalog(conn) => conn.validate(id, storage_configuration).await?,
270        }
271        Ok(())
272    }
273
274    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
275        match self {
276            Self::Kafka(conn) => conn,
277            o => unreachable!("{o:?} is not a Kafka connection"),
278        }
279    }
280
281    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
282        match self {
283            Self::Postgres(conn) => conn,
284            o => unreachable!("{o:?} is not a Postgres connection"),
285        }
286    }
287
288    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
289        match self {
290            Self::MySql(conn) => conn,
291            o => unreachable!("{o:?} is not a MySQL connection"),
292        }
293    }
294
295    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
296        match self {
297            Self::SqlServer(conn) => conn,
298            o => unreachable!("{o:?} is not a SQL Server connection"),
299        }
300    }
301
302    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
303        match self {
304            Self::Aws(conn) => conn,
305            o => unreachable!("{o:?} is not an AWS connection"),
306        }
307    }
308
309    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
310        match self {
311            Self::Ssh(conn) => conn,
312            o => unreachable!("{o:?} is not an SSH connection"),
313        }
314    }
315
316    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
317        match self {
318            Self::Csr(conn) => conn,
319            o => unreachable!("{o:?} is not a Kafka connection"),
320        }
321    }
322}
323
324/// An error returned by [`Connection::validate`].
325#[derive(thiserror::Error, Debug)]
326pub enum ConnectionValidationError {
327    #[error(transparent)]
328    Postgres(#[from] PostgresConnectionValidationError),
329    #[error(transparent)]
330    MySql(#[from] MySqlConnectionValidationError),
331    #[error(transparent)]
332    SqlServer(#[from] SqlServerConnectionValidationError),
333    #[error(transparent)]
334    Aws(#[from] AwsConnectionValidationError),
335    #[error("{}", .0.display_with_causes())]
336    Other(#[from] anyhow::Error),
337}
338
339impl ConnectionValidationError {
340    /// Reports additional details about the error, if any are available.
341    pub fn detail(&self) -> Option<String> {
342        match self {
343            ConnectionValidationError::Postgres(e) => e.detail(),
344            ConnectionValidationError::MySql(e) => e.detail(),
345            ConnectionValidationError::SqlServer(e) => e.detail(),
346            ConnectionValidationError::Aws(e) => e.detail(),
347            ConnectionValidationError::Other(_) => None,
348        }
349    }
350
351    /// Reports a hint for the user about how the error could be fixed.
352    pub fn hint(&self) -> Option<String> {
353        match self {
354            ConnectionValidationError::Postgres(e) => e.hint(),
355            ConnectionValidationError::MySql(e) => e.hint(),
356            ConnectionValidationError::SqlServer(e) => e.hint(),
357            ConnectionValidationError::Aws(e) => e.hint(),
358            ConnectionValidationError::Other(_) => None,
359        }
360    }
361}
362
363impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
364    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
365        match (self, other) {
366            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
367            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
368            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
369            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
370            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
371            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
372            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
373            _ => {
374                tracing::warn!(
375                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
376                    self,
377                    other
378                );
379                Err(AlterError { id })
380            }
381        }
382    }
383}
384
385#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
386pub struct RestIcebergCatalog {
387    /// For REST catalogs, the oauth2 credential in a `CLIENT_ID:CLIENT_SECRET` format
388    pub credential: StringOrSecret,
389    /// The oauth2 scope for REST catalogs
390    pub scope: Option<String>,
391    /// The warehouse for REST catalogs
392    pub warehouse: Option<String>,
393}
394
395#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
396pub struct S3TablesRestIcebergCatalog<C: ConnectionAccess = InlinedConnection> {
397    /// The AWS connection details, for s3tables
398    pub aws_connection: AwsConnectionReference<C>,
399    /// The warehouse for s3tables
400    pub warehouse: String,
401}
402
403impl<R: ConnectionResolver> IntoInlineConnection<S3TablesRestIcebergCatalog, R>
404    for S3TablesRestIcebergCatalog<ReferencedConnection>
405{
406    fn into_inline_connection(self, r: R) -> S3TablesRestIcebergCatalog {
407        S3TablesRestIcebergCatalog {
408            aws_connection: self.aws_connection.into_inline_connection(&r),
409            warehouse: self.warehouse,
410        }
411    }
412}
413
414#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
415pub enum IcebergCatalogType {
416    Rest,
417    S3TablesRest,
418}
419
420#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
421pub enum IcebergCatalogImpl<C: ConnectionAccess = InlinedConnection> {
422    Rest(RestIcebergCatalog),
423    S3TablesRest(S3TablesRestIcebergCatalog<C>),
424}
425
426impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogImpl, R>
427    for IcebergCatalogImpl<ReferencedConnection>
428{
429    fn into_inline_connection(self, r: R) -> IcebergCatalogImpl {
430        match self {
431            IcebergCatalogImpl::Rest(rest) => IcebergCatalogImpl::Rest(rest),
432            IcebergCatalogImpl::S3TablesRest(s3tables) => {
433                IcebergCatalogImpl::S3TablesRest(s3tables.into_inline_connection(r))
434            }
435        }
436    }
437}
438
439#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
440pub struct IcebergCatalogConnection<C: ConnectionAccess = InlinedConnection> {
441    /// The catalog impl impl of that catalog
442    pub catalog: IcebergCatalogImpl<C>,
443    /// Where the catalog is located
444    pub uri: reqwest::Url,
445}
446
447impl AlterCompatible for IcebergCatalogConnection {
448    fn alter_compatible(&self, id: GlobalId, _other: &Self) -> Result<(), AlterError> {
449        Err(AlterError { id })
450    }
451}
452
453impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogConnection, R>
454    for IcebergCatalogConnection<ReferencedConnection>
455{
456    fn into_inline_connection(self, r: R) -> IcebergCatalogConnection {
457        IcebergCatalogConnection {
458            catalog: self.catalog.into_inline_connection(&r),
459            uri: self.uri,
460        }
461    }
462}
463
464impl<C: ConnectionAccess> IcebergCatalogConnection<C> {
465    fn validate_by_default(&self) -> bool {
466        true
467    }
468}
469
470impl IcebergCatalogConnection<InlinedConnection> {
471    async fn connect(
472        &self,
473        storage_configuration: &StorageConfiguration,
474        in_task: InTask,
475    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
476        match self.catalog {
477            IcebergCatalogImpl::S3TablesRest(ref s3tables) => {
478                self.connect_s3tables(s3tables, storage_configuration, in_task)
479                    .await
480            }
481            IcebergCatalogImpl::Rest(ref rest) => {
482                self.connect_rest(rest, storage_configuration, in_task)
483                    .await
484            }
485        }
486    }
487
488    async fn connect_s3tables(
489        &self,
490        s3tables: &S3TablesRestIcebergCatalog,
491        storage_configuration: &StorageConfiguration,
492        in_task: InTask,
493    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
494        let mut props = BTreeMap::from([(
495            REST_CATALOG_PROP_URI.to_string(),
496            self.uri.to_string().clone(),
497        )]);
498
499        let aws_ref = &s3tables.aws_connection;
500        let aws_config = aws_ref
501            .connection
502            .load_sdk_config(
503                &storage_configuration.connection_context,
504                aws_ref.connection_id,
505                in_task,
506            )
507            .await?;
508
509        props.insert(
510            REST_CATALOG_PROP_WAREHOUSE.to_string(),
511            s3tables.warehouse.clone(),
512        );
513
514        let catalog = RestCatalogBuilder::default()
515            .with_aws_client(aws_config)
516            .load("IcebergCatalog", props.into_iter().collect())
517            .await
518            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
519        Ok(Arc::new(catalog))
520    }
521
522    async fn connect_rest(
523        &self,
524        rest: &RestIcebergCatalog,
525        storage_configuration: &StorageConfiguration,
526        in_task: InTask,
527    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
528        let mut props = BTreeMap::from([(
529            REST_CATALOG_PROP_URI.to_string(),
530            self.uri.to_string().clone(),
531        )]);
532
533        if let Some(warehouse) = &rest.warehouse {
534            props.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
535        }
536
537        let credential = rest
538            .credential
539            .get_string(
540                in_task,
541                &storage_configuration.connection_context.secrets_reader,
542            )
543            .await
544            .map_err(|e| anyhow!("failed to read Iceberg catalog credential: {e}"))?;
545        props.insert(REST_CATALOG_PROP_CREDENTIAL.to_string(), credential);
546
547        if let Some(scope) = &rest.scope {
548            props.insert(REST_CATALOG_PROP_SCOPE.to_string(), scope.clone());
549        }
550
551        let catalog = RestCatalogBuilder::default()
552            .load("IcebergCatalog", props.into_iter().collect())
553            .await
554            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
555        Ok(Arc::new(catalog))
556    }
557
558    async fn validate(
559        &self,
560        _id: CatalogItemId,
561        storage_configuration: &StorageConfiguration,
562    ) -> Result<(), ConnectionValidationError> {
563        let catalog = self
564            .connect(storage_configuration, InTask::No)
565            .await
566            .map_err(|e| {
567                ConnectionValidationError::Other(anyhow!("failed to connect to catalog: {e}"))
568            })?;
569
570        // If we can list namespaces, the connection is valid.
571        catalog.list_namespaces(None).await.map_err(|e| {
572            ConnectionValidationError::Other(anyhow!("failed to list namespaces: {e}"))
573        })?;
574
575        Ok(())
576    }
577}
578
579#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
580pub struct AwsPrivatelinkConnection {
581    pub service_name: String,
582    pub availability_zones: Vec<String>,
583}
584
585impl AlterCompatible for AwsPrivatelinkConnection {
586    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
587        // Every element of the AwsPrivatelinkConnection connection is configurable.
588        Ok(())
589    }
590}
591
592#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
593pub struct KafkaTlsConfig {
594    pub identity: Option<TlsIdentity>,
595    pub root_cert: Option<StringOrSecret>,
596}
597
598#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
599pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
600    pub mechanism: String,
601    pub username: StringOrSecret,
602    pub password: Option<CatalogItemId>,
603    pub aws: Option<AwsConnectionReference<C>>,
604}
605
606impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
607    for KafkaSaslConfig<ReferencedConnection>
608{
609    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
610        KafkaSaslConfig {
611            mechanism: self.mechanism,
612            username: self.username,
613            password: self.password,
614            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
615        }
616    }
617}
618
619/// Specifies a Kafka broker in a [`KafkaConnection`].
620#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
621pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
622    /// The address of the Kafka broker.
623    pub address: String,
624    /// An optional tunnel to use when connecting to the broker.
625    pub tunnel: Tunnel<C>,
626}
627
628impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
629    for KafkaBroker<ReferencedConnection>
630{
631    fn into_inline_connection(self, r: R) -> KafkaBroker {
632        let KafkaBroker { address, tunnel } = self;
633        KafkaBroker {
634            address,
635            tunnel: tunnel.into_inline_connection(r),
636        }
637    }
638}
639
640#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
641pub struct KafkaTopicOptions {
642    /// The replication factor for the topic.
643    /// If `None`, the broker default will be used.
644    pub replication_factor: Option<NonNeg<i32>>,
645    /// The number of partitions to create.
646    /// If `None`, the broker default will be used.
647    pub partition_count: Option<NonNeg<i32>>,
648    /// The initial configuration parameters for the topic.
649    pub topic_config: BTreeMap<String, String>,
650}
651
652#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
653pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
654    pub brokers: Vec<KafkaBroker<C>>,
655    /// A tunnel through which to route traffic,
656    /// that can be overridden for individual brokers
657    /// in `brokers`.
658    pub default_tunnel: Tunnel<C>,
659    pub progress_topic: Option<String>,
660    pub progress_topic_options: KafkaTopicOptions,
661    pub options: BTreeMap<String, StringOrSecret>,
662    pub tls: Option<KafkaTlsConfig>,
663    pub sasl: Option<KafkaSaslConfig<C>>,
664}
665
666impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
667    for KafkaConnection<ReferencedConnection>
668{
669    fn into_inline_connection(self, r: R) -> KafkaConnection {
670        let KafkaConnection {
671            brokers,
672            progress_topic,
673            progress_topic_options,
674            default_tunnel,
675            options,
676            tls,
677            sasl,
678        } = self;
679
680        let brokers = brokers
681            .into_iter()
682            .map(|broker| broker.into_inline_connection(&r))
683            .collect();
684
685        KafkaConnection {
686            brokers,
687            progress_topic,
688            progress_topic_options,
689            default_tunnel: default_tunnel.into_inline_connection(&r),
690            options,
691            tls,
692            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
693        }
694    }
695}
696
697impl<C: ConnectionAccess> KafkaConnection<C> {
698    /// Returns the name of the progress topic to use for the connection.
699    ///
700    /// The caller is responsible for providing the connection ID as it is not
701    /// known to `KafkaConnection`.
702    pub fn progress_topic(
703        &self,
704        connection_context: &ConnectionContext,
705        connection_id: CatalogItemId,
706    ) -> Cow<'_, str> {
707        if let Some(progress_topic) = &self.progress_topic {
708            Cow::Borrowed(progress_topic)
709        } else {
710            Cow::Owned(format!(
711                "_materialize-progress-{}-{}",
712                connection_context.environment_id, connection_id,
713            ))
714        }
715    }
716
717    fn validate_by_default(&self) -> bool {
718        true
719    }
720}
721
722impl KafkaConnection {
723    /// Generates a string that can be used as the base for a configuration ID
724    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
725    /// or sink.
726    pub fn id_base(
727        connection_context: &ConnectionContext,
728        connection_id: CatalogItemId,
729        object_id: GlobalId,
730    ) -> String {
731        format!(
732            "materialize-{}-{}-{}",
733            connection_context.environment_id, connection_id, object_id,
734        )
735    }
736
737    /// Enriches the provided `client_id` according to any enrichment rules in
738    /// the `kafka_client_id_enrichment_rules` configuration parameter.
739    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
740        #[derive(Debug, Deserialize)]
741        struct EnrichmentRule {
742            #[serde(deserialize_with = "deserialize_regex")]
743            pattern: Regex,
744            payload: String,
745        }
746
747        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
748        where
749            D: Deserializer<'de>,
750        {
751            let buf = String::deserialize(deserializer)?;
752            Regex::new(&buf).map_err(serde::de::Error::custom)
753        }
754
755        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
756        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
757            Ok(rules) => rules,
758            Err(e) => {
759                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
760                return;
761            }
762        };
763
764        // Check every rule against every broker. Rules are matched in the order
765        // that they are specified. It is usually a configuration error if
766        // multiple rules match the same list of Kafka brokers, but we
767        // nonetheless want to provide well defined semantics.
768        debug!(?self.brokers, "evaluating client ID enrichment rules");
769        for rule in rules {
770            let is_match = self
771                .brokers
772                .iter()
773                .any(|b| rule.pattern.is_match(&b.address));
774            debug!(?rule, is_match, "evaluated client ID enrichment rule");
775            if is_match {
776                client_id.push('-');
777                client_id.push_str(&rule.payload);
778            }
779        }
780    }
781
782    /// Creates a Kafka client for the connection.
783    pub async fn create_with_context<C, T>(
784        &self,
785        storage_configuration: &StorageConfiguration,
786        context: C,
787        extra_options: &BTreeMap<&str, String>,
788        in_task: InTask,
789    ) -> Result<T, ContextCreationError>
790    where
791        C: ClientContext,
792        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
793    {
794        let mut options = self.options.clone();
795
796        // Ensure that Kafka topics are *not* automatically created when
797        // consuming, producing, or fetching metadata for a topic. This ensures
798        // that we don't accidentally create topics with the wrong number of
799        // partitions.
800        options.insert("allow.auto.create.topics".into(), "false".into());
801
802        let brokers = match &self.default_tunnel {
803            Tunnel::AwsPrivatelink(t) => {
804                assert!(&self.brokers.is_empty());
805
806                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
807                    .get(storage_configuration.config_set());
808                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
809
810                // When using a default privatelink tunnel broker/brokers cannot be specified
811                // instead the tunnel connection_id and port are used for the initial connection.
812                format!(
813                    "{}:{}",
814                    vpc_endpoint_host(
815                        t.connection_id,
816                        None, // Default tunnel does not support availability zones.
817                    ),
818                    t.port.unwrap_or(9092)
819                )
820            }
821            _ => self.brokers.iter().map(|b| &b.address).join(","),
822        };
823        options.insert("bootstrap.servers".into(), brokers.into());
824        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
825            (false, false) => "PLAINTEXT",
826            (true, false) => "SSL",
827            (false, true) => "SASL_PLAINTEXT",
828            (true, true) => "SASL_SSL",
829        };
830        options.insert("security.protocol".into(), security_protocol.into());
831        if let Some(tls) = &self.tls {
832            if let Some(root_cert) = &tls.root_cert {
833                options.insert("ssl.ca.pem".into(), root_cert.clone());
834            }
835            if let Some(identity) = &tls.identity {
836                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
837                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
838            }
839        }
840        if let Some(sasl) = &self.sasl {
841            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
842            options.insert("sasl.username".into(), sasl.username.clone());
843            if let Some(password) = sasl.password {
844                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
845            }
846        }
847
848        let mut config = mz_kafka_util::client::create_new_client_config(
849            storage_configuration
850                .connection_context
851                .librdkafka_log_level,
852            storage_configuration.parameters.kafka_timeout_config,
853        );
854        for (k, v) in options {
855            config.set(
856                k,
857                v.get_string(
858                    in_task,
859                    &storage_configuration.connection_context.secrets_reader,
860                )
861                .await
862                .context("reading kafka secret")?,
863            );
864        }
865        for (k, v) in extra_options {
866            config.set(*k, v);
867        }
868
869        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
870            None => None,
871            Some(aws) => Some(
872                aws.connection
873                    .load_sdk_config(
874                        &storage_configuration.connection_context,
875                        aws.connection_id,
876                        in_task,
877                    )
878                    .await?,
879            ),
880        };
881
882        // TODO(roshan): Implement enforcement of external address validation once
883        // rdkafka client has been updated to support providing multiple resolved
884        // addresses for brokers
885        let mut context = TunnelingClientContext::new(
886            context,
887            Handle::current(),
888            storage_configuration
889                .connection_context
890                .ssh_tunnel_manager
891                .clone(),
892            storage_configuration.parameters.ssh_timeout_config,
893            aws_config,
894            in_task,
895        );
896
897        match &self.default_tunnel {
898            Tunnel::Direct => {
899                // By default, don't offer a default override for broker address lookup.
900            }
901            Tunnel::AwsPrivatelink(pl) => {
902                context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
903                    pl.connection_id,
904                    None, // Default tunnel does not support availability zones.
905                )));
906            }
907            Tunnel::Ssh(ssh_tunnel) => {
908                let secret = storage_configuration
909                    .connection_context
910                    .secrets_reader
911                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
912                    .await?;
913                let key_pair = SshKeyPair::from_bytes(&secret)?;
914
915                // Ensure any ssh-bastion address we connect to is resolved to an external address.
916                let resolved = resolve_address(
917                    &ssh_tunnel.connection.host,
918                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
919                )
920                .await?;
921                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
922                    host: resolved
923                        .iter()
924                        .map(|a| a.to_string())
925                        .collect::<BTreeSet<_>>(),
926                    port: ssh_tunnel.connection.port,
927                    user: ssh_tunnel.connection.user.clone(),
928                    key_pair,
929                }));
930            }
931        }
932
933        for broker in &self.brokers {
934            let mut addr_parts = broker.address.splitn(2, ':');
935            let addr = BrokerAddr {
936                host: addr_parts
937                    .next()
938                    .context("BROKER is not address:port")?
939                    .into(),
940                port: addr_parts
941                    .next()
942                    .unwrap_or("9092")
943                    .parse()
944                    .context("parsing BROKER port")?,
945            };
946            match &broker.tunnel {
947                Tunnel::Direct => {
948                    // By default, don't override broker address lookup.
949                    //
950                    // N.B.
951                    //
952                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
953                    // we avoid doing because:
954                    // - Its not necessary.
955                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
956                    // in the `TunnelingClientContext`.
957                }
958                Tunnel::AwsPrivatelink(aws_privatelink) => {
959                    let host = mz_cloud_resources::vpc_endpoint_host(
960                        aws_privatelink.connection_id,
961                        aws_privatelink.availability_zone.as_deref(),
962                    );
963                    let port = aws_privatelink.port;
964                    context.add_broker_rewrite(
965                        addr,
966                        BrokerRewrite {
967                            host: host.clone(),
968                            port,
969                        },
970                    );
971                }
972                Tunnel::Ssh(ssh_tunnel) => {
973                    // Ensure any SSH bastion address we connect to is resolved to an external address.
974                    let ssh_host_resolved = resolve_address(
975                        &ssh_tunnel.connection.host,
976                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
977                    )
978                    .await?;
979                    context
980                        .add_ssh_tunnel(
981                            addr,
982                            SshTunnelConfig {
983                                host: ssh_host_resolved
984                                    .iter()
985                                    .map(|a| a.to_string())
986                                    .collect::<BTreeSet<_>>(),
987                                port: ssh_tunnel.connection.port,
988                                user: ssh_tunnel.connection.user.clone(),
989                                key_pair: SshKeyPair::from_bytes(
990                                    &storage_configuration
991                                        .connection_context
992                                        .secrets_reader
993                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
994                                        .await?,
995                                )?,
996                            },
997                        )
998                        .await
999                        .map_err(ContextCreationError::Ssh)?;
1000                }
1001            }
1002        }
1003
1004        Ok(config.create_with_context(context)?)
1005    }
1006
1007    async fn validate(
1008        &self,
1009        _id: CatalogItemId,
1010        storage_configuration: &StorageConfiguration,
1011    ) -> Result<(), anyhow::Error> {
1012        let (context, error_rx) = MzClientContext::with_errors();
1013        let consumer: BaseConsumer<_> = self
1014            .create_with_context(
1015                storage_configuration,
1016                context,
1017                &BTreeMap::new(),
1018                // We are in a normal tokio context during validation, already.
1019                InTask::No,
1020            )
1021            .await?;
1022        let consumer = Arc::new(consumer);
1023
1024        let timeout = storage_configuration
1025            .parameters
1026            .kafka_timeout_config
1027            .fetch_metadata_timeout;
1028
1029        // librdkafka doesn't expose an API for determining whether a connection to
1030        // the Kafka cluster has been successfully established. So we make a
1031        // metadata request, though we don't care about the results, so that we can
1032        // report any errors making that request. If the request succeeds, we know
1033        // we were able to contact at least one broker, and that's a good proxy for
1034        // being able to contact all the brokers in the cluster.
1035        //
1036        // The downside of this approach is it produces a generic error message like
1037        // "metadata fetch error" with no additional details. The real networking
1038        // error is buried in the librdkafka logs, which are not visible to users.
1039        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
1040            let consumer = Arc::clone(&consumer);
1041            move || consumer.fetch_metadata(None, timeout)
1042        })
1043        .await?;
1044        match result {
1045            Ok(_) => Ok(()),
1046            // The error returned by `fetch_metadata` does not provide any details which makes for
1047            // a crappy user facing error message. For this reason we attempt to grab a better
1048            // error message from the client context, which should contain any error logs emitted
1049            // by librdkafka, and fallback to the generic error if there is nothing there.
1050            Err(err) => {
1051                // Multiple errors might have been logged during this validation but some are more
1052                // relevant than others. Specifically, we prefer non-internal errors over internal
1053                // errors since those give much more useful information to the users.
1054                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
1055                    MzKafkaError::Internal(_) => new,
1056                    _ => cur,
1057                });
1058
1059                // Don't drop the consumer until after we've drained the errors
1060                // channel. Dropping the consumer can introduce spurious errors.
1061                // See database-issues#7432.
1062                drop(consumer);
1063
1064                match main_err {
1065                    Some(err) => Err(err.into()),
1066                    None => Err(err.into()),
1067                }
1068            }
1069        }
1070    }
1071}
1072
1073impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
1074    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1075        let KafkaConnection {
1076            brokers: _,
1077            default_tunnel: _,
1078            progress_topic,
1079            progress_topic_options,
1080            options: _,
1081            tls: _,
1082            sasl: _,
1083        } = self;
1084
1085        let compatibility_checks = [
1086            (progress_topic == &other.progress_topic, "progress_topic"),
1087            (
1088                progress_topic_options == &other.progress_topic_options,
1089                "progress_topic_options",
1090            ),
1091        ];
1092
1093        for (compatible, field) in compatibility_checks {
1094            if !compatible {
1095                tracing::warn!(
1096                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1097                    self,
1098                    other
1099                );
1100
1101                return Err(AlterError { id });
1102            }
1103        }
1104
1105        Ok(())
1106    }
1107}
1108
1109/// A connection to a Confluent Schema Registry.
1110#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1111pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1112    /// The URL of the schema registry.
1113    pub url: Url,
1114    /// Trusted root TLS certificate in PEM format.
1115    pub tls_root_cert: Option<StringOrSecret>,
1116    /// An optional TLS client certificate for authentication with the schema
1117    /// registry.
1118    pub tls_identity: Option<TlsIdentity>,
1119    /// Optional HTTP authentication credentials for the schema registry.
1120    pub http_auth: Option<CsrConnectionHttpAuth>,
1121    /// A tunnel through which to route traffic.
1122    pub tunnel: Tunnel<C>,
1123}
1124
1125impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1126    for CsrConnection<ReferencedConnection>
1127{
1128    fn into_inline_connection(self, r: R) -> CsrConnection {
1129        let CsrConnection {
1130            url,
1131            tls_root_cert,
1132            tls_identity,
1133            http_auth,
1134            tunnel,
1135        } = self;
1136        CsrConnection {
1137            url,
1138            tls_root_cert,
1139            tls_identity,
1140            http_auth,
1141            tunnel: tunnel.into_inline_connection(r),
1142        }
1143    }
1144}
1145
1146impl<C: ConnectionAccess> CsrConnection<C> {
1147    fn validate_by_default(&self) -> bool {
1148        true
1149    }
1150}
1151
1152impl CsrConnection {
1153    /// Constructs a schema registry client from the connection.
1154    pub async fn connect(
1155        &self,
1156        storage_configuration: &StorageConfiguration,
1157        in_task: InTask,
1158    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1159        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1160        if let Some(root_cert) = &self.tls_root_cert {
1161            let root_cert = root_cert
1162                .get_string(
1163                    in_task,
1164                    &storage_configuration.connection_context.secrets_reader,
1165                )
1166                .await?;
1167            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1168            client_config = client_config.add_root_certificate(root_cert);
1169        }
1170
1171        if let Some(tls_identity) = &self.tls_identity {
1172            let key = &storage_configuration
1173                .connection_context
1174                .secrets_reader
1175                .read_string_in_task_if(in_task, tls_identity.key)
1176                .await?;
1177            let cert = tls_identity
1178                .cert
1179                .get_string(
1180                    in_task,
1181                    &storage_configuration.connection_context.secrets_reader,
1182                )
1183                .await?;
1184            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1185            client_config = client_config.identity(ident);
1186        }
1187
1188        if let Some(http_auth) = &self.http_auth {
1189            let username = http_auth
1190                .username
1191                .get_string(
1192                    in_task,
1193                    &storage_configuration.connection_context.secrets_reader,
1194                )
1195                .await?;
1196            let password = match http_auth.password {
1197                None => None,
1198                Some(password) => Some(
1199                    storage_configuration
1200                        .connection_context
1201                        .secrets_reader
1202                        .read_string_in_task_if(in_task, password)
1203                        .await?,
1204                ),
1205            };
1206            client_config = client_config.auth(username, password);
1207        }
1208
1209        // TODO: use types to enforce that the URL has a string hostname.
1210        let host = self
1211            .url
1212            .host_str()
1213            .ok_or_else(|| anyhow!("url missing host"))?;
1214        match &self.tunnel {
1215            Tunnel::Direct => {
1216                // Ensure any host we connect to is resolved to an external address.
1217                let resolved = resolve_address(
1218                    host,
1219                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1220                )
1221                .await?;
1222                client_config = client_config.resolve_to_addrs(
1223                    host,
1224                    &resolved
1225                        .iter()
1226                        .map(|addr| SocketAddr::new(*addr, DUMMY_DNS_PORT))
1227                        .collect::<Vec<_>>(),
1228                )
1229            }
1230            Tunnel::Ssh(ssh_tunnel) => {
1231                let ssh_tunnel = ssh_tunnel
1232                    .connect(
1233                        storage_configuration,
1234                        host,
1235                        // Default to the default http port, but this
1236                        // could default to 8081...
1237                        self.url.port().unwrap_or(80),
1238                        in_task,
1239                    )
1240                    .await
1241                    .map_err(CsrConnectError::Ssh)?;
1242
1243                // Carefully inject the SSH tunnel into the client
1244                // configuration. This is delicate because we need TLS
1245                // verification to continue to use the remote hostname rather
1246                // than the tunnel hostname.
1247
1248                client_config = client_config
1249                    // `resolve_to_addrs` allows us to rewrite the hostname
1250                    // at the DNS level, which means the TCP connection is
1251                    // correctly routed through the tunnel, but TLS verification
1252                    // is still performed against the remote hostname.
1253                    // Unfortunately the port here is ignored...
1254                    .resolve_to_addrs(
1255                        host,
1256                        &[SocketAddr::new(
1257                            ssh_tunnel.local_addr().ip(),
1258                            DUMMY_DNS_PORT,
1259                        )],
1260                    )
1261                    // ...so we also dynamically rewrite the URL to use the
1262                    // current port for the SSH tunnel.
1263                    //
1264                    // WARNING: this is brittle, because we only dynamically
1265                    // update the client configuration with the tunnel *port*,
1266                    // and not the hostname This works fine in practice, because
1267                    // only the SSH tunnel port will change if the tunnel fails
1268                    // and has to be restarted (the hostname is always
1269                    // 127.0.0.1)--but this is an an implementation detail of
1270                    // the SSH tunnel code that we're relying on.
1271                    .dynamic_url({
1272                        let remote_url = self.url.clone();
1273                        move || {
1274                            let mut url = remote_url.clone();
1275                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1276                                .expect("cannot fail");
1277                            url
1278                        }
1279                    });
1280            }
1281            Tunnel::AwsPrivatelink(connection) => {
1282                assert_none!(connection.port);
1283
1284                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1285                    connection.connection_id,
1286                    connection.availability_zone.as_deref(),
1287                );
1288                let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_DNS_PORT))
1289                    .await
1290                    .context("resolving PrivateLink host")?
1291                    .collect();
1292                client_config = client_config.resolve_to_addrs(host, &addrs)
1293            }
1294        }
1295
1296        Ok(client_config.build()?)
1297    }
1298
1299    async fn validate(
1300        &self,
1301        _id: CatalogItemId,
1302        storage_configuration: &StorageConfiguration,
1303    ) -> Result<(), anyhow::Error> {
1304        let client = self
1305            .connect(
1306                storage_configuration,
1307                // We are in a normal tokio context during validation, already.
1308                InTask::No,
1309            )
1310            .await?;
1311        client.list_subjects().await?;
1312        Ok(())
1313    }
1314}
1315
1316impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1317    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1318        let CsrConnection {
1319            tunnel,
1320            // All non-tunnel fields may change
1321            url: _,
1322            tls_root_cert: _,
1323            tls_identity: _,
1324            http_auth: _,
1325        } = self;
1326
1327        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1328
1329        for (compatible, field) in compatibility_checks {
1330            if !compatible {
1331                tracing::warn!(
1332                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1333                    self,
1334                    other
1335                );
1336
1337                return Err(AlterError { id });
1338            }
1339        }
1340        Ok(())
1341    }
1342}
1343
1344/// A TLS key pair used for client identity.
1345#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1346pub struct TlsIdentity {
1347    /// The client's TLS public certificate in PEM format.
1348    pub cert: StringOrSecret,
1349    /// The ID of the secret containing the client's TLS private key in PEM
1350    /// format.
1351    pub key: CatalogItemId,
1352}
1353
1354/// HTTP authentication credentials in a [`CsrConnection`].
1355#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1356pub struct CsrConnectionHttpAuth {
1357    /// The username.
1358    pub username: StringOrSecret,
1359    /// The ID of the secret containing the password, if any.
1360    pub password: Option<CatalogItemId>,
1361}
1362
1363/// A connection to a PostgreSQL server.
1364#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1365pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1366    /// The hostname of the server.
1367    pub host: String,
1368    /// The port of the server.
1369    pub port: u16,
1370    /// The name of the database to connect to.
1371    pub database: String,
1372    /// The username to authenticate as.
1373    pub user: StringOrSecret,
1374    /// An optional password for authentication.
1375    pub password: Option<CatalogItemId>,
1376    /// A tunnel through which to route traffic.
1377    pub tunnel: Tunnel<C>,
1378    /// Whether to use TLS for encryption, authentication, or both.
1379    pub tls_mode: SslMode,
1380    /// An optional root TLS certificate in PEM format, to verify the server's
1381    /// identity.
1382    pub tls_root_cert: Option<StringOrSecret>,
1383    /// An optional TLS client certificate for authentication.
1384    pub tls_identity: Option<TlsIdentity>,
1385}
1386
1387impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1388    for PostgresConnection<ReferencedConnection>
1389{
1390    fn into_inline_connection(self, r: R) -> PostgresConnection {
1391        let PostgresConnection {
1392            host,
1393            port,
1394            database,
1395            user,
1396            password,
1397            tunnel,
1398            tls_mode,
1399            tls_root_cert,
1400            tls_identity,
1401        } = self;
1402
1403        PostgresConnection {
1404            host,
1405            port,
1406            database,
1407            user,
1408            password,
1409            tunnel: tunnel.into_inline_connection(r),
1410            tls_mode,
1411            tls_root_cert,
1412            tls_identity,
1413        }
1414    }
1415}
1416
1417impl<C: ConnectionAccess> PostgresConnection<C> {
1418    fn validate_by_default(&self) -> bool {
1419        true
1420    }
1421}
1422
1423impl PostgresConnection<InlinedConnection> {
1424    pub async fn config(
1425        &self,
1426        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1427        storage_configuration: &StorageConfiguration,
1428        in_task: InTask,
1429    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1430        let params = &storage_configuration.parameters;
1431
1432        let mut config = tokio_postgres::Config::new();
1433        config
1434            .host(&self.host)
1435            .port(self.port)
1436            .dbname(&self.database)
1437            .user(&self.user.get_string(in_task, secrets_reader).await?)
1438            .ssl_mode(self.tls_mode);
1439        if let Some(password) = self.password {
1440            let password = secrets_reader
1441                .read_string_in_task_if(in_task, password)
1442                .await?;
1443            config.password(password);
1444        }
1445        if let Some(tls_root_cert) = &self.tls_root_cert {
1446            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1447            config.ssl_root_cert(tls_root_cert.as_bytes());
1448        }
1449        if let Some(tls_identity) = &self.tls_identity {
1450            let cert = tls_identity
1451                .cert
1452                .get_string(in_task, secrets_reader)
1453                .await?;
1454            let key = secrets_reader
1455                .read_string_in_task_if(in_task, tls_identity.key)
1456                .await?;
1457            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1458        }
1459
1460        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1461            config.connect_timeout(connect_timeout);
1462        }
1463        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1464            config.keepalives_retries(keepalives_retries);
1465        }
1466        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1467            config.keepalives_idle(keepalives_idle);
1468        }
1469        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1470            config.keepalives_interval(keepalives_interval);
1471        }
1472        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1473            config.tcp_user_timeout(tcp_user_timeout);
1474        }
1475
1476        let mut options = vec![];
1477        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1478            options.push(format!(
1479                "--wal_sender_timeout={}",
1480                wal_sender_timeout.as_millis()
1481            ));
1482        };
1483        if params.pg_source_tcp_configure_server {
1484            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1485                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1486            }
1487            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1488                options.push(format!(
1489                    "--tcp_keepalives_idle={}",
1490                    keepalives_idle.as_secs()
1491                ));
1492            }
1493            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1494                options.push(format!(
1495                    "--tcp_keepalives_interval={}",
1496                    keepalives_interval.as_secs()
1497                ));
1498            }
1499            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1500                options.push(format!(
1501                    "--tcp_user_timeout={}",
1502                    tcp_user_timeout.as_millis()
1503                ));
1504            }
1505        }
1506        config.options(options.join(" ").as_str());
1507
1508        let tunnel = match &self.tunnel {
1509            Tunnel::Direct => {
1510                // Ensure any host we connect to is resolved to an external address.
1511                let resolved = resolve_address(
1512                    &self.host,
1513                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1514                )
1515                .await?;
1516                mz_postgres_util::TunnelConfig::Direct {
1517                    resolved_ips: Some(resolved),
1518                }
1519            }
1520            Tunnel::Ssh(SshTunnel {
1521                connection_id,
1522                connection,
1523            }) => {
1524                let secret = secrets_reader
1525                    .read_in_task_if(in_task, *connection_id)
1526                    .await?;
1527                let key_pair = SshKeyPair::from_bytes(&secret)?;
1528                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1529                let resolved = resolve_address(
1530                    &connection.host,
1531                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1532                )
1533                .await?;
1534                mz_postgres_util::TunnelConfig::Ssh {
1535                    config: SshTunnelConfig {
1536                        host: resolved
1537                            .iter()
1538                            .map(|a| a.to_string())
1539                            .collect::<BTreeSet<_>>(),
1540                        port: connection.port,
1541                        user: connection.user.clone(),
1542                        key_pair,
1543                    },
1544                }
1545            }
1546            Tunnel::AwsPrivatelink(connection) => {
1547                assert_none!(connection.port);
1548                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1549                    connection_id: connection.connection_id,
1550                }
1551            }
1552        };
1553
1554        Ok(mz_postgres_util::Config::new(
1555            config,
1556            tunnel,
1557            params.ssh_timeout_config,
1558            in_task,
1559        )?)
1560    }
1561
1562    pub async fn validate(
1563        &self,
1564        _id: CatalogItemId,
1565        storage_configuration: &StorageConfiguration,
1566    ) -> Result<mz_postgres_util::Client, anyhow::Error> {
1567        let config = self
1568            .config(
1569                &storage_configuration.connection_context.secrets_reader,
1570                storage_configuration,
1571                // We are in a normal tokio context during validation, already.
1572                InTask::No,
1573            )
1574            .await?;
1575        let client = config
1576            .connect(
1577                "connection validation",
1578                &storage_configuration.connection_context.ssh_tunnel_manager,
1579            )
1580            .await?;
1581
1582        let wal_level = mz_postgres_util::get_wal_level(&client).await?;
1583
1584        if wal_level < mz_postgres_util::replication::WalLevel::Logical {
1585            Err(PostgresConnectionValidationError::InsufficientWalLevel { wal_level })?;
1586        }
1587
1588        let max_wal_senders = mz_postgres_util::get_max_wal_senders(&client).await?;
1589
1590        if max_wal_senders < 1 {
1591            Err(PostgresConnectionValidationError::ReplicationDisabled)?;
1592        }
1593
1594        let available_replication_slots =
1595            mz_postgres_util::available_replication_slots(&client).await?;
1596
1597        // We need 1 replication slot for the snapshots and 1 for the continuing replication
1598        if available_replication_slots < 2 {
1599            Err(
1600                PostgresConnectionValidationError::InsufficientReplicationSlotsAvailable {
1601                    count: 2,
1602                },
1603            )?;
1604        }
1605
1606        Ok(client)
1607    }
1608}
1609
1610#[derive(Debug, Clone, thiserror::Error)]
1611pub enum PostgresConnectionValidationError {
1612    #[error("PostgreSQL server has insufficient number of replication slots available")]
1613    InsufficientReplicationSlotsAvailable { count: usize },
1614    #[error("server must have wal_level >= logical, but has {wal_level}")]
1615    InsufficientWalLevel {
1616        wal_level: mz_postgres_util::replication::WalLevel,
1617    },
1618    #[error("replication disabled on server")]
1619    ReplicationDisabled,
1620}
1621
1622impl PostgresConnectionValidationError {
1623    pub fn detail(&self) -> Option<String> {
1624        match self {
1625            Self::InsufficientReplicationSlotsAvailable { count } => Some(format!(
1626                "executing this statement requires {} replication slot{}",
1627                count,
1628                if *count == 1 { "" } else { "s" }
1629            )),
1630            _ => None,
1631        }
1632    }
1633
1634    pub fn hint(&self) -> Option<String> {
1635        match self {
1636            Self::InsufficientReplicationSlotsAvailable { .. } => Some(
1637                "you might be able to wait for other sources to finish snapshotting and try again"
1638                    .into(),
1639            ),
1640            Self::ReplicationDisabled => Some("set max_wal_senders to a value > 0".into()),
1641            _ => None,
1642        }
1643    }
1644}
1645
1646impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1647    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1648        let PostgresConnection {
1649            tunnel,
1650            // All non-tunnel options may change arbitrarily
1651            host: _,
1652            port: _,
1653            database: _,
1654            user: _,
1655            password: _,
1656            tls_mode: _,
1657            tls_root_cert: _,
1658            tls_identity: _,
1659        } = self;
1660
1661        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1662
1663        for (compatible, field) in compatibility_checks {
1664            if !compatible {
1665                tracing::warn!(
1666                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1667                    self,
1668                    other
1669                );
1670
1671                return Err(AlterError { id });
1672            }
1673        }
1674        Ok(())
1675    }
1676}
1677
1678/// Specifies how to tunnel a connection.
1679#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1680pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1681    /// No tunneling.
1682    Direct,
1683    /// Via the specified SSH tunnel connection.
1684    Ssh(SshTunnel<C>),
1685    /// Via the specified AWS PrivateLink connection.
1686    AwsPrivatelink(AwsPrivatelink),
1687}
1688
1689impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1690    fn into_inline_connection(self, r: R) -> Tunnel {
1691        match self {
1692            Tunnel::Direct => Tunnel::Direct,
1693            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1694            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1695        }
1696    }
1697}
1698
1699impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1700    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1701        let compatible = match (self, other) {
1702            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1703            (s, o) => s == o,
1704        };
1705
1706        if !compatible {
1707            tracing::warn!(
1708                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1709                self,
1710                other
1711            );
1712
1713            return Err(AlterError { id });
1714        }
1715
1716        Ok(())
1717    }
1718}
1719
1720/// Specifies which MySQL SSL Mode to use:
1721/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1722/// This is not available as an enum in the mysql-async crate, so we define our own.
1723#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1724pub enum MySqlSslMode {
1725    Disabled,
1726    Required,
1727    VerifyCa,
1728    VerifyIdentity,
1729}
1730
1731/// A connection to a MySQL server.
1732#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1733pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1734    /// The hostname of the server.
1735    pub host: String,
1736    /// The port of the server.
1737    pub port: u16,
1738    /// The username to authenticate as.
1739    pub user: StringOrSecret,
1740    /// An optional password for authentication.
1741    pub password: Option<CatalogItemId>,
1742    /// A tunnel through which to route traffic.
1743    pub tunnel: Tunnel<C>,
1744    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1745    pub tls_mode: MySqlSslMode,
1746    /// An optional root TLS certificate in PEM format, to verify the server's
1747    /// identity.
1748    pub tls_root_cert: Option<StringOrSecret>,
1749    /// An optional TLS client certificate for authentication.
1750    pub tls_identity: Option<TlsIdentity>,
1751    /// Reference to the AWS connection information to be used for IAM authenitcation and
1752    /// assuming AWS roles.
1753    pub aws_connection: Option<AwsConnectionReference<C>>,
1754}
1755
1756impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
1757    for MySqlConnection<ReferencedConnection>
1758{
1759    fn into_inline_connection(self, r: R) -> MySqlConnection {
1760        let MySqlConnection {
1761            host,
1762            port,
1763            user,
1764            password,
1765            tunnel,
1766            tls_mode,
1767            tls_root_cert,
1768            tls_identity,
1769            aws_connection,
1770        } = self;
1771
1772        MySqlConnection {
1773            host,
1774            port,
1775            user,
1776            password,
1777            tunnel: tunnel.into_inline_connection(&r),
1778            tls_mode,
1779            tls_root_cert,
1780            tls_identity,
1781            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
1782        }
1783    }
1784}
1785
1786impl<C: ConnectionAccess> MySqlConnection<C> {
1787    fn validate_by_default(&self) -> bool {
1788        true
1789    }
1790}
1791
1792impl MySqlConnection<InlinedConnection> {
1793    pub async fn config(
1794        &self,
1795        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1796        storage_configuration: &StorageConfiguration,
1797        in_task: InTask,
1798    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
1799        // TODO(roshan): Set appropriate connection timeouts
1800        let mut opts = mysql_async::OptsBuilder::default()
1801            .ip_or_hostname(&self.host)
1802            .tcp_port(self.port)
1803            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
1804
1805        if let Some(password) = self.password {
1806            let password = secrets_reader
1807                .read_string_in_task_if(in_task, password)
1808                .await?;
1809            opts = opts.pass(Some(password));
1810        }
1811
1812        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
1813        // which uses opt-in security features (SSL, CA verification, & Identity verification).
1814        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
1815        // we need to appropriately disable features to match the intent of each enum value.
1816        let mut ssl_opts = match self.tls_mode {
1817            MySqlSslMode::Disabled => None,
1818            MySqlSslMode::Required => Some(
1819                mysql_async::SslOpts::default()
1820                    .with_danger_accept_invalid_certs(true)
1821                    .with_danger_skip_domain_validation(true),
1822            ),
1823            MySqlSslMode::VerifyCa => {
1824                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
1825            }
1826            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
1827        };
1828
1829        if matches!(
1830            self.tls_mode,
1831            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
1832        ) {
1833            if let Some(tls_root_cert) = &self.tls_root_cert {
1834                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1835                ssl_opts = ssl_opts.map(|opts| {
1836                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
1837                });
1838            }
1839        }
1840
1841        if let Some(identity) = &self.tls_identity {
1842            let key = secrets_reader
1843                .read_string_in_task_if(in_task, identity.key)
1844                .await?;
1845            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
1846            let Pkcs12Archive { der, pass } =
1847                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
1848
1849            // Add client identity to SSLOpts
1850            ssl_opts = ssl_opts.map(|opts| {
1851                opts.with_client_identity(Some(
1852                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
1853                ))
1854            });
1855        }
1856
1857        opts = opts.ssl_opts(ssl_opts);
1858
1859        let tunnel = match &self.tunnel {
1860            Tunnel::Direct => {
1861                // Ensure any host we connect to is resolved to an external address.
1862                let resolved = resolve_address(
1863                    &self.host,
1864                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1865                )
1866                .await?;
1867                mz_mysql_util::TunnelConfig::Direct {
1868                    resolved_ips: Some(resolved),
1869                }
1870            }
1871            Tunnel::Ssh(SshTunnel {
1872                connection_id,
1873                connection,
1874            }) => {
1875                let secret = secrets_reader
1876                    .read_in_task_if(in_task, *connection_id)
1877                    .await?;
1878                let key_pair = SshKeyPair::from_bytes(&secret)?;
1879                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1880                let resolved = resolve_address(
1881                    &connection.host,
1882                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1883                )
1884                .await?;
1885                mz_mysql_util::TunnelConfig::Ssh {
1886                    config: SshTunnelConfig {
1887                        host: resolved
1888                            .iter()
1889                            .map(|a| a.to_string())
1890                            .collect::<BTreeSet<_>>(),
1891                        port: connection.port,
1892                        user: connection.user.clone(),
1893                        key_pair,
1894                    },
1895                }
1896            }
1897            Tunnel::AwsPrivatelink(connection) => {
1898                assert_none!(connection.port);
1899                mz_mysql_util::TunnelConfig::AwsPrivatelink {
1900                    connection_id: connection.connection_id,
1901                }
1902            }
1903        };
1904
1905        let aws_config = match self.aws_connection.as_ref() {
1906            None => None,
1907            Some(aws_ref) => Some(
1908                aws_ref
1909                    .connection
1910                    .load_sdk_config(
1911                        &storage_configuration.connection_context,
1912                        aws_ref.connection_id,
1913                        in_task,
1914                    )
1915                    .await?,
1916            ),
1917        };
1918
1919        Ok(mz_mysql_util::Config::new(
1920            opts,
1921            tunnel,
1922            storage_configuration.parameters.ssh_timeout_config,
1923            in_task,
1924            storage_configuration
1925                .parameters
1926                .mysql_source_timeouts
1927                .clone(),
1928            aws_config,
1929        )?)
1930    }
1931
1932    pub async fn validate(
1933        &self,
1934        _id: CatalogItemId,
1935        storage_configuration: &StorageConfiguration,
1936    ) -> Result<MySqlConn, MySqlConnectionValidationError> {
1937        let config = self
1938            .config(
1939                &storage_configuration.connection_context.secrets_reader,
1940                storage_configuration,
1941                // We are in a normal tokio context during validation, already.
1942                InTask::No,
1943            )
1944            .await?;
1945        let mut conn = config
1946            .connect(
1947                "connection validation",
1948                &storage_configuration.connection_context.ssh_tunnel_manager,
1949            )
1950            .await?;
1951
1952        // Check if the MySQL database is configured to allow row-based consistent GTID replication
1953        let mut setting_errors = vec![];
1954        let gtid_res = mz_mysql_util::ensure_gtid_consistency(&mut conn).await;
1955        let binlog_res = mz_mysql_util::ensure_full_row_binlog_format(&mut conn).await;
1956        let order_res = mz_mysql_util::ensure_replication_commit_order(&mut conn).await;
1957        for res in [gtid_res, binlog_res, order_res] {
1958            match res {
1959                Err(MySqlError::InvalidSystemSetting {
1960                    setting,
1961                    expected,
1962                    actual,
1963                }) => {
1964                    setting_errors.push((setting, expected, actual));
1965                }
1966                Err(err) => Err(err)?,
1967                Ok(()) => {}
1968            }
1969        }
1970        if !setting_errors.is_empty() {
1971            Err(MySqlConnectionValidationError::ReplicationSettingsError(
1972                setting_errors,
1973            ))?;
1974        }
1975
1976        Ok(conn)
1977    }
1978}
1979
1980#[derive(Debug, thiserror::Error)]
1981pub enum MySqlConnectionValidationError {
1982    #[error("Invalid MySQL system replication settings")]
1983    ReplicationSettingsError(Vec<(String, String, String)>),
1984    #[error(transparent)]
1985    Client(#[from] MySqlError),
1986    #[error("{}", .0.display_with_causes())]
1987    Other(#[from] anyhow::Error),
1988}
1989
1990impl MySqlConnectionValidationError {
1991    pub fn detail(&self) -> Option<String> {
1992        match self {
1993            Self::ReplicationSettingsError(settings) => Some(format!(
1994                "Invalid MySQL system replication settings: {}",
1995                itertools::join(
1996                    settings.iter().map(|(setting, expected, actual)| format!(
1997                        "{}: expected {}, got {}",
1998                        setting, expected, actual
1999                    )),
2000                    "; "
2001                )
2002            )),
2003            _ => None,
2004        }
2005    }
2006
2007    pub fn hint(&self) -> Option<String> {
2008        match self {
2009            Self::ReplicationSettingsError(_) => {
2010                Some("Set the necessary MySQL database system settings.".into())
2011            }
2012            _ => None,
2013        }
2014    }
2015}
2016
2017impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
2018    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2019        let MySqlConnection {
2020            tunnel,
2021            // All non-tunnel options may change arbitrarily
2022            host: _,
2023            port: _,
2024            user: _,
2025            password: _,
2026            tls_mode: _,
2027            tls_root_cert: _,
2028            tls_identity: _,
2029            aws_connection: _,
2030        } = self;
2031
2032        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2033
2034        for (compatible, field) in compatibility_checks {
2035            if !compatible {
2036                tracing::warn!(
2037                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2038                    self,
2039                    other
2040                );
2041
2042                return Err(AlterError { id });
2043            }
2044        }
2045        Ok(())
2046    }
2047}
2048
2049/// Details how to connect to an instance of Microsoft SQL Server.
2050///
2051/// For specifics of connecting to SQL Server for purposes of creating a
2052/// Materialize Source, see [`SqlServerSourceConnection`] which wraps this type.
2053///
2054/// [`SqlServerSourceConnection`]: crate::sources::SqlServerSourceConnection
2055#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2056pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2057    /// The hostname of the server.
2058    pub host: String,
2059    /// The port of the server.
2060    pub port: u16,
2061    /// Database we should connect to.
2062    pub database: String,
2063    /// The username to authenticate as.
2064    pub user: StringOrSecret,
2065    /// Password used for authentication.
2066    pub password: CatalogItemId,
2067    /// A tunnel through which to route traffic.
2068    pub tunnel: Tunnel<C>,
2069    /// Level of encryption to use for the connection.
2070    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2071    /// Certificate validation policy
2072    pub certificate_validation_policy: mz_sql_server_util::config::CertificateValidationPolicy,
2073    /// TLS CA Certifiecate in PEM format
2074    pub tls_root_cert: Option<StringOrSecret>,
2075}
2076
2077impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2078    fn validate_by_default(&self) -> bool {
2079        true
2080    }
2081}
2082
2083impl SqlServerConnectionDetails<InlinedConnection> {
2084    /// Attempts to open a connection to the upstream SQL Server instance.
2085    pub async fn validate(
2086        &self,
2087        _id: CatalogItemId,
2088        storage_configuration: &StorageConfiguration,
2089    ) -> Result<mz_sql_server_util::Client, anyhow::Error> {
2090        let config = self
2091            .resolve_config(
2092                &storage_configuration.connection_context.secrets_reader,
2093                storage_configuration,
2094                InTask::No,
2095            )
2096            .await?;
2097        tracing::debug!(?config, "Validating SQL Server connection");
2098
2099        let mut client = mz_sql_server_util::Client::connect(config).await?;
2100
2101        // Ensure the upstream SQL Server instance is configured to allow CDC.
2102        //
2103        // Run all of the checks necessary and collect the errors to provide the best
2104        // guidance as to which system settings need to be enabled.
2105        let mut replication_errors = vec![];
2106        for error in [
2107            mz_sql_server_util::inspect::ensure_database_cdc_enabled(&mut client).await,
2108            mz_sql_server_util::inspect::ensure_snapshot_isolation_enabled(&mut client).await,
2109        ] {
2110            match error {
2111                Err(mz_sql_server_util::SqlServerError::InvalidSystemSetting {
2112                    name,
2113                    expected,
2114                    actual,
2115                }) => replication_errors.push((name, expected, actual)),
2116                Err(other) => Err(other)?,
2117                Ok(()) => (),
2118            }
2119        }
2120        if !replication_errors.is_empty() {
2121            Err(SqlServerConnectionValidationError::ReplicationSettingsError(replication_errors))?;
2122        }
2123
2124        Ok(client)
2125    }
2126
2127    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2128    /// so the returned [`Config`] can be used to open a connection with the
2129    /// upstream system.
2130    ///
2131    /// The provided [`InTask`] argument determines whether any I/O is run in an
2132    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2133    /// future. The main goal here is to prevent running I/O in timely threads.
2134    ///
2135    /// [`Config`]: mz_sql_server_util::Config
2136    pub async fn resolve_config(
2137        &self,
2138        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2139        storage_configuration: &StorageConfiguration,
2140        in_task: InTask,
2141    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2142        let dyncfg = storage_configuration.config_set();
2143        let mut inner_config = tiberius::Config::new();
2144
2145        // Setup default connection params.
2146        inner_config.host(&self.host);
2147        inner_config.port(self.port);
2148        inner_config.database(self.database.clone());
2149        inner_config.encryption(self.encryption.into());
2150        match self.certificate_validation_policy {
2151            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2152                inner_config.trust_cert()
2153            }
2154            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2155                inner_config.trust_cert_ca_pem(
2156                    self.tls_root_cert
2157                        .as_ref()
2158                        .unwrap()
2159                        .get_string(in_task, secrets_reader)
2160                        .await
2161                        .context("ca certificate")?,
2162                );
2163            }
2164            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => (), // no-op
2165        }
2166
2167        inner_config.application_name("materialize");
2168
2169        // Read our auth settings from
2170        let user = self
2171            .user
2172            .get_string(in_task, secrets_reader)
2173            .await
2174            .context("username")?;
2175        let password = secrets_reader
2176            .read_string_in_task_if(in_task, self.password)
2177            .await
2178            .context("password")?;
2179        // TODO(sql_server3): Support other methods of authentication besides
2180        // username and password.
2181        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2182
2183        // Prevent users from probing our internal network ports by trying to
2184        // connect to localhost, or another non-external IP.
2185        let enfoce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2186
2187        let tunnel = match &self.tunnel {
2188            Tunnel::Direct => mz_sql_server_util::config::TunnelConfig::Direct,
2189            Tunnel::Ssh(SshTunnel {
2190                connection_id,
2191                connection: ssh_connection,
2192            }) => {
2193                let secret = secrets_reader
2194                    .read_in_task_if(in_task, *connection_id)
2195                    .await
2196                    .context("ssh secret")?;
2197                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2198                // Ensure any SSH-bastion host we connect to is resolved to an
2199                // external address.
2200                let addresses = resolve_address(&ssh_connection.host, enfoce_external_addresses)
2201                    .await
2202                    .context("ssh tunnel")?;
2203
2204                let config = SshTunnelConfig {
2205                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2206                    port: ssh_connection.port,
2207                    user: ssh_connection.user.clone(),
2208                    key_pair,
2209                };
2210                mz_sql_server_util::config::TunnelConfig::Ssh {
2211                    config,
2212                    manager: storage_configuration
2213                        .connection_context
2214                        .ssh_tunnel_manager
2215                        .clone(),
2216                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2217                    host: self.host.clone(),
2218                    port: self.port,
2219                }
2220            }
2221            Tunnel::AwsPrivatelink(private_link_connection) => {
2222                assert_none!(private_link_connection.port);
2223                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2224                    connection_id: private_link_connection.connection_id,
2225                    port: self.port,
2226                }
2227            }
2228        };
2229
2230        Ok(mz_sql_server_util::Config::new(
2231            inner_config,
2232            tunnel,
2233            in_task,
2234        ))
2235    }
2236}
2237
2238#[derive(Debug, Clone, thiserror::Error)]
2239pub enum SqlServerConnectionValidationError {
2240    #[error("Invalid SQL Server system replication settings")]
2241    ReplicationSettingsError(Vec<(String, String, String)>),
2242}
2243
2244impl SqlServerConnectionValidationError {
2245    pub fn detail(&self) -> Option<String> {
2246        match self {
2247            Self::ReplicationSettingsError(settings) => Some(format!(
2248                "Invalid SQL Server system replication settings: {}",
2249                itertools::join(
2250                    settings.iter().map(|(setting, expected, actual)| format!(
2251                        "{}: expected {}, got {}",
2252                        setting, expected, actual
2253                    )),
2254                    "; "
2255                )
2256            )),
2257        }
2258    }
2259
2260    pub fn hint(&self) -> Option<String> {
2261        match self {
2262            _ => None,
2263        }
2264    }
2265}
2266
2267impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2268    for SqlServerConnectionDetails<ReferencedConnection>
2269{
2270    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2271        let SqlServerConnectionDetails {
2272            host,
2273            port,
2274            database,
2275            user,
2276            password,
2277            tunnel,
2278            encryption,
2279            certificate_validation_policy,
2280            tls_root_cert,
2281        } = self;
2282
2283        SqlServerConnectionDetails {
2284            host,
2285            port,
2286            database,
2287            user,
2288            password,
2289            tunnel: tunnel.into_inline_connection(&r),
2290            encryption,
2291            certificate_validation_policy,
2292            tls_root_cert,
2293        }
2294    }
2295}
2296
2297impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2298    fn alter_compatible(
2299        &self,
2300        id: mz_repr::GlobalId,
2301        other: &Self,
2302    ) -> Result<(), crate::controller::AlterError> {
2303        let SqlServerConnectionDetails {
2304            tunnel,
2305            // TODO(sql_server2): Figure out how these variables are allowed to change.
2306            host: _,
2307            port: _,
2308            database: _,
2309            user: _,
2310            password: _,
2311            encryption: _,
2312            certificate_validation_policy: _,
2313            tls_root_cert: _,
2314        } = self;
2315
2316        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2317
2318        for (compatible, field) in compatibility_checks {
2319            if !compatible {
2320                tracing::warn!(
2321                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2322                    self,
2323                    other
2324                );
2325
2326                return Err(AlterError { id });
2327            }
2328        }
2329        Ok(())
2330    }
2331}
2332
2333/// A connection to an SSH tunnel.
2334#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2335pub struct SshConnection {
2336    pub host: String,
2337    pub port: u16,
2338    pub user: String,
2339}
2340
2341use self::inline::{
2342    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2343    ReferencedConnection,
2344};
2345
2346impl AlterCompatible for SshConnection {
2347    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2348        // Every element of the SSH connection is configurable.
2349        Ok(())
2350    }
2351}
2352
2353/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2354#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2355pub struct AwsPrivatelink {
2356    /// The ID of the connection to the AWS PrivateLink service.
2357    pub connection_id: CatalogItemId,
2358    // The availability zone to use when connecting to the AWS PrivateLink service.
2359    pub availability_zone: Option<String>,
2360    /// The port to use when connecting to the AWS PrivateLink service, if
2361    /// different from the port in [`KafkaBroker::address`].
2362    pub port: Option<u16>,
2363}
2364
2365impl AlterCompatible for AwsPrivatelink {
2366    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2367        let AwsPrivatelink {
2368            connection_id,
2369            availability_zone: _,
2370            port: _,
2371        } = self;
2372
2373        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2374
2375        for (compatible, field) in compatibility_checks {
2376            if !compatible {
2377                tracing::warn!(
2378                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2379                    self,
2380                    other
2381                );
2382
2383                return Err(AlterError { id });
2384            }
2385        }
2386
2387        Ok(())
2388    }
2389}
2390
2391/// Specifies an SSH tunnel connection.
2392#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2393pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2394    /// id of the ssh connection
2395    pub connection_id: CatalogItemId,
2396    /// ssh connection object
2397    pub connection: C::Ssh,
2398}
2399
2400impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2401    fn into_inline_connection(self, r: R) -> SshTunnel {
2402        let SshTunnel {
2403            connection,
2404            connection_id,
2405        } = self;
2406
2407        SshTunnel {
2408            connection: r.resolve_connection(connection).unwrap_ssh(),
2409            connection_id,
2410        }
2411    }
2412}
2413
2414impl SshTunnel<InlinedConnection> {
2415    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2416    /// secret.
2417    async fn connect(
2418        &self,
2419        storage_configuration: &StorageConfiguration,
2420        remote_host: &str,
2421        remote_port: u16,
2422        in_task: InTask,
2423    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2424        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2425        let resolved = resolve_address(
2426            &self.connection.host,
2427            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2428        )
2429        .await?;
2430        storage_configuration
2431            .connection_context
2432            .ssh_tunnel_manager
2433            .connect(
2434                SshTunnelConfig {
2435                    host: resolved
2436                        .iter()
2437                        .map(|a| a.to_string())
2438                        .collect::<BTreeSet<_>>(),
2439                    port: self.connection.port,
2440                    user: self.connection.user.clone(),
2441                    key_pair: SshKeyPair::from_bytes(
2442                        &storage_configuration
2443                            .connection_context
2444                            .secrets_reader
2445                            .read_in_task_if(in_task, self.connection_id)
2446                            .await?,
2447                    )?,
2448                },
2449                remote_host,
2450                remote_port,
2451                storage_configuration.parameters.ssh_timeout_config,
2452                in_task,
2453            )
2454            .await
2455    }
2456}
2457
2458impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2459    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2460        let SshTunnel {
2461            connection_id,
2462            connection,
2463        } = self;
2464
2465        let compatibility_checks = [
2466            (connection_id == &other.connection_id, "connection_id"),
2467            (
2468                connection.alter_compatible(id, &other.connection).is_ok(),
2469                "connection",
2470            ),
2471        ];
2472
2473        for (compatible, field) in compatibility_checks {
2474            if !compatible {
2475                tracing::warn!(
2476                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2477                    self,
2478                    other
2479                );
2480
2481                return Err(AlterError { id });
2482            }
2483        }
2484
2485        Ok(())
2486    }
2487}
2488
2489impl SshConnection {
2490    #[allow(clippy::unused_async)]
2491    async fn validate(
2492        &self,
2493        id: CatalogItemId,
2494        storage_configuration: &StorageConfiguration,
2495    ) -> Result<(), anyhow::Error> {
2496        let secret = storage_configuration
2497            .connection_context
2498            .secrets_reader
2499            .read_in_task_if(
2500                // We are in a normal tokio context during validation, already.
2501                InTask::No,
2502                id,
2503            )
2504            .await?;
2505        let key_pair = SshKeyPair::from_bytes(&secret)?;
2506
2507        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2508        let resolved = resolve_address(
2509            &self.host,
2510            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2511        )
2512        .await?;
2513
2514        let config = SshTunnelConfig {
2515            host: resolved
2516                .iter()
2517                .map(|a| a.to_string())
2518                .collect::<BTreeSet<_>>(),
2519            port: self.port,
2520            user: self.user.clone(),
2521            key_pair,
2522        };
2523        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2524        // can actually create a new connection to the ssh bastion, without tunneling.
2525        config
2526            .validate(storage_configuration.parameters.ssh_timeout_config)
2527            .await
2528    }
2529
2530    fn validate_by_default(&self) -> bool {
2531        false
2532    }
2533}
2534
2535impl AwsPrivatelinkConnection {
2536    #[allow(clippy::unused_async)]
2537    async fn validate(
2538        &self,
2539        id: CatalogItemId,
2540        storage_configuration: &StorageConfiguration,
2541    ) -> Result<(), anyhow::Error> {
2542        let Some(ref cloud_resource_reader) = storage_configuration
2543            .connection_context
2544            .cloud_resource_reader
2545        else {
2546            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2547        };
2548
2549        // No need to optionally run this in a task, as we are just validating from envd.
2550        let status = cloud_resource_reader.read(id).await?;
2551
2552        let availability = status
2553            .conditions
2554            .as_ref()
2555            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2556
2557        match availability {
2558            Some(condition) if condition.status == "True" => Ok(()),
2559            Some(condition) => Err(anyhow!("{}", condition.message)),
2560            None => Err(anyhow!("Endpoint availability is unknown")),
2561        }
2562    }
2563
2564    fn validate_by_default(&self) -> bool {
2565        false
2566    }
2567}