mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow};
18use iceberg::{Catalog, CatalogBuilder};
19use iceberg_catalog_rest::{
20    REST_CATALOG_PROP_URI, REST_CATALOG_PROP_WAREHOUSE, RestCatalogBuilder,
21};
22use itertools::Itertools;
23use mz_ccsr::tls::{Certificate, Identity};
24use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
25use mz_dyncfg::ConfigSet;
26use mz_kafka_util::client::{
27    BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
28};
29use mz_mysql_util::{MySqlConn, MySqlError};
30use mz_ore::assert_none;
31use mz_ore::error::ErrorExt;
32use mz_ore::future::{InTask, OreFutureExt};
33use mz_ore::netio::DUMMY_DNS_PORT;
34use mz_ore::netio::resolve_address;
35use mz_ore::num::NonNeg;
36use mz_repr::{CatalogItemId, GlobalId};
37use mz_secrets::SecretsReader;
38use mz_ssh_util::keys::SshKeyPair;
39use mz_ssh_util::tunnel::SshTunnelConfig;
40use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
41use mz_tls_util::Pkcs12Archive;
42use mz_tracing::CloneableEnvFilter;
43use rdkafka::ClientContext;
44use rdkafka::config::FromClientConfigAndContext;
45use rdkafka::consumer::{BaseConsumer, Consumer};
46use regex::Regex;
47use serde::{Deserialize, Deserializer, Serialize};
48use tokio::net;
49use tokio::runtime::Handle;
50use tokio_postgres::config::SslMode;
51use tracing::{debug, warn};
52use url::Url;
53
54use crate::AlterCompatible;
55use crate::configuration::StorageConfiguration;
56use crate::connections::aws::{
57    AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
58};
59use crate::connections::string_or_secret::StringOrSecret;
60use crate::controller::AlterError;
61use crate::dyncfgs::{
62    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
63    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM,
64};
65use crate::errors::{ContextCreationError, CsrConnectError};
66
67pub mod aws;
68pub mod inline;
69pub mod string_or_secret;
70
71const REST_CATALOG_PROP_SCOPE: &str = "scope";
72const REST_CATALOG_PROP_CREDENTIAL: &str = "credential";
73
74/// An extension trait for [`SecretsReader`]
75#[async_trait::async_trait]
76trait SecretsReaderExt {
77    /// `SecretsReader::read`, but optionally run in a task.
78    async fn read_in_task_if(
79        &self,
80        in_task: InTask,
81        id: CatalogItemId,
82    ) -> Result<Vec<u8>, anyhow::Error>;
83
84    /// `SecretsReader::read_string`, but optionally run in a task.
85    async fn read_string_in_task_if(
86        &self,
87        in_task: InTask,
88        id: CatalogItemId,
89    ) -> Result<String, anyhow::Error>;
90}
91
92#[async_trait::async_trait]
93impl SecretsReaderExt for Arc<dyn SecretsReader> {
94    async fn read_in_task_if(
95        &self,
96        in_task: InTask,
97        id: CatalogItemId,
98    ) -> Result<Vec<u8>, anyhow::Error> {
99        let sr = Arc::clone(self);
100        async move { sr.read(id).await }
101            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
102            .await
103    }
104    async fn read_string_in_task_if(
105        &self,
106        in_task: InTask,
107        id: CatalogItemId,
108    ) -> Result<String, anyhow::Error> {
109        let sr = Arc::clone(self);
110        async move { sr.read_string(id).await }
111            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
112            .await
113    }
114}
115
116/// Extra context to pass through when instantiating a connection for a source
117/// or sink.
118///
119/// Should be kept cheaply cloneable.
120#[derive(Debug, Clone)]
121pub struct ConnectionContext {
122    /// An opaque identifier for the environment in which this process is
123    /// running.
124    ///
125    /// The storage layer is intentionally unaware of the structure within this
126    /// identifier. Higher layers of the stack can make use of that structure,
127    /// but the storage layer should be oblivious to it.
128    pub environment_id: String,
129    /// The level for librdkafka's logs.
130    pub librdkafka_log_level: tracing::Level,
131    /// A prefix for an external ID to use for all AWS AssumeRole operations.
132    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
133    /// The ARN for a Materialize-controlled role to assume before assuming
134    /// a customer's requested role for an AWS connection.
135    pub aws_connection_role_arn: Option<String>,
136    /// A secrets reader.
137    pub secrets_reader: Arc<dyn SecretsReader>,
138    /// A cloud resource reader, if supported in this configuration.
139    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
140    /// A manager for SSH tunnels.
141    pub ssh_tunnel_manager: SshTunnelManager,
142}
143
144impl ConnectionContext {
145    /// Constructs a new connection context from command line arguments.
146    ///
147    /// **WARNING:** it is critical for security that the `aws_external_id` be
148    /// provided by the operator of the Materialize service (i.e., via a CLI
149    /// argument or environment variable) and not the end user of Materialize
150    /// (e.g., via a configuration option in a SQL statement). See
151    /// [`AwsExternalIdPrefix`] for details.
152    pub fn from_cli_args(
153        environment_id: String,
154        startup_log_level: &CloneableEnvFilter,
155        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
156        aws_connection_role_arn: Option<String>,
157        secrets_reader: Arc<dyn SecretsReader>,
158        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
159    ) -> ConnectionContext {
160        ConnectionContext {
161            environment_id,
162            librdkafka_log_level: mz_ore::tracing::crate_level(
163                &startup_log_level.clone().into(),
164                "librdkafka",
165            ),
166            aws_external_id_prefix,
167            aws_connection_role_arn,
168            secrets_reader,
169            cloud_resource_reader,
170            ssh_tunnel_manager: SshTunnelManager::default(),
171        }
172    }
173
174    /// Constructs a new connection context for usage in tests.
175    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
176        ConnectionContext {
177            environment_id: "test-environment-id".into(),
178            librdkafka_log_level: tracing::Level::INFO,
179            aws_external_id_prefix: Some(
180                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
181                    "test-aws-external-id-prefix",
182                )
183                .expect("infallible"),
184            ),
185            aws_connection_role_arn: Some(
186                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
187            ),
188            secrets_reader,
189            cloud_resource_reader: None,
190            ssh_tunnel_manager: SshTunnelManager::default(),
191        }
192    }
193}
194
195#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
196pub enum Connection<C: ConnectionAccess = InlinedConnection> {
197    Kafka(KafkaConnection<C>),
198    Csr(CsrConnection<C>),
199    Postgres(PostgresConnection<C>),
200    Ssh(SshConnection),
201    Aws(AwsConnection),
202    AwsPrivatelink(AwsPrivatelinkConnection),
203    MySql(MySqlConnection<C>),
204    SqlServer(SqlServerConnectionDetails<C>),
205    IcebergCatalog(IcebergCatalogConnection<C>),
206}
207
208impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
209    for Connection<ReferencedConnection>
210{
211    fn into_inline_connection(self, r: R) -> Connection {
212        match self {
213            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
214            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
215            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
216            Connection::Ssh(ssh) => Connection::Ssh(ssh),
217            Connection::Aws(aws) => Connection::Aws(aws),
218            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
219            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
220            Connection::SqlServer(sql_server) => {
221                Connection::SqlServer(sql_server.into_inline_connection(r))
222            }
223            Connection::IcebergCatalog(iceberg) => {
224                Connection::IcebergCatalog(iceberg.into_inline_connection(r))
225            }
226        }
227    }
228}
229
230impl<C: ConnectionAccess> Connection<C> {
231    /// Whether this connection should be validated by default on creation.
232    pub fn validate_by_default(&self) -> bool {
233        match self {
234            Connection::Kafka(conn) => conn.validate_by_default(),
235            Connection::Csr(conn) => conn.validate_by_default(),
236            Connection::Postgres(conn) => conn.validate_by_default(),
237            Connection::Ssh(conn) => conn.validate_by_default(),
238            Connection::Aws(conn) => conn.validate_by_default(),
239            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
240            Connection::MySql(conn) => conn.validate_by_default(),
241            Connection::SqlServer(conn) => conn.validate_by_default(),
242            Connection::IcebergCatalog(conn) => conn.validate_by_default(),
243        }
244    }
245}
246
247impl Connection<InlinedConnection> {
248    /// Validates this connection by attempting to connect to the upstream system.
249    pub async fn validate(
250        &self,
251        id: CatalogItemId,
252        storage_configuration: &StorageConfiguration,
253    ) -> Result<(), ConnectionValidationError> {
254        match self {
255            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
256            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
257            Connection::Postgres(conn) => {
258                conn.validate(id, storage_configuration).await?;
259            }
260            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
261            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
262            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
263            Connection::MySql(conn) => {
264                conn.validate(id, storage_configuration).await?;
265            }
266            Connection::SqlServer(conn) => {
267                conn.validate(id, storage_configuration).await?;
268            }
269            Connection::IcebergCatalog(conn) => conn.validate(id, storage_configuration).await?,
270        }
271        Ok(())
272    }
273
274    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
275        match self {
276            Self::Kafka(conn) => conn,
277            o => unreachable!("{o:?} is not a Kafka connection"),
278        }
279    }
280
281    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
282        match self {
283            Self::Postgres(conn) => conn,
284            o => unreachable!("{o:?} is not a Postgres connection"),
285        }
286    }
287
288    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
289        match self {
290            Self::MySql(conn) => conn,
291            o => unreachable!("{o:?} is not a MySQL connection"),
292        }
293    }
294
295    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
296        match self {
297            Self::SqlServer(conn) => conn,
298            o => unreachable!("{o:?} is not a SQL Server connection"),
299        }
300    }
301
302    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
303        match self {
304            Self::Aws(conn) => conn,
305            o => unreachable!("{o:?} is not an AWS connection"),
306        }
307    }
308
309    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
310        match self {
311            Self::Ssh(conn) => conn,
312            o => unreachable!("{o:?} is not an SSH connection"),
313        }
314    }
315
316    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
317        match self {
318            Self::Csr(conn) => conn,
319            o => unreachable!("{o:?} is not a Kafka connection"),
320        }
321    }
322
323    pub fn unwrap_iceberg_catalog(self) -> <InlinedConnection as ConnectionAccess>::IcebergCatalog {
324        match self {
325            Self::IcebergCatalog(conn) => conn,
326            o => unreachable!("{o:?} is not an Iceberg catalog connection"),
327        }
328    }
329}
330
331/// An error returned by [`Connection::validate`].
332#[derive(thiserror::Error, Debug)]
333pub enum ConnectionValidationError {
334    #[error(transparent)]
335    Postgres(#[from] PostgresConnectionValidationError),
336    #[error(transparent)]
337    MySql(#[from] MySqlConnectionValidationError),
338    #[error(transparent)]
339    SqlServer(#[from] SqlServerConnectionValidationError),
340    #[error(transparent)]
341    Aws(#[from] AwsConnectionValidationError),
342    #[error("{}", .0.display_with_causes())]
343    Other(#[from] anyhow::Error),
344}
345
346impl ConnectionValidationError {
347    /// Reports additional details about the error, if any are available.
348    pub fn detail(&self) -> Option<String> {
349        match self {
350            ConnectionValidationError::Postgres(e) => e.detail(),
351            ConnectionValidationError::MySql(e) => e.detail(),
352            ConnectionValidationError::SqlServer(e) => e.detail(),
353            ConnectionValidationError::Aws(e) => e.detail(),
354            ConnectionValidationError::Other(_) => None,
355        }
356    }
357
358    /// Reports a hint for the user about how the error could be fixed.
359    pub fn hint(&self) -> Option<String> {
360        match self {
361            ConnectionValidationError::Postgres(e) => e.hint(),
362            ConnectionValidationError::MySql(e) => e.hint(),
363            ConnectionValidationError::SqlServer(e) => e.hint(),
364            ConnectionValidationError::Aws(e) => e.hint(),
365            ConnectionValidationError::Other(_) => None,
366        }
367    }
368}
369
370impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
371    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
372        match (self, other) {
373            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
374            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
375            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
376            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
377            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
378            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
379            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
380            _ => {
381                tracing::warn!(
382                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
383                    self,
384                    other
385                );
386                Err(AlterError { id })
387            }
388        }
389    }
390}
391
392#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
393pub struct RestIcebergCatalog {
394    /// For REST catalogs, the oauth2 credential in a `CLIENT_ID:CLIENT_SECRET` format
395    pub credential: StringOrSecret,
396    /// The oauth2 scope for REST catalogs
397    pub scope: Option<String>,
398    /// The warehouse for REST catalogs
399    pub warehouse: Option<String>,
400}
401
402#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
403pub struct S3TablesRestIcebergCatalog<C: ConnectionAccess = InlinedConnection> {
404    /// The AWS connection details, for s3tables
405    pub aws_connection: AwsConnectionReference<C>,
406    /// The warehouse for s3tables
407    pub warehouse: String,
408}
409
410impl<R: ConnectionResolver> IntoInlineConnection<S3TablesRestIcebergCatalog, R>
411    for S3TablesRestIcebergCatalog<ReferencedConnection>
412{
413    fn into_inline_connection(self, r: R) -> S3TablesRestIcebergCatalog {
414        S3TablesRestIcebergCatalog {
415            aws_connection: self.aws_connection.into_inline_connection(&r),
416            warehouse: self.warehouse,
417        }
418    }
419}
420
421#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
422pub enum IcebergCatalogType {
423    Rest,
424    S3TablesRest,
425}
426
427#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
428pub enum IcebergCatalogImpl<C: ConnectionAccess = InlinedConnection> {
429    Rest(RestIcebergCatalog),
430    S3TablesRest(S3TablesRestIcebergCatalog<C>),
431}
432
433impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogImpl, R>
434    for IcebergCatalogImpl<ReferencedConnection>
435{
436    fn into_inline_connection(self, r: R) -> IcebergCatalogImpl {
437        match self {
438            IcebergCatalogImpl::Rest(rest) => IcebergCatalogImpl::Rest(rest),
439            IcebergCatalogImpl::S3TablesRest(s3tables) => {
440                IcebergCatalogImpl::S3TablesRest(s3tables.into_inline_connection(r))
441            }
442        }
443    }
444}
445
446#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
447pub struct IcebergCatalogConnection<C: ConnectionAccess = InlinedConnection> {
448    /// The catalog impl impl of that catalog
449    pub catalog: IcebergCatalogImpl<C>,
450    /// Where the catalog is located
451    pub uri: reqwest::Url,
452}
453
454impl AlterCompatible for IcebergCatalogConnection {
455    fn alter_compatible(&self, id: GlobalId, _other: &Self) -> Result<(), AlterError> {
456        Err(AlterError { id })
457    }
458}
459
460impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogConnection, R>
461    for IcebergCatalogConnection<ReferencedConnection>
462{
463    fn into_inline_connection(self, r: R) -> IcebergCatalogConnection {
464        IcebergCatalogConnection {
465            catalog: self.catalog.into_inline_connection(&r),
466            uri: self.uri,
467        }
468    }
469}
470
471impl<C: ConnectionAccess> IcebergCatalogConnection<C> {
472    fn validate_by_default(&self) -> bool {
473        true
474    }
475}
476
477impl IcebergCatalogConnection<InlinedConnection> {
478    pub async fn connect(
479        &self,
480        storage_configuration: &StorageConfiguration,
481        in_task: InTask,
482    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
483        match self.catalog {
484            IcebergCatalogImpl::S3TablesRest(ref s3tables) => {
485                self.connect_s3tables(s3tables, storage_configuration, in_task)
486                    .await
487            }
488            IcebergCatalogImpl::Rest(ref rest) => {
489                self.connect_rest(rest, storage_configuration, in_task)
490                    .await
491            }
492        }
493    }
494
495    async fn connect_s3tables(
496        &self,
497        s3tables: &S3TablesRestIcebergCatalog,
498        storage_configuration: &StorageConfiguration,
499        in_task: InTask,
500    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
501        let mut props = BTreeMap::from([(
502            REST_CATALOG_PROP_URI.to_string(),
503            self.uri.to_string().clone(),
504        )]);
505
506        let aws_ref = &s3tables.aws_connection;
507        let aws_config = aws_ref
508            .connection
509            .load_sdk_config(
510                &storage_configuration.connection_context,
511                aws_ref.connection_id,
512                in_task,
513            )
514            .await?;
515
516        props.insert(
517            REST_CATALOG_PROP_WAREHOUSE.to_string(),
518            s3tables.warehouse.clone(),
519        );
520
521        let catalog = RestCatalogBuilder::default()
522            .with_aws_client(aws_config)
523            .load("IcebergCatalog", props.into_iter().collect())
524            .await
525            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
526        Ok(Arc::new(catalog))
527    }
528
529    async fn connect_rest(
530        &self,
531        rest: &RestIcebergCatalog,
532        storage_configuration: &StorageConfiguration,
533        in_task: InTask,
534    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
535        let mut props = BTreeMap::from([(
536            REST_CATALOG_PROP_URI.to_string(),
537            self.uri.to_string().clone(),
538        )]);
539
540        if let Some(warehouse) = &rest.warehouse {
541            props.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
542        }
543
544        let credential = rest
545            .credential
546            .get_string(
547                in_task,
548                &storage_configuration.connection_context.secrets_reader,
549            )
550            .await
551            .map_err(|e| anyhow!("failed to read Iceberg catalog credential: {e}"))?;
552        props.insert(REST_CATALOG_PROP_CREDENTIAL.to_string(), credential);
553
554        if let Some(scope) = &rest.scope {
555            props.insert(REST_CATALOG_PROP_SCOPE.to_string(), scope.clone());
556        }
557
558        let catalog = RestCatalogBuilder::default()
559            .load("IcebergCatalog", props.into_iter().collect())
560            .await
561            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
562        Ok(Arc::new(catalog))
563    }
564
565    async fn validate(
566        &self,
567        _id: CatalogItemId,
568        storage_configuration: &StorageConfiguration,
569    ) -> Result<(), ConnectionValidationError> {
570        let catalog = self
571            .connect(storage_configuration, InTask::No)
572            .await
573            .map_err(|e| {
574                ConnectionValidationError::Other(anyhow!("failed to connect to catalog: {e}"))
575            })?;
576
577        // If we can list namespaces, the connection is valid.
578        catalog.list_namespaces(None).await.map_err(|e| {
579            ConnectionValidationError::Other(anyhow!("failed to list namespaces: {e}"))
580        })?;
581
582        Ok(())
583    }
584}
585
586#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
587pub struct AwsPrivatelinkConnection {
588    pub service_name: String,
589    pub availability_zones: Vec<String>,
590}
591
592impl AlterCompatible for AwsPrivatelinkConnection {
593    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
594        // Every element of the AwsPrivatelinkConnection connection is configurable.
595        Ok(())
596    }
597}
598
599#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
600pub struct KafkaTlsConfig {
601    pub identity: Option<TlsIdentity>,
602    pub root_cert: Option<StringOrSecret>,
603}
604
605#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
606pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
607    pub mechanism: String,
608    pub username: StringOrSecret,
609    pub password: Option<CatalogItemId>,
610    pub aws: Option<AwsConnectionReference<C>>,
611}
612
613impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
614    for KafkaSaslConfig<ReferencedConnection>
615{
616    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
617        KafkaSaslConfig {
618            mechanism: self.mechanism,
619            username: self.username,
620            password: self.password,
621            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
622        }
623    }
624}
625
626/// Specifies a Kafka broker in a [`KafkaConnection`].
627#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
628pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
629    /// The address of the Kafka broker.
630    pub address: String,
631    /// An optional tunnel to use when connecting to the broker.
632    pub tunnel: Tunnel<C>,
633}
634
635impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
636    for KafkaBroker<ReferencedConnection>
637{
638    fn into_inline_connection(self, r: R) -> KafkaBroker {
639        let KafkaBroker { address, tunnel } = self;
640        KafkaBroker {
641            address,
642            tunnel: tunnel.into_inline_connection(r),
643        }
644    }
645}
646
647#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
648pub struct KafkaTopicOptions {
649    /// The replication factor for the topic.
650    /// If `None`, the broker default will be used.
651    pub replication_factor: Option<NonNeg<i32>>,
652    /// The number of partitions to create.
653    /// If `None`, the broker default will be used.
654    pub partition_count: Option<NonNeg<i32>>,
655    /// The initial configuration parameters for the topic.
656    pub topic_config: BTreeMap<String, String>,
657}
658
659#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
660pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
661    pub brokers: Vec<KafkaBroker<C>>,
662    /// A tunnel through which to route traffic,
663    /// that can be overridden for individual brokers
664    /// in `brokers`.
665    pub default_tunnel: Tunnel<C>,
666    pub progress_topic: Option<String>,
667    pub progress_topic_options: KafkaTopicOptions,
668    pub options: BTreeMap<String, StringOrSecret>,
669    pub tls: Option<KafkaTlsConfig>,
670    pub sasl: Option<KafkaSaslConfig<C>>,
671}
672
673impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
674    for KafkaConnection<ReferencedConnection>
675{
676    fn into_inline_connection(self, r: R) -> KafkaConnection {
677        let KafkaConnection {
678            brokers,
679            progress_topic,
680            progress_topic_options,
681            default_tunnel,
682            options,
683            tls,
684            sasl,
685        } = self;
686
687        let brokers = brokers
688            .into_iter()
689            .map(|broker| broker.into_inline_connection(&r))
690            .collect();
691
692        KafkaConnection {
693            brokers,
694            progress_topic,
695            progress_topic_options,
696            default_tunnel: default_tunnel.into_inline_connection(&r),
697            options,
698            tls,
699            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
700        }
701    }
702}
703
704impl<C: ConnectionAccess> KafkaConnection<C> {
705    /// Returns the name of the progress topic to use for the connection.
706    ///
707    /// The caller is responsible for providing the connection ID as it is not
708    /// known to `KafkaConnection`.
709    pub fn progress_topic(
710        &self,
711        connection_context: &ConnectionContext,
712        connection_id: CatalogItemId,
713    ) -> Cow<'_, str> {
714        if let Some(progress_topic) = &self.progress_topic {
715            Cow::Borrowed(progress_topic)
716        } else {
717            Cow::Owned(format!(
718                "_materialize-progress-{}-{}",
719                connection_context.environment_id, connection_id,
720            ))
721        }
722    }
723
724    fn validate_by_default(&self) -> bool {
725        true
726    }
727}
728
729impl KafkaConnection {
730    /// Generates a string that can be used as the base for a configuration ID
731    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
732    /// or sink.
733    pub fn id_base(
734        connection_context: &ConnectionContext,
735        connection_id: CatalogItemId,
736        object_id: GlobalId,
737    ) -> String {
738        format!(
739            "materialize-{}-{}-{}",
740            connection_context.environment_id, connection_id, object_id,
741        )
742    }
743
744    /// Enriches the provided `client_id` according to any enrichment rules in
745    /// the `kafka_client_id_enrichment_rules` configuration parameter.
746    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
747        #[derive(Debug, Deserialize)]
748        struct EnrichmentRule {
749            #[serde(deserialize_with = "deserialize_regex")]
750            pattern: Regex,
751            payload: String,
752        }
753
754        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
755        where
756            D: Deserializer<'de>,
757        {
758            let buf = String::deserialize(deserializer)?;
759            Regex::new(&buf).map_err(serde::de::Error::custom)
760        }
761
762        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
763        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
764            Ok(rules) => rules,
765            Err(e) => {
766                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
767                return;
768            }
769        };
770
771        // Check every rule against every broker. Rules are matched in the order
772        // that they are specified. It is usually a configuration error if
773        // multiple rules match the same list of Kafka brokers, but we
774        // nonetheless want to provide well defined semantics.
775        debug!(?self.brokers, "evaluating client ID enrichment rules");
776        for rule in rules {
777            let is_match = self
778                .brokers
779                .iter()
780                .any(|b| rule.pattern.is_match(&b.address));
781            debug!(?rule, is_match, "evaluated client ID enrichment rule");
782            if is_match {
783                client_id.push('-');
784                client_id.push_str(&rule.payload);
785            }
786        }
787    }
788
789    /// Creates a Kafka client for the connection.
790    pub async fn create_with_context<C, T>(
791        &self,
792        storage_configuration: &StorageConfiguration,
793        context: C,
794        extra_options: &BTreeMap<&str, String>,
795        in_task: InTask,
796    ) -> Result<T, ContextCreationError>
797    where
798        C: ClientContext,
799        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
800    {
801        let mut options = self.options.clone();
802
803        // Ensure that Kafka topics are *not* automatically created when
804        // consuming, producing, or fetching metadata for a topic. This ensures
805        // that we don't accidentally create topics with the wrong number of
806        // partitions.
807        options.insert("allow.auto.create.topics".into(), "false".into());
808
809        let brokers = match &self.default_tunnel {
810            Tunnel::AwsPrivatelink(t) => {
811                assert!(&self.brokers.is_empty());
812
813                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
814                    .get(storage_configuration.config_set());
815                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
816
817                // When using a default privatelink tunnel broker/brokers cannot be specified
818                // instead the tunnel connection_id and port are used for the initial connection.
819                format!(
820                    "{}:{}",
821                    vpc_endpoint_host(
822                        t.connection_id,
823                        None, // Default tunnel does not support availability zones.
824                    ),
825                    t.port.unwrap_or(9092)
826                )
827            }
828            _ => self.brokers.iter().map(|b| &b.address).join(","),
829        };
830        options.insert("bootstrap.servers".into(), brokers.into());
831        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
832            (false, false) => "PLAINTEXT",
833            (true, false) => "SSL",
834            (false, true) => "SASL_PLAINTEXT",
835            (true, true) => "SASL_SSL",
836        };
837        options.insert("security.protocol".into(), security_protocol.into());
838        if let Some(tls) = &self.tls {
839            if let Some(root_cert) = &tls.root_cert {
840                options.insert("ssl.ca.pem".into(), root_cert.clone());
841            }
842            if let Some(identity) = &tls.identity {
843                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
844                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
845            }
846        }
847        if let Some(sasl) = &self.sasl {
848            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
849            options.insert("sasl.username".into(), sasl.username.clone());
850            if let Some(password) = sasl.password {
851                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
852            }
853        }
854
855        let mut config = mz_kafka_util::client::create_new_client_config(
856            storage_configuration
857                .connection_context
858                .librdkafka_log_level,
859            storage_configuration.parameters.kafka_timeout_config,
860        );
861        for (k, v) in options {
862            config.set(
863                k,
864                v.get_string(
865                    in_task,
866                    &storage_configuration.connection_context.secrets_reader,
867                )
868                .await
869                .context("reading kafka secret")?,
870            );
871        }
872        for (k, v) in extra_options {
873            config.set(*k, v);
874        }
875
876        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
877            None => None,
878            Some(aws) => Some(
879                aws.connection
880                    .load_sdk_config(
881                        &storage_configuration.connection_context,
882                        aws.connection_id,
883                        in_task,
884                    )
885                    .await?,
886            ),
887        };
888
889        // TODO(roshan): Implement enforcement of external address validation once
890        // rdkafka client has been updated to support providing multiple resolved
891        // addresses for brokers
892        let mut context = TunnelingClientContext::new(
893            context,
894            Handle::current(),
895            storage_configuration
896                .connection_context
897                .ssh_tunnel_manager
898                .clone(),
899            storage_configuration.parameters.ssh_timeout_config,
900            aws_config,
901            in_task,
902        );
903
904        match &self.default_tunnel {
905            Tunnel::Direct => {
906                // By default, don't offer a default override for broker address lookup.
907            }
908            Tunnel::AwsPrivatelink(pl) => {
909                context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
910                    pl.connection_id,
911                    None, // Default tunnel does not support availability zones.
912                )));
913            }
914            Tunnel::Ssh(ssh_tunnel) => {
915                let secret = storage_configuration
916                    .connection_context
917                    .secrets_reader
918                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
919                    .await?;
920                let key_pair = SshKeyPair::from_bytes(&secret)?;
921
922                // Ensure any ssh-bastion address we connect to is resolved to an external address.
923                let resolved = resolve_address(
924                    &ssh_tunnel.connection.host,
925                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
926                )
927                .await?;
928                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
929                    host: resolved
930                        .iter()
931                        .map(|a| a.to_string())
932                        .collect::<BTreeSet<_>>(),
933                    port: ssh_tunnel.connection.port,
934                    user: ssh_tunnel.connection.user.clone(),
935                    key_pair,
936                }));
937            }
938        }
939
940        for broker in &self.brokers {
941            let mut addr_parts = broker.address.splitn(2, ':');
942            let addr = BrokerAddr {
943                host: addr_parts
944                    .next()
945                    .context("BROKER is not address:port")?
946                    .into(),
947                port: addr_parts
948                    .next()
949                    .unwrap_or("9092")
950                    .parse()
951                    .context("parsing BROKER port")?,
952            };
953            match &broker.tunnel {
954                Tunnel::Direct => {
955                    // By default, don't override broker address lookup.
956                    //
957                    // N.B.
958                    //
959                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
960                    // we avoid doing because:
961                    // - Its not necessary.
962                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
963                    // in the `TunnelingClientContext`.
964                }
965                Tunnel::AwsPrivatelink(aws_privatelink) => {
966                    let host = mz_cloud_resources::vpc_endpoint_host(
967                        aws_privatelink.connection_id,
968                        aws_privatelink.availability_zone.as_deref(),
969                    );
970                    let port = aws_privatelink.port;
971                    context.add_broker_rewrite(
972                        addr,
973                        BrokerRewrite {
974                            host: host.clone(),
975                            port,
976                        },
977                    );
978                }
979                Tunnel::Ssh(ssh_tunnel) => {
980                    // Ensure any SSH bastion address we connect to is resolved to an external address.
981                    let ssh_host_resolved = resolve_address(
982                        &ssh_tunnel.connection.host,
983                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
984                    )
985                    .await?;
986                    context
987                        .add_ssh_tunnel(
988                            addr,
989                            SshTunnelConfig {
990                                host: ssh_host_resolved
991                                    .iter()
992                                    .map(|a| a.to_string())
993                                    .collect::<BTreeSet<_>>(),
994                                port: ssh_tunnel.connection.port,
995                                user: ssh_tunnel.connection.user.clone(),
996                                key_pair: SshKeyPair::from_bytes(
997                                    &storage_configuration
998                                        .connection_context
999                                        .secrets_reader
1000                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
1001                                        .await?,
1002                                )?,
1003                            },
1004                        )
1005                        .await
1006                        .map_err(ContextCreationError::Ssh)?;
1007                }
1008            }
1009        }
1010
1011        Ok(config.create_with_context(context)?)
1012    }
1013
1014    async fn validate(
1015        &self,
1016        _id: CatalogItemId,
1017        storage_configuration: &StorageConfiguration,
1018    ) -> Result<(), anyhow::Error> {
1019        let (context, error_rx) = MzClientContext::with_errors();
1020        let consumer: BaseConsumer<_> = self
1021            .create_with_context(
1022                storage_configuration,
1023                context,
1024                &BTreeMap::new(),
1025                // We are in a normal tokio context during validation, already.
1026                InTask::No,
1027            )
1028            .await?;
1029        let consumer = Arc::new(consumer);
1030
1031        let timeout = storage_configuration
1032            .parameters
1033            .kafka_timeout_config
1034            .fetch_metadata_timeout;
1035
1036        // librdkafka doesn't expose an API for determining whether a connection to
1037        // the Kafka cluster has been successfully established. So we make a
1038        // metadata request, though we don't care about the results, so that we can
1039        // report any errors making that request. If the request succeeds, we know
1040        // we were able to contact at least one broker, and that's a good proxy for
1041        // being able to contact all the brokers in the cluster.
1042        //
1043        // The downside of this approach is it produces a generic error message like
1044        // "metadata fetch error" with no additional details. The real networking
1045        // error is buried in the librdkafka logs, which are not visible to users.
1046        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
1047            let consumer = Arc::clone(&consumer);
1048            move || consumer.fetch_metadata(None, timeout)
1049        })
1050        .await?;
1051        match result {
1052            Ok(_) => Ok(()),
1053            // The error returned by `fetch_metadata` does not provide any details which makes for
1054            // a crappy user facing error message. For this reason we attempt to grab a better
1055            // error message from the client context, which should contain any error logs emitted
1056            // by librdkafka, and fallback to the generic error if there is nothing there.
1057            Err(err) => {
1058                // Multiple errors might have been logged during this validation but some are more
1059                // relevant than others. Specifically, we prefer non-internal errors over internal
1060                // errors since those give much more useful information to the users.
1061                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
1062                    MzKafkaError::Internal(_) => new,
1063                    _ => cur,
1064                });
1065
1066                // Don't drop the consumer until after we've drained the errors
1067                // channel. Dropping the consumer can introduce spurious errors.
1068                // See database-issues#7432.
1069                drop(consumer);
1070
1071                match main_err {
1072                    Some(err) => Err(err.into()),
1073                    None => Err(err.into()),
1074                }
1075            }
1076        }
1077    }
1078}
1079
1080impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
1081    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1082        let KafkaConnection {
1083            brokers: _,
1084            default_tunnel: _,
1085            progress_topic,
1086            progress_topic_options,
1087            options: _,
1088            tls: _,
1089            sasl: _,
1090        } = self;
1091
1092        let compatibility_checks = [
1093            (progress_topic == &other.progress_topic, "progress_topic"),
1094            (
1095                progress_topic_options == &other.progress_topic_options,
1096                "progress_topic_options",
1097            ),
1098        ];
1099
1100        for (compatible, field) in compatibility_checks {
1101            if !compatible {
1102                tracing::warn!(
1103                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1104                    self,
1105                    other
1106                );
1107
1108                return Err(AlterError { id });
1109            }
1110        }
1111
1112        Ok(())
1113    }
1114}
1115
1116/// A connection to a Confluent Schema Registry.
1117#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1118pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1119    /// The URL of the schema registry.
1120    pub url: Url,
1121    /// Trusted root TLS certificate in PEM format.
1122    pub tls_root_cert: Option<StringOrSecret>,
1123    /// An optional TLS client certificate for authentication with the schema
1124    /// registry.
1125    pub tls_identity: Option<TlsIdentity>,
1126    /// Optional HTTP authentication credentials for the schema registry.
1127    pub http_auth: Option<CsrConnectionHttpAuth>,
1128    /// A tunnel through which to route traffic.
1129    pub tunnel: Tunnel<C>,
1130}
1131
1132impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1133    for CsrConnection<ReferencedConnection>
1134{
1135    fn into_inline_connection(self, r: R) -> CsrConnection {
1136        let CsrConnection {
1137            url,
1138            tls_root_cert,
1139            tls_identity,
1140            http_auth,
1141            tunnel,
1142        } = self;
1143        CsrConnection {
1144            url,
1145            tls_root_cert,
1146            tls_identity,
1147            http_auth,
1148            tunnel: tunnel.into_inline_connection(r),
1149        }
1150    }
1151}
1152
1153impl<C: ConnectionAccess> CsrConnection<C> {
1154    fn validate_by_default(&self) -> bool {
1155        true
1156    }
1157}
1158
1159impl CsrConnection {
1160    /// Constructs a schema registry client from the connection.
1161    pub async fn connect(
1162        &self,
1163        storage_configuration: &StorageConfiguration,
1164        in_task: InTask,
1165    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1166        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1167        if let Some(root_cert) = &self.tls_root_cert {
1168            let root_cert = root_cert
1169                .get_string(
1170                    in_task,
1171                    &storage_configuration.connection_context.secrets_reader,
1172                )
1173                .await?;
1174            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1175            client_config = client_config.add_root_certificate(root_cert);
1176        }
1177
1178        if let Some(tls_identity) = &self.tls_identity {
1179            let key = &storage_configuration
1180                .connection_context
1181                .secrets_reader
1182                .read_string_in_task_if(in_task, tls_identity.key)
1183                .await?;
1184            let cert = tls_identity
1185                .cert
1186                .get_string(
1187                    in_task,
1188                    &storage_configuration.connection_context.secrets_reader,
1189                )
1190                .await?;
1191            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1192            client_config = client_config.identity(ident);
1193        }
1194
1195        if let Some(http_auth) = &self.http_auth {
1196            let username = http_auth
1197                .username
1198                .get_string(
1199                    in_task,
1200                    &storage_configuration.connection_context.secrets_reader,
1201                )
1202                .await?;
1203            let password = match http_auth.password {
1204                None => None,
1205                Some(password) => Some(
1206                    storage_configuration
1207                        .connection_context
1208                        .secrets_reader
1209                        .read_string_in_task_if(in_task, password)
1210                        .await?,
1211                ),
1212            };
1213            client_config = client_config.auth(username, password);
1214        }
1215
1216        // TODO: use types to enforce that the URL has a string hostname.
1217        let host = self
1218            .url
1219            .host_str()
1220            .ok_or_else(|| anyhow!("url missing host"))?;
1221        match &self.tunnel {
1222            Tunnel::Direct => {
1223                // Ensure any host we connect to is resolved to an external address.
1224                let resolved = resolve_address(
1225                    host,
1226                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1227                )
1228                .await?;
1229                client_config = client_config.resolve_to_addrs(
1230                    host,
1231                    &resolved
1232                        .iter()
1233                        .map(|addr| SocketAddr::new(*addr, DUMMY_DNS_PORT))
1234                        .collect::<Vec<_>>(),
1235                )
1236            }
1237            Tunnel::Ssh(ssh_tunnel) => {
1238                let ssh_tunnel = ssh_tunnel
1239                    .connect(
1240                        storage_configuration,
1241                        host,
1242                        // Default to the default http port, but this
1243                        // could default to 8081...
1244                        self.url.port().unwrap_or(80),
1245                        in_task,
1246                    )
1247                    .await
1248                    .map_err(CsrConnectError::Ssh)?;
1249
1250                // Carefully inject the SSH tunnel into the client
1251                // configuration. This is delicate because we need TLS
1252                // verification to continue to use the remote hostname rather
1253                // than the tunnel hostname.
1254
1255                client_config = client_config
1256                    // `resolve_to_addrs` allows us to rewrite the hostname
1257                    // at the DNS level, which means the TCP connection is
1258                    // correctly routed through the tunnel, but TLS verification
1259                    // is still performed against the remote hostname.
1260                    // Unfortunately the port here is ignored...
1261                    .resolve_to_addrs(
1262                        host,
1263                        &[SocketAddr::new(
1264                            ssh_tunnel.local_addr().ip(),
1265                            DUMMY_DNS_PORT,
1266                        )],
1267                    )
1268                    // ...so we also dynamically rewrite the URL to use the
1269                    // current port for the SSH tunnel.
1270                    //
1271                    // WARNING: this is brittle, because we only dynamically
1272                    // update the client configuration with the tunnel *port*,
1273                    // and not the hostname This works fine in practice, because
1274                    // only the SSH tunnel port will change if the tunnel fails
1275                    // and has to be restarted (the hostname is always
1276                    // 127.0.0.1)--but this is an an implementation detail of
1277                    // the SSH tunnel code that we're relying on.
1278                    .dynamic_url({
1279                        let remote_url = self.url.clone();
1280                        move || {
1281                            let mut url = remote_url.clone();
1282                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1283                                .expect("cannot fail");
1284                            url
1285                        }
1286                    });
1287            }
1288            Tunnel::AwsPrivatelink(connection) => {
1289                assert_none!(connection.port);
1290
1291                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1292                    connection.connection_id,
1293                    connection.availability_zone.as_deref(),
1294                );
1295                let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_DNS_PORT))
1296                    .await
1297                    .context("resolving PrivateLink host")?
1298                    .collect();
1299                client_config = client_config.resolve_to_addrs(host, &addrs)
1300            }
1301        }
1302
1303        Ok(client_config.build()?)
1304    }
1305
1306    async fn validate(
1307        &self,
1308        _id: CatalogItemId,
1309        storage_configuration: &StorageConfiguration,
1310    ) -> Result<(), anyhow::Error> {
1311        let client = self
1312            .connect(
1313                storage_configuration,
1314                // We are in a normal tokio context during validation, already.
1315                InTask::No,
1316            )
1317            .await?;
1318        client.list_subjects().await?;
1319        Ok(())
1320    }
1321}
1322
1323impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1324    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1325        let CsrConnection {
1326            tunnel,
1327            // All non-tunnel fields may change
1328            url: _,
1329            tls_root_cert: _,
1330            tls_identity: _,
1331            http_auth: _,
1332        } = self;
1333
1334        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1335
1336        for (compatible, field) in compatibility_checks {
1337            if !compatible {
1338                tracing::warn!(
1339                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1340                    self,
1341                    other
1342                );
1343
1344                return Err(AlterError { id });
1345            }
1346        }
1347        Ok(())
1348    }
1349}
1350
1351/// A TLS key pair used for client identity.
1352#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1353pub struct TlsIdentity {
1354    /// The client's TLS public certificate in PEM format.
1355    pub cert: StringOrSecret,
1356    /// The ID of the secret containing the client's TLS private key in PEM
1357    /// format.
1358    pub key: CatalogItemId,
1359}
1360
1361/// HTTP authentication credentials in a [`CsrConnection`].
1362#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1363pub struct CsrConnectionHttpAuth {
1364    /// The username.
1365    pub username: StringOrSecret,
1366    /// The ID of the secret containing the password, if any.
1367    pub password: Option<CatalogItemId>,
1368}
1369
1370/// A connection to a PostgreSQL server.
1371#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1372pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1373    /// The hostname of the server.
1374    pub host: String,
1375    /// The port of the server.
1376    pub port: u16,
1377    /// The name of the database to connect to.
1378    pub database: String,
1379    /// The username to authenticate as.
1380    pub user: StringOrSecret,
1381    /// An optional password for authentication.
1382    pub password: Option<CatalogItemId>,
1383    /// A tunnel through which to route traffic.
1384    pub tunnel: Tunnel<C>,
1385    /// Whether to use TLS for encryption, authentication, or both.
1386    pub tls_mode: SslMode,
1387    /// An optional root TLS certificate in PEM format, to verify the server's
1388    /// identity.
1389    pub tls_root_cert: Option<StringOrSecret>,
1390    /// An optional TLS client certificate for authentication.
1391    pub tls_identity: Option<TlsIdentity>,
1392}
1393
1394impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1395    for PostgresConnection<ReferencedConnection>
1396{
1397    fn into_inline_connection(self, r: R) -> PostgresConnection {
1398        let PostgresConnection {
1399            host,
1400            port,
1401            database,
1402            user,
1403            password,
1404            tunnel,
1405            tls_mode,
1406            tls_root_cert,
1407            tls_identity,
1408        } = self;
1409
1410        PostgresConnection {
1411            host,
1412            port,
1413            database,
1414            user,
1415            password,
1416            tunnel: tunnel.into_inline_connection(r),
1417            tls_mode,
1418            tls_root_cert,
1419            tls_identity,
1420        }
1421    }
1422}
1423
1424impl<C: ConnectionAccess> PostgresConnection<C> {
1425    fn validate_by_default(&self) -> bool {
1426        true
1427    }
1428}
1429
1430impl PostgresConnection<InlinedConnection> {
1431    pub async fn config(
1432        &self,
1433        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1434        storage_configuration: &StorageConfiguration,
1435        in_task: InTask,
1436    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1437        let params = &storage_configuration.parameters;
1438
1439        let mut config = tokio_postgres::Config::new();
1440        config
1441            .host(&self.host)
1442            .port(self.port)
1443            .dbname(&self.database)
1444            .user(&self.user.get_string(in_task, secrets_reader).await?)
1445            .ssl_mode(self.tls_mode);
1446        if let Some(password) = self.password {
1447            let password = secrets_reader
1448                .read_string_in_task_if(in_task, password)
1449                .await?;
1450            config.password(password);
1451        }
1452        if let Some(tls_root_cert) = &self.tls_root_cert {
1453            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1454            config.ssl_root_cert(tls_root_cert.as_bytes());
1455        }
1456        if let Some(tls_identity) = &self.tls_identity {
1457            let cert = tls_identity
1458                .cert
1459                .get_string(in_task, secrets_reader)
1460                .await?;
1461            let key = secrets_reader
1462                .read_string_in_task_if(in_task, tls_identity.key)
1463                .await?;
1464            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1465        }
1466
1467        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1468            config.connect_timeout(connect_timeout);
1469        }
1470        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1471            config.keepalives_retries(keepalives_retries);
1472        }
1473        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1474            config.keepalives_idle(keepalives_idle);
1475        }
1476        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1477            config.keepalives_interval(keepalives_interval);
1478        }
1479        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1480            config.tcp_user_timeout(tcp_user_timeout);
1481        }
1482
1483        let mut options = vec![];
1484        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1485            options.push(format!(
1486                "--wal_sender_timeout={}",
1487                wal_sender_timeout.as_millis()
1488            ));
1489        };
1490        if params.pg_source_tcp_configure_server {
1491            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1492                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1493            }
1494            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1495                options.push(format!(
1496                    "--tcp_keepalives_idle={}",
1497                    keepalives_idle.as_secs()
1498                ));
1499            }
1500            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1501                options.push(format!(
1502                    "--tcp_keepalives_interval={}",
1503                    keepalives_interval.as_secs()
1504                ));
1505            }
1506            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1507                options.push(format!(
1508                    "--tcp_user_timeout={}",
1509                    tcp_user_timeout.as_millis()
1510                ));
1511            }
1512        }
1513        config.options(options.join(" ").as_str());
1514
1515        let tunnel = match &self.tunnel {
1516            Tunnel::Direct => {
1517                // Ensure any host we connect to is resolved to an external address.
1518                let resolved = resolve_address(
1519                    &self.host,
1520                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1521                )
1522                .await?;
1523                mz_postgres_util::TunnelConfig::Direct {
1524                    resolved_ips: Some(resolved),
1525                }
1526            }
1527            Tunnel::Ssh(SshTunnel {
1528                connection_id,
1529                connection,
1530            }) => {
1531                let secret = secrets_reader
1532                    .read_in_task_if(in_task, *connection_id)
1533                    .await?;
1534                let key_pair = SshKeyPair::from_bytes(&secret)?;
1535                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1536                let resolved = resolve_address(
1537                    &connection.host,
1538                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1539                )
1540                .await?;
1541                mz_postgres_util::TunnelConfig::Ssh {
1542                    config: SshTunnelConfig {
1543                        host: resolved
1544                            .iter()
1545                            .map(|a| a.to_string())
1546                            .collect::<BTreeSet<_>>(),
1547                        port: connection.port,
1548                        user: connection.user.clone(),
1549                        key_pair,
1550                    },
1551                }
1552            }
1553            Tunnel::AwsPrivatelink(connection) => {
1554                assert_none!(connection.port);
1555                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1556                    connection_id: connection.connection_id,
1557                }
1558            }
1559        };
1560
1561        Ok(mz_postgres_util::Config::new(
1562            config,
1563            tunnel,
1564            params.ssh_timeout_config,
1565            in_task,
1566        )?)
1567    }
1568
1569    pub async fn validate(
1570        &self,
1571        _id: CatalogItemId,
1572        storage_configuration: &StorageConfiguration,
1573    ) -> Result<mz_postgres_util::Client, anyhow::Error> {
1574        let config = self
1575            .config(
1576                &storage_configuration.connection_context.secrets_reader,
1577                storage_configuration,
1578                // We are in a normal tokio context during validation, already.
1579                InTask::No,
1580            )
1581            .await?;
1582        let client = config
1583            .connect(
1584                "connection validation",
1585                &storage_configuration.connection_context.ssh_tunnel_manager,
1586            )
1587            .await?;
1588
1589        let wal_level = mz_postgres_util::get_wal_level(&client).await?;
1590
1591        if wal_level < mz_postgres_util::replication::WalLevel::Logical {
1592            Err(PostgresConnectionValidationError::InsufficientWalLevel { wal_level })?;
1593        }
1594
1595        let max_wal_senders = mz_postgres_util::get_max_wal_senders(&client).await?;
1596
1597        if max_wal_senders < 1 {
1598            Err(PostgresConnectionValidationError::ReplicationDisabled)?;
1599        }
1600
1601        let available_replication_slots =
1602            mz_postgres_util::available_replication_slots(&client).await?;
1603
1604        // We need 1 replication slot for the snapshots and 1 for the continuing replication
1605        if available_replication_slots < 2 {
1606            Err(
1607                PostgresConnectionValidationError::InsufficientReplicationSlotsAvailable {
1608                    count: 2,
1609                },
1610            )?;
1611        }
1612
1613        Ok(client)
1614    }
1615}
1616
1617#[derive(Debug, Clone, thiserror::Error)]
1618pub enum PostgresConnectionValidationError {
1619    #[error("PostgreSQL server has insufficient number of replication slots available")]
1620    InsufficientReplicationSlotsAvailable { count: usize },
1621    #[error("server must have wal_level >= logical, but has {wal_level}")]
1622    InsufficientWalLevel {
1623        wal_level: mz_postgres_util::replication::WalLevel,
1624    },
1625    #[error("replication disabled on server")]
1626    ReplicationDisabled,
1627}
1628
1629impl PostgresConnectionValidationError {
1630    pub fn detail(&self) -> Option<String> {
1631        match self {
1632            Self::InsufficientReplicationSlotsAvailable { count } => Some(format!(
1633                "executing this statement requires {} replication slot{}",
1634                count,
1635                if *count == 1 { "" } else { "s" }
1636            )),
1637            _ => None,
1638        }
1639    }
1640
1641    pub fn hint(&self) -> Option<String> {
1642        match self {
1643            Self::InsufficientReplicationSlotsAvailable { .. } => Some(
1644                "you might be able to wait for other sources to finish snapshotting and try again"
1645                    .into(),
1646            ),
1647            Self::ReplicationDisabled => Some("set max_wal_senders to a value > 0".into()),
1648            _ => None,
1649        }
1650    }
1651}
1652
1653impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1654    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1655        let PostgresConnection {
1656            tunnel,
1657            // All non-tunnel options may change arbitrarily
1658            host: _,
1659            port: _,
1660            database: _,
1661            user: _,
1662            password: _,
1663            tls_mode: _,
1664            tls_root_cert: _,
1665            tls_identity: _,
1666        } = self;
1667
1668        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1669
1670        for (compatible, field) in compatibility_checks {
1671            if !compatible {
1672                tracing::warn!(
1673                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1674                    self,
1675                    other
1676                );
1677
1678                return Err(AlterError { id });
1679            }
1680        }
1681        Ok(())
1682    }
1683}
1684
1685/// Specifies how to tunnel a connection.
1686#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1687pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1688    /// No tunneling.
1689    Direct,
1690    /// Via the specified SSH tunnel connection.
1691    Ssh(SshTunnel<C>),
1692    /// Via the specified AWS PrivateLink connection.
1693    AwsPrivatelink(AwsPrivatelink),
1694}
1695
1696impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1697    fn into_inline_connection(self, r: R) -> Tunnel {
1698        match self {
1699            Tunnel::Direct => Tunnel::Direct,
1700            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1701            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1702        }
1703    }
1704}
1705
1706impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1707    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1708        let compatible = match (self, other) {
1709            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1710            (s, o) => s == o,
1711        };
1712
1713        if !compatible {
1714            tracing::warn!(
1715                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1716                self,
1717                other
1718            );
1719
1720            return Err(AlterError { id });
1721        }
1722
1723        Ok(())
1724    }
1725}
1726
1727/// Specifies which MySQL SSL Mode to use:
1728/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1729/// This is not available as an enum in the mysql-async crate, so we define our own.
1730#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1731pub enum MySqlSslMode {
1732    Disabled,
1733    Required,
1734    VerifyCa,
1735    VerifyIdentity,
1736}
1737
1738/// A connection to a MySQL server.
1739#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1740pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1741    /// The hostname of the server.
1742    pub host: String,
1743    /// The port of the server.
1744    pub port: u16,
1745    /// The username to authenticate as.
1746    pub user: StringOrSecret,
1747    /// An optional password for authentication.
1748    pub password: Option<CatalogItemId>,
1749    /// A tunnel through which to route traffic.
1750    pub tunnel: Tunnel<C>,
1751    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1752    pub tls_mode: MySqlSslMode,
1753    /// An optional root TLS certificate in PEM format, to verify the server's
1754    /// identity.
1755    pub tls_root_cert: Option<StringOrSecret>,
1756    /// An optional TLS client certificate for authentication.
1757    pub tls_identity: Option<TlsIdentity>,
1758    /// Reference to the AWS connection information to be used for IAM authenitcation and
1759    /// assuming AWS roles.
1760    pub aws_connection: Option<AwsConnectionReference<C>>,
1761}
1762
1763impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
1764    for MySqlConnection<ReferencedConnection>
1765{
1766    fn into_inline_connection(self, r: R) -> MySqlConnection {
1767        let MySqlConnection {
1768            host,
1769            port,
1770            user,
1771            password,
1772            tunnel,
1773            tls_mode,
1774            tls_root_cert,
1775            tls_identity,
1776            aws_connection,
1777        } = self;
1778
1779        MySqlConnection {
1780            host,
1781            port,
1782            user,
1783            password,
1784            tunnel: tunnel.into_inline_connection(&r),
1785            tls_mode,
1786            tls_root_cert,
1787            tls_identity,
1788            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
1789        }
1790    }
1791}
1792
1793impl<C: ConnectionAccess> MySqlConnection<C> {
1794    fn validate_by_default(&self) -> bool {
1795        true
1796    }
1797}
1798
1799impl MySqlConnection<InlinedConnection> {
1800    pub async fn config(
1801        &self,
1802        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1803        storage_configuration: &StorageConfiguration,
1804        in_task: InTask,
1805    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
1806        // TODO(roshan): Set appropriate connection timeouts
1807        let mut opts = mysql_async::OptsBuilder::default()
1808            .ip_or_hostname(&self.host)
1809            .tcp_port(self.port)
1810            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
1811
1812        if let Some(password) = self.password {
1813            let password = secrets_reader
1814                .read_string_in_task_if(in_task, password)
1815                .await?;
1816            opts = opts.pass(Some(password));
1817        }
1818
1819        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
1820        // which uses opt-in security features (SSL, CA verification, & Identity verification).
1821        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
1822        // we need to appropriately disable features to match the intent of each enum value.
1823        let mut ssl_opts = match self.tls_mode {
1824            MySqlSslMode::Disabled => None,
1825            MySqlSslMode::Required => Some(
1826                mysql_async::SslOpts::default()
1827                    .with_danger_accept_invalid_certs(true)
1828                    .with_danger_skip_domain_validation(true),
1829            ),
1830            MySqlSslMode::VerifyCa => {
1831                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
1832            }
1833            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
1834        };
1835
1836        if matches!(
1837            self.tls_mode,
1838            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
1839        ) {
1840            if let Some(tls_root_cert) = &self.tls_root_cert {
1841                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1842                ssl_opts = ssl_opts.map(|opts| {
1843                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
1844                });
1845            }
1846        }
1847
1848        if let Some(identity) = &self.tls_identity {
1849            let key = secrets_reader
1850                .read_string_in_task_if(in_task, identity.key)
1851                .await?;
1852            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
1853            let Pkcs12Archive { der, pass } =
1854                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
1855
1856            // Add client identity to SSLOpts
1857            ssl_opts = ssl_opts.map(|opts| {
1858                opts.with_client_identity(Some(
1859                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
1860                ))
1861            });
1862        }
1863
1864        opts = opts.ssl_opts(ssl_opts);
1865
1866        let tunnel = match &self.tunnel {
1867            Tunnel::Direct => {
1868                // Ensure any host we connect to is resolved to an external address.
1869                let resolved = resolve_address(
1870                    &self.host,
1871                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1872                )
1873                .await?;
1874                mz_mysql_util::TunnelConfig::Direct {
1875                    resolved_ips: Some(resolved),
1876                }
1877            }
1878            Tunnel::Ssh(SshTunnel {
1879                connection_id,
1880                connection,
1881            }) => {
1882                let secret = secrets_reader
1883                    .read_in_task_if(in_task, *connection_id)
1884                    .await?;
1885                let key_pair = SshKeyPair::from_bytes(&secret)?;
1886                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1887                let resolved = resolve_address(
1888                    &connection.host,
1889                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1890                )
1891                .await?;
1892                mz_mysql_util::TunnelConfig::Ssh {
1893                    config: SshTunnelConfig {
1894                        host: resolved
1895                            .iter()
1896                            .map(|a| a.to_string())
1897                            .collect::<BTreeSet<_>>(),
1898                        port: connection.port,
1899                        user: connection.user.clone(),
1900                        key_pair,
1901                    },
1902                }
1903            }
1904            Tunnel::AwsPrivatelink(connection) => {
1905                assert_none!(connection.port);
1906                mz_mysql_util::TunnelConfig::AwsPrivatelink {
1907                    connection_id: connection.connection_id,
1908                }
1909            }
1910        };
1911
1912        let aws_config = match self.aws_connection.as_ref() {
1913            None => None,
1914            Some(aws_ref) => Some(
1915                aws_ref
1916                    .connection
1917                    .load_sdk_config(
1918                        &storage_configuration.connection_context,
1919                        aws_ref.connection_id,
1920                        in_task,
1921                    )
1922                    .await?,
1923            ),
1924        };
1925
1926        Ok(mz_mysql_util::Config::new(
1927            opts,
1928            tunnel,
1929            storage_configuration.parameters.ssh_timeout_config,
1930            in_task,
1931            storage_configuration
1932                .parameters
1933                .mysql_source_timeouts
1934                .clone(),
1935            aws_config,
1936        )?)
1937    }
1938
1939    pub async fn validate(
1940        &self,
1941        _id: CatalogItemId,
1942        storage_configuration: &StorageConfiguration,
1943    ) -> Result<MySqlConn, MySqlConnectionValidationError> {
1944        let config = self
1945            .config(
1946                &storage_configuration.connection_context.secrets_reader,
1947                storage_configuration,
1948                // We are in a normal tokio context during validation, already.
1949                InTask::No,
1950            )
1951            .await?;
1952        let mut conn = config
1953            .connect(
1954                "connection validation",
1955                &storage_configuration.connection_context.ssh_tunnel_manager,
1956            )
1957            .await?;
1958
1959        // Check if the MySQL database is configured to allow row-based consistent GTID replication
1960        let mut setting_errors = vec![];
1961        let gtid_res = mz_mysql_util::ensure_gtid_consistency(&mut conn).await;
1962        let binlog_res = mz_mysql_util::ensure_full_row_binlog_format(&mut conn).await;
1963        let order_res = mz_mysql_util::ensure_replication_commit_order(&mut conn).await;
1964        for res in [gtid_res, binlog_res, order_res] {
1965            match res {
1966                Err(MySqlError::InvalidSystemSetting {
1967                    setting,
1968                    expected,
1969                    actual,
1970                }) => {
1971                    setting_errors.push((setting, expected, actual));
1972                }
1973                Err(err) => Err(err)?,
1974                Ok(()) => {}
1975            }
1976        }
1977        if !setting_errors.is_empty() {
1978            Err(MySqlConnectionValidationError::ReplicationSettingsError(
1979                setting_errors,
1980            ))?;
1981        }
1982
1983        Ok(conn)
1984    }
1985}
1986
1987#[derive(Debug, thiserror::Error)]
1988pub enum MySqlConnectionValidationError {
1989    #[error("Invalid MySQL system replication settings")]
1990    ReplicationSettingsError(Vec<(String, String, String)>),
1991    #[error(transparent)]
1992    Client(#[from] MySqlError),
1993    #[error("{}", .0.display_with_causes())]
1994    Other(#[from] anyhow::Error),
1995}
1996
1997impl MySqlConnectionValidationError {
1998    pub fn detail(&self) -> Option<String> {
1999        match self {
2000            Self::ReplicationSettingsError(settings) => Some(format!(
2001                "Invalid MySQL system replication settings: {}",
2002                itertools::join(
2003                    settings.iter().map(|(setting, expected, actual)| format!(
2004                        "{}: expected {}, got {}",
2005                        setting, expected, actual
2006                    )),
2007                    "; "
2008                )
2009            )),
2010            _ => None,
2011        }
2012    }
2013
2014    pub fn hint(&self) -> Option<String> {
2015        match self {
2016            Self::ReplicationSettingsError(_) => {
2017                Some("Set the necessary MySQL database system settings.".into())
2018            }
2019            _ => None,
2020        }
2021    }
2022}
2023
2024impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
2025    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2026        let MySqlConnection {
2027            tunnel,
2028            // All non-tunnel options may change arbitrarily
2029            host: _,
2030            port: _,
2031            user: _,
2032            password: _,
2033            tls_mode: _,
2034            tls_root_cert: _,
2035            tls_identity: _,
2036            aws_connection: _,
2037        } = self;
2038
2039        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2040
2041        for (compatible, field) in compatibility_checks {
2042            if !compatible {
2043                tracing::warn!(
2044                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2045                    self,
2046                    other
2047                );
2048
2049                return Err(AlterError { id });
2050            }
2051        }
2052        Ok(())
2053    }
2054}
2055
2056/// Details how to connect to an instance of Microsoft SQL Server.
2057///
2058/// For specifics of connecting to SQL Server for purposes of creating a
2059/// Materialize Source, see [`SqlServerSourceConnection`] which wraps this type.
2060///
2061/// [`SqlServerSourceConnection`]: crate::sources::SqlServerSourceConnection
2062#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2063pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2064    /// The hostname of the server.
2065    pub host: String,
2066    /// The port of the server.
2067    pub port: u16,
2068    /// Database we should connect to.
2069    pub database: String,
2070    /// The username to authenticate as.
2071    pub user: StringOrSecret,
2072    /// Password used for authentication.
2073    pub password: CatalogItemId,
2074    /// A tunnel through which to route traffic.
2075    pub tunnel: Tunnel<C>,
2076    /// Level of encryption to use for the connection.
2077    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2078    /// Certificate validation policy
2079    pub certificate_validation_policy: mz_sql_server_util::config::CertificateValidationPolicy,
2080    /// TLS CA Certifiecate in PEM format
2081    pub tls_root_cert: Option<StringOrSecret>,
2082}
2083
2084impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2085    fn validate_by_default(&self) -> bool {
2086        true
2087    }
2088}
2089
2090impl SqlServerConnectionDetails<InlinedConnection> {
2091    /// Attempts to open a connection to the upstream SQL Server instance.
2092    pub async fn validate(
2093        &self,
2094        _id: CatalogItemId,
2095        storage_configuration: &StorageConfiguration,
2096    ) -> Result<mz_sql_server_util::Client, anyhow::Error> {
2097        let config = self
2098            .resolve_config(
2099                &storage_configuration.connection_context.secrets_reader,
2100                storage_configuration,
2101                InTask::No,
2102            )
2103            .await?;
2104        tracing::debug!(?config, "Validating SQL Server connection");
2105
2106        let mut client = mz_sql_server_util::Client::connect(config).await?;
2107
2108        // Ensure the upstream SQL Server instance is configured to allow CDC.
2109        //
2110        // Run all of the checks necessary and collect the errors to provide the best
2111        // guidance as to which system settings need to be enabled.
2112        let mut replication_errors = vec![];
2113        for error in [
2114            mz_sql_server_util::inspect::ensure_database_cdc_enabled(&mut client).await,
2115            mz_sql_server_util::inspect::ensure_snapshot_isolation_enabled(&mut client).await,
2116            mz_sql_server_util::inspect::ensure_sql_server_agent_running(&mut client).await,
2117        ] {
2118            match error {
2119                Err(mz_sql_server_util::SqlServerError::InvalidSystemSetting {
2120                    name,
2121                    expected,
2122                    actual,
2123                }) => replication_errors.push((name, expected, actual)),
2124                Err(other) => Err(other)?,
2125                Ok(()) => (),
2126            }
2127        }
2128        if !replication_errors.is_empty() {
2129            Err(SqlServerConnectionValidationError::ReplicationSettingsError(replication_errors))?;
2130        }
2131
2132        Ok(client)
2133    }
2134
2135    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2136    /// so the returned [`Config`] can be used to open a connection with the
2137    /// upstream system.
2138    ///
2139    /// The provided [`InTask`] argument determines whether any I/O is run in an
2140    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2141    /// future. The main goal here is to prevent running I/O in timely threads.
2142    ///
2143    /// [`Config`]: mz_sql_server_util::Config
2144    pub async fn resolve_config(
2145        &self,
2146        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2147        storage_configuration: &StorageConfiguration,
2148        in_task: InTask,
2149    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2150        let dyncfg = storage_configuration.config_set();
2151        let mut inner_config = tiberius::Config::new();
2152
2153        // Setup default connection params.
2154        inner_config.host(&self.host);
2155        inner_config.port(self.port);
2156        inner_config.database(self.database.clone());
2157        inner_config.encryption(self.encryption.into());
2158        match self.certificate_validation_policy {
2159            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2160                inner_config.trust_cert()
2161            }
2162            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2163                inner_config.trust_cert_ca_pem(
2164                    self.tls_root_cert
2165                        .as_ref()
2166                        .unwrap()
2167                        .get_string(in_task, secrets_reader)
2168                        .await
2169                        .context("ca certificate")?,
2170                );
2171            }
2172            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => (), // no-op
2173        }
2174
2175        inner_config.application_name("materialize");
2176
2177        // Read our auth settings from
2178        let user = self
2179            .user
2180            .get_string(in_task, secrets_reader)
2181            .await
2182            .context("username")?;
2183        let password = secrets_reader
2184            .read_string_in_task_if(in_task, self.password)
2185            .await
2186            .context("password")?;
2187        // TODO(sql_server3): Support other methods of authentication besides
2188        // username and password.
2189        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2190
2191        // Prevent users from probing our internal network ports by trying to
2192        // connect to localhost, or another non-external IP.
2193        let enfoce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2194
2195        let tunnel = match &self.tunnel {
2196            Tunnel::Direct => mz_sql_server_util::config::TunnelConfig::Direct,
2197            Tunnel::Ssh(SshTunnel {
2198                connection_id,
2199                connection: ssh_connection,
2200            }) => {
2201                let secret = secrets_reader
2202                    .read_in_task_if(in_task, *connection_id)
2203                    .await
2204                    .context("ssh secret")?;
2205                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2206                // Ensure any SSH-bastion host we connect to is resolved to an
2207                // external address.
2208                let addresses = resolve_address(&ssh_connection.host, enfoce_external_addresses)
2209                    .await
2210                    .context("ssh tunnel")?;
2211
2212                let config = SshTunnelConfig {
2213                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2214                    port: ssh_connection.port,
2215                    user: ssh_connection.user.clone(),
2216                    key_pair,
2217                };
2218                mz_sql_server_util::config::TunnelConfig::Ssh {
2219                    config,
2220                    manager: storage_configuration
2221                        .connection_context
2222                        .ssh_tunnel_manager
2223                        .clone(),
2224                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2225                    host: self.host.clone(),
2226                    port: self.port,
2227                }
2228            }
2229            Tunnel::AwsPrivatelink(private_link_connection) => {
2230                assert_none!(private_link_connection.port);
2231                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2232                    connection_id: private_link_connection.connection_id,
2233                    port: self.port,
2234                }
2235            }
2236        };
2237
2238        Ok(mz_sql_server_util::Config::new(
2239            inner_config,
2240            tunnel,
2241            in_task,
2242        ))
2243    }
2244}
2245
2246#[derive(Debug, Clone, thiserror::Error)]
2247pub enum SqlServerConnectionValidationError {
2248    #[error("Invalid SQL Server system replication settings")]
2249    ReplicationSettingsError(Vec<(String, String, String)>),
2250}
2251
2252impl SqlServerConnectionValidationError {
2253    pub fn detail(&self) -> Option<String> {
2254        match self {
2255            Self::ReplicationSettingsError(settings) => Some(format!(
2256                "Invalid SQL Server system replication settings: {}",
2257                itertools::join(
2258                    settings.iter().map(|(setting, expected, actual)| format!(
2259                        "{}: expected {}, got {}",
2260                        setting, expected, actual
2261                    )),
2262                    "; "
2263                )
2264            )),
2265        }
2266    }
2267
2268    pub fn hint(&self) -> Option<String> {
2269        match self {
2270            _ => None,
2271        }
2272    }
2273}
2274
2275impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2276    for SqlServerConnectionDetails<ReferencedConnection>
2277{
2278    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2279        let SqlServerConnectionDetails {
2280            host,
2281            port,
2282            database,
2283            user,
2284            password,
2285            tunnel,
2286            encryption,
2287            certificate_validation_policy,
2288            tls_root_cert,
2289        } = self;
2290
2291        SqlServerConnectionDetails {
2292            host,
2293            port,
2294            database,
2295            user,
2296            password,
2297            tunnel: tunnel.into_inline_connection(&r),
2298            encryption,
2299            certificate_validation_policy,
2300            tls_root_cert,
2301        }
2302    }
2303}
2304
2305impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2306    fn alter_compatible(
2307        &self,
2308        id: mz_repr::GlobalId,
2309        other: &Self,
2310    ) -> Result<(), crate::controller::AlterError> {
2311        let SqlServerConnectionDetails {
2312            tunnel,
2313            // TODO(sql_server2): Figure out how these variables are allowed to change.
2314            host: _,
2315            port: _,
2316            database: _,
2317            user: _,
2318            password: _,
2319            encryption: _,
2320            certificate_validation_policy: _,
2321            tls_root_cert: _,
2322        } = self;
2323
2324        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2325
2326        for (compatible, field) in compatibility_checks {
2327            if !compatible {
2328                tracing::warn!(
2329                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2330                    self,
2331                    other
2332                );
2333
2334                return Err(AlterError { id });
2335            }
2336        }
2337        Ok(())
2338    }
2339}
2340
2341/// A connection to an SSH tunnel.
2342#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2343pub struct SshConnection {
2344    pub host: String,
2345    pub port: u16,
2346    pub user: String,
2347}
2348
2349use self::inline::{
2350    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2351    ReferencedConnection,
2352};
2353
2354impl AlterCompatible for SshConnection {
2355    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2356        // Every element of the SSH connection is configurable.
2357        Ok(())
2358    }
2359}
2360
2361/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2362#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2363pub struct AwsPrivatelink {
2364    /// The ID of the connection to the AWS PrivateLink service.
2365    pub connection_id: CatalogItemId,
2366    // The availability zone to use when connecting to the AWS PrivateLink service.
2367    pub availability_zone: Option<String>,
2368    /// The port to use when connecting to the AWS PrivateLink service, if
2369    /// different from the port in [`KafkaBroker::address`].
2370    pub port: Option<u16>,
2371}
2372
2373impl AlterCompatible for AwsPrivatelink {
2374    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2375        let AwsPrivatelink {
2376            connection_id,
2377            availability_zone: _,
2378            port: _,
2379        } = self;
2380
2381        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2382
2383        for (compatible, field) in compatibility_checks {
2384            if !compatible {
2385                tracing::warn!(
2386                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2387                    self,
2388                    other
2389                );
2390
2391                return Err(AlterError { id });
2392            }
2393        }
2394
2395        Ok(())
2396    }
2397}
2398
2399/// Specifies an SSH tunnel connection.
2400#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2401pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2402    /// id of the ssh connection
2403    pub connection_id: CatalogItemId,
2404    /// ssh connection object
2405    pub connection: C::Ssh,
2406}
2407
2408impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2409    fn into_inline_connection(self, r: R) -> SshTunnel {
2410        let SshTunnel {
2411            connection,
2412            connection_id,
2413        } = self;
2414
2415        SshTunnel {
2416            connection: r.resolve_connection(connection).unwrap_ssh(),
2417            connection_id,
2418        }
2419    }
2420}
2421
2422impl SshTunnel<InlinedConnection> {
2423    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2424    /// secret.
2425    async fn connect(
2426        &self,
2427        storage_configuration: &StorageConfiguration,
2428        remote_host: &str,
2429        remote_port: u16,
2430        in_task: InTask,
2431    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2432        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2433        let resolved = resolve_address(
2434            &self.connection.host,
2435            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2436        )
2437        .await?;
2438        storage_configuration
2439            .connection_context
2440            .ssh_tunnel_manager
2441            .connect(
2442                SshTunnelConfig {
2443                    host: resolved
2444                        .iter()
2445                        .map(|a| a.to_string())
2446                        .collect::<BTreeSet<_>>(),
2447                    port: self.connection.port,
2448                    user: self.connection.user.clone(),
2449                    key_pair: SshKeyPair::from_bytes(
2450                        &storage_configuration
2451                            .connection_context
2452                            .secrets_reader
2453                            .read_in_task_if(in_task, self.connection_id)
2454                            .await?,
2455                    )?,
2456                },
2457                remote_host,
2458                remote_port,
2459                storage_configuration.parameters.ssh_timeout_config,
2460                in_task,
2461            )
2462            .await
2463    }
2464}
2465
2466impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2467    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2468        let SshTunnel {
2469            connection_id,
2470            connection,
2471        } = self;
2472
2473        let compatibility_checks = [
2474            (connection_id == &other.connection_id, "connection_id"),
2475            (
2476                connection.alter_compatible(id, &other.connection).is_ok(),
2477                "connection",
2478            ),
2479        ];
2480
2481        for (compatible, field) in compatibility_checks {
2482            if !compatible {
2483                tracing::warn!(
2484                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2485                    self,
2486                    other
2487                );
2488
2489                return Err(AlterError { id });
2490            }
2491        }
2492
2493        Ok(())
2494    }
2495}
2496
2497impl SshConnection {
2498    #[allow(clippy::unused_async)]
2499    async fn validate(
2500        &self,
2501        id: CatalogItemId,
2502        storage_configuration: &StorageConfiguration,
2503    ) -> Result<(), anyhow::Error> {
2504        let secret = storage_configuration
2505            .connection_context
2506            .secrets_reader
2507            .read_in_task_if(
2508                // We are in a normal tokio context during validation, already.
2509                InTask::No,
2510                id,
2511            )
2512            .await?;
2513        let key_pair = SshKeyPair::from_bytes(&secret)?;
2514
2515        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2516        let resolved = resolve_address(
2517            &self.host,
2518            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2519        )
2520        .await?;
2521
2522        let config = SshTunnelConfig {
2523            host: resolved
2524                .iter()
2525                .map(|a| a.to_string())
2526                .collect::<BTreeSet<_>>(),
2527            port: self.port,
2528            user: self.user.clone(),
2529            key_pair,
2530        };
2531        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2532        // can actually create a new connection to the ssh bastion, without tunneling.
2533        config
2534            .validate(storage_configuration.parameters.ssh_timeout_config)
2535            .await
2536    }
2537
2538    fn validate_by_default(&self) -> bool {
2539        false
2540    }
2541}
2542
2543impl AwsPrivatelinkConnection {
2544    #[allow(clippy::unused_async)]
2545    async fn validate(
2546        &self,
2547        id: CatalogItemId,
2548        storage_configuration: &StorageConfiguration,
2549    ) -> Result<(), anyhow::Error> {
2550        let Some(ref cloud_resource_reader) = storage_configuration
2551            .connection_context
2552            .cloud_resource_reader
2553        else {
2554            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2555        };
2556
2557        // No need to optionally run this in a task, as we are just validating from envd.
2558        let status = cloud_resource_reader.read(id).await?;
2559
2560        let availability = status
2561            .conditions
2562            .as_ref()
2563            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2564
2565        match availability {
2566            Some(condition) if condition.status == "True" => Ok(()),
2567            Some(condition) => Err(anyhow!("{}", condition.message)),
2568            None => Err(anyhow!("Endpoint availability is unknown")),
2569        }
2570    }
2571
2572    fn validate_by_default(&self) -> bool {
2573        false
2574    }
2575}