mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow};
18use aws_sdk_sts::config::ProvideCredentials;
19use iceberg::Catalog;
20use iceberg::CatalogBuilder;
21use iceberg::io::{S3_ACCESS_KEY_ID, S3_DISABLE_EC2_METADATA, S3_REGION, S3_SECRET_ACCESS_KEY};
22use iceberg_catalog_rest::{
23    REST_CATALOG_PROP_URI, REST_CATALOG_PROP_WAREHOUSE, RestCatalogBuilder,
24};
25use itertools::Itertools;
26use mz_ccsr::tls::{Certificate, Identity};
27use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
28use mz_dyncfg::ConfigSet;
29use mz_kafka_util::client::{
30    BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
31};
32use mz_mysql_util::{MySqlConn, MySqlError};
33use mz_ore::assert_none;
34use mz_ore::error::ErrorExt;
35use mz_ore::future::{InTask, OreFutureExt};
36use mz_ore::netio::resolve_address;
37use mz_ore::num::NonNeg;
38use mz_repr::{CatalogItemId, GlobalId};
39use mz_secrets::SecretsReader;
40use mz_ssh_util::keys::SshKeyPair;
41use mz_ssh_util::tunnel::SshTunnelConfig;
42use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
43use mz_tls_util::Pkcs12Archive;
44use mz_tracing::CloneableEnvFilter;
45use rdkafka::ClientContext;
46use rdkafka::config::FromClientConfigAndContext;
47use rdkafka::consumer::{BaseConsumer, Consumer};
48use regex::Regex;
49use serde::{Deserialize, Deserializer, Serialize};
50use tokio::net;
51use tokio::runtime::Handle;
52use tokio_postgres::config::SslMode;
53use tracing::{debug, warn};
54use url::Url;
55
56use crate::AlterCompatible;
57use crate::configuration::StorageConfiguration;
58use crate::connections::aws::{
59    AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
60};
61use crate::connections::string_or_secret::StringOrSecret;
62use crate::controller::AlterError;
63use crate::dyncfgs::{
64    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
65    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM, KAFKA_RECONNECT_BACKOFF,
66    KAFKA_RECONNECT_BACKOFF_MAX, KAFKA_RETRY_BACKOFF, KAFKA_RETRY_BACKOFF_MAX,
67};
68use crate::errors::{ContextCreationError, CsrConnectError};
69
70pub mod aws;
71pub mod inline;
72pub mod string_or_secret;
73
74const REST_CATALOG_PROP_SCOPE: &str = "scope";
75const REST_CATALOG_PROP_CREDENTIAL: &str = "credential";
76
77/// An extension trait for [`SecretsReader`]
78#[async_trait::async_trait]
79trait SecretsReaderExt {
80    /// `SecretsReader::read`, but optionally run in a task.
81    async fn read_in_task_if(
82        &self,
83        in_task: InTask,
84        id: CatalogItemId,
85    ) -> Result<Vec<u8>, anyhow::Error>;
86
87    /// `SecretsReader::read_string`, but optionally run in a task.
88    async fn read_string_in_task_if(
89        &self,
90        in_task: InTask,
91        id: CatalogItemId,
92    ) -> Result<String, anyhow::Error>;
93}
94
95#[async_trait::async_trait]
96impl SecretsReaderExt for Arc<dyn SecretsReader> {
97    async fn read_in_task_if(
98        &self,
99        in_task: InTask,
100        id: CatalogItemId,
101    ) -> Result<Vec<u8>, anyhow::Error> {
102        let sr = Arc::clone(self);
103        async move { sr.read(id).await }
104            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
105            .await
106    }
107    async fn read_string_in_task_if(
108        &self,
109        in_task: InTask,
110        id: CatalogItemId,
111    ) -> Result<String, anyhow::Error> {
112        let sr = Arc::clone(self);
113        async move { sr.read_string(id).await }
114            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
115            .await
116    }
117}
118
119/// Extra context to pass through when instantiating a connection for a source
120/// or sink.
121///
122/// Should be kept cheaply cloneable.
123#[derive(Debug, Clone)]
124pub struct ConnectionContext {
125    /// An opaque identifier for the environment in which this process is
126    /// running.
127    ///
128    /// The storage layer is intentionally unaware of the structure within this
129    /// identifier. Higher layers of the stack can make use of that structure,
130    /// but the storage layer should be oblivious to it.
131    pub environment_id: String,
132    /// The level for librdkafka's logs.
133    pub librdkafka_log_level: tracing::Level,
134    /// A prefix for an external ID to use for all AWS AssumeRole operations.
135    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
136    /// The ARN for a Materialize-controlled role to assume before assuming
137    /// a customer's requested role for an AWS connection.
138    pub aws_connection_role_arn: Option<String>,
139    /// A secrets reader.
140    pub secrets_reader: Arc<dyn SecretsReader>,
141    /// A cloud resource reader, if supported in this configuration.
142    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
143    /// A manager for SSH tunnels.
144    pub ssh_tunnel_manager: SshTunnelManager,
145}
146
147impl ConnectionContext {
148    /// Constructs a new connection context from command line arguments.
149    ///
150    /// **WARNING:** it is critical for security that the `aws_external_id` be
151    /// provided by the operator of the Materialize service (i.e., via a CLI
152    /// argument or environment variable) and not the end user of Materialize
153    /// (e.g., via a configuration option in a SQL statement). See
154    /// [`AwsExternalIdPrefix`] for details.
155    pub fn from_cli_args(
156        environment_id: String,
157        startup_log_level: &CloneableEnvFilter,
158        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
159        aws_connection_role_arn: Option<String>,
160        secrets_reader: Arc<dyn SecretsReader>,
161        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
162    ) -> ConnectionContext {
163        ConnectionContext {
164            environment_id,
165            librdkafka_log_level: mz_ore::tracing::crate_level(
166                &startup_log_level.clone().into(),
167                "librdkafka",
168            ),
169            aws_external_id_prefix,
170            aws_connection_role_arn,
171            secrets_reader,
172            cloud_resource_reader,
173            ssh_tunnel_manager: SshTunnelManager::default(),
174        }
175    }
176
177    /// Constructs a new connection context for usage in tests.
178    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
179        ConnectionContext {
180            environment_id: "test-environment-id".into(),
181            librdkafka_log_level: tracing::Level::INFO,
182            aws_external_id_prefix: Some(
183                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
184                    "test-aws-external-id-prefix",
185                )
186                .expect("infallible"),
187            ),
188            aws_connection_role_arn: Some(
189                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
190            ),
191            secrets_reader,
192            cloud_resource_reader: None,
193            ssh_tunnel_manager: SshTunnelManager::default(),
194        }
195    }
196}
197
198#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
199pub enum Connection<C: ConnectionAccess = InlinedConnection> {
200    Kafka(KafkaConnection<C>),
201    Csr(CsrConnection<C>),
202    Postgres(PostgresConnection<C>),
203    Ssh(SshConnection),
204    Aws(AwsConnection),
205    AwsPrivatelink(AwsPrivatelinkConnection),
206    MySql(MySqlConnection<C>),
207    SqlServer(SqlServerConnectionDetails<C>),
208    IcebergCatalog(IcebergCatalogConnection<C>),
209}
210
211impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
212    for Connection<ReferencedConnection>
213{
214    fn into_inline_connection(self, r: R) -> Connection {
215        match self {
216            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
217            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
218            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
219            Connection::Ssh(ssh) => Connection::Ssh(ssh),
220            Connection::Aws(aws) => Connection::Aws(aws),
221            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
222            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
223            Connection::SqlServer(sql_server) => {
224                Connection::SqlServer(sql_server.into_inline_connection(r))
225            }
226            Connection::IcebergCatalog(iceberg) => {
227                Connection::IcebergCatalog(iceberg.into_inline_connection(r))
228            }
229        }
230    }
231}
232
233impl<C: ConnectionAccess> Connection<C> {
234    /// Whether this connection should be validated by default on creation.
235    pub fn validate_by_default(&self) -> bool {
236        match self {
237            Connection::Kafka(conn) => conn.validate_by_default(),
238            Connection::Csr(conn) => conn.validate_by_default(),
239            Connection::Postgres(conn) => conn.validate_by_default(),
240            Connection::Ssh(conn) => conn.validate_by_default(),
241            Connection::Aws(conn) => conn.validate_by_default(),
242            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
243            Connection::MySql(conn) => conn.validate_by_default(),
244            Connection::SqlServer(conn) => conn.validate_by_default(),
245            Connection::IcebergCatalog(conn) => conn.validate_by_default(),
246        }
247    }
248}
249
250impl Connection<InlinedConnection> {
251    /// Validates this connection by attempting to connect to the upstream system.
252    pub async fn validate(
253        &self,
254        id: CatalogItemId,
255        storage_configuration: &StorageConfiguration,
256    ) -> Result<(), ConnectionValidationError> {
257        match self {
258            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
259            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
260            Connection::Postgres(conn) => {
261                conn.validate(id, storage_configuration).await?;
262            }
263            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
264            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
265            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
266            Connection::MySql(conn) => {
267                conn.validate(id, storage_configuration).await?;
268            }
269            Connection::SqlServer(conn) => {
270                conn.validate(id, storage_configuration).await?;
271            }
272            Connection::IcebergCatalog(conn) => conn.validate(id, storage_configuration).await?,
273        }
274        Ok(())
275    }
276
277    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
278        match self {
279            Self::Kafka(conn) => conn,
280            o => unreachable!("{o:?} is not a Kafka connection"),
281        }
282    }
283
284    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
285        match self {
286            Self::Postgres(conn) => conn,
287            o => unreachable!("{o:?} is not a Postgres connection"),
288        }
289    }
290
291    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
292        match self {
293            Self::MySql(conn) => conn,
294            o => unreachable!("{o:?} is not a MySQL connection"),
295        }
296    }
297
298    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
299        match self {
300            Self::SqlServer(conn) => conn,
301            o => unreachable!("{o:?} is not a SQL Server connection"),
302        }
303    }
304
305    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
306        match self {
307            Self::Aws(conn) => conn,
308            o => unreachable!("{o:?} is not an AWS connection"),
309        }
310    }
311
312    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
313        match self {
314            Self::Ssh(conn) => conn,
315            o => unreachable!("{o:?} is not an SSH connection"),
316        }
317    }
318
319    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
320        match self {
321            Self::Csr(conn) => conn,
322            o => unreachable!("{o:?} is not a Kafka connection"),
323        }
324    }
325
326    pub fn unwrap_iceberg_catalog(self) -> <InlinedConnection as ConnectionAccess>::IcebergCatalog {
327        match self {
328            Self::IcebergCatalog(conn) => conn,
329            o => unreachable!("{o:?} is not an Iceberg catalog connection"),
330        }
331    }
332}
333
334/// An error returned by [`Connection::validate`].
335#[derive(thiserror::Error, Debug)]
336pub enum ConnectionValidationError {
337    #[error(transparent)]
338    Postgres(#[from] PostgresConnectionValidationError),
339    #[error(transparent)]
340    MySql(#[from] MySqlConnectionValidationError),
341    #[error(transparent)]
342    SqlServer(#[from] SqlServerConnectionValidationError),
343    #[error(transparent)]
344    Aws(#[from] AwsConnectionValidationError),
345    #[error("{}", .0.display_with_causes())]
346    Other(#[from] anyhow::Error),
347}
348
349impl ConnectionValidationError {
350    /// Reports additional details about the error, if any are available.
351    pub fn detail(&self) -> Option<String> {
352        match self {
353            ConnectionValidationError::Postgres(e) => e.detail(),
354            ConnectionValidationError::MySql(e) => e.detail(),
355            ConnectionValidationError::SqlServer(e) => e.detail(),
356            ConnectionValidationError::Aws(e) => e.detail(),
357            ConnectionValidationError::Other(_) => None,
358        }
359    }
360
361    /// Reports a hint for the user about how the error could be fixed.
362    pub fn hint(&self) -> Option<String> {
363        match self {
364            ConnectionValidationError::Postgres(e) => e.hint(),
365            ConnectionValidationError::MySql(e) => e.hint(),
366            ConnectionValidationError::SqlServer(e) => e.hint(),
367            ConnectionValidationError::Aws(e) => e.hint(),
368            ConnectionValidationError::Other(_) => None,
369        }
370    }
371}
372
373impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
374    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
375        match (self, other) {
376            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
377            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
378            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
379            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
380            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
381            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
382            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
383            _ => {
384                tracing::warn!(
385                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
386                    self,
387                    other
388                );
389                Err(AlterError { id })
390            }
391        }
392    }
393}
394
395#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
396pub struct RestIcebergCatalog {
397    /// For REST catalogs, the oauth2 credential in a `CLIENT_ID:CLIENT_SECRET` format
398    pub credential: StringOrSecret,
399    /// The oauth2 scope for REST catalogs
400    pub scope: Option<String>,
401    /// The warehouse for REST catalogs
402    pub warehouse: Option<String>,
403}
404
405#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
406pub struct S3TablesRestIcebergCatalog<C: ConnectionAccess = InlinedConnection> {
407    /// The AWS connection details, for s3tables
408    pub aws_connection: AwsConnectionReference<C>,
409    /// The warehouse for s3tables
410    pub warehouse: String,
411}
412
413impl<R: ConnectionResolver> IntoInlineConnection<S3TablesRestIcebergCatalog, R>
414    for S3TablesRestIcebergCatalog<ReferencedConnection>
415{
416    fn into_inline_connection(self, r: R) -> S3TablesRestIcebergCatalog {
417        S3TablesRestIcebergCatalog {
418            aws_connection: self.aws_connection.into_inline_connection(&r),
419            warehouse: self.warehouse,
420        }
421    }
422}
423
424#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
425pub enum IcebergCatalogType {
426    Rest,
427    S3TablesRest,
428}
429
430#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
431pub enum IcebergCatalogImpl<C: ConnectionAccess = InlinedConnection> {
432    Rest(RestIcebergCatalog),
433    S3TablesRest(S3TablesRestIcebergCatalog<C>),
434}
435
436impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogImpl, R>
437    for IcebergCatalogImpl<ReferencedConnection>
438{
439    fn into_inline_connection(self, r: R) -> IcebergCatalogImpl {
440        match self {
441            IcebergCatalogImpl::Rest(rest) => IcebergCatalogImpl::Rest(rest),
442            IcebergCatalogImpl::S3TablesRest(s3tables) => {
443                IcebergCatalogImpl::S3TablesRest(s3tables.into_inline_connection(r))
444            }
445        }
446    }
447}
448
449#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
450pub struct IcebergCatalogConnection<C: ConnectionAccess = InlinedConnection> {
451    /// The catalog impl impl of that catalog
452    pub catalog: IcebergCatalogImpl<C>,
453    /// Where the catalog is located
454    pub uri: reqwest::Url,
455}
456
457impl AlterCompatible for IcebergCatalogConnection {
458    fn alter_compatible(&self, id: GlobalId, _other: &Self) -> Result<(), AlterError> {
459        Err(AlterError { id })
460    }
461}
462
463impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogConnection, R>
464    for IcebergCatalogConnection<ReferencedConnection>
465{
466    fn into_inline_connection(self, r: R) -> IcebergCatalogConnection {
467        IcebergCatalogConnection {
468            catalog: self.catalog.into_inline_connection(&r),
469            uri: self.uri,
470        }
471    }
472}
473
474impl<C: ConnectionAccess> IcebergCatalogConnection<C> {
475    fn validate_by_default(&self) -> bool {
476        true
477    }
478}
479
480impl IcebergCatalogConnection<InlinedConnection> {
481    pub async fn connect(
482        &self,
483        storage_configuration: &StorageConfiguration,
484        in_task: InTask,
485    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
486        match self.catalog {
487            IcebergCatalogImpl::S3TablesRest(ref s3tables) => {
488                self.connect_s3tables(s3tables, storage_configuration, in_task)
489                    .await
490            }
491            IcebergCatalogImpl::Rest(ref rest) => {
492                self.connect_rest(rest, storage_configuration, in_task)
493                    .await
494            }
495        }
496    }
497
498    pub fn catalog_type(&self) -> IcebergCatalogType {
499        match self.catalog {
500            IcebergCatalogImpl::S3TablesRest(_) => IcebergCatalogType::S3TablesRest,
501            IcebergCatalogImpl::Rest(_) => IcebergCatalogType::Rest,
502        }
503    }
504
505    pub fn s3tables_catalog(&self) -> Option<&S3TablesRestIcebergCatalog> {
506        match &self.catalog {
507            IcebergCatalogImpl::S3TablesRest(s3tables) => Some(s3tables),
508            _ => None,
509        }
510    }
511
512    pub fn rest_catalog(&self) -> Option<&RestIcebergCatalog> {
513        match &self.catalog {
514            IcebergCatalogImpl::Rest(rest) => Some(rest),
515            _ => None,
516        }
517    }
518
519    async fn connect_s3tables(
520        &self,
521        s3tables: &S3TablesRestIcebergCatalog,
522        storage_configuration: &StorageConfiguration,
523        in_task: InTask,
524    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
525        let aws_ref = &s3tables.aws_connection;
526        let aws_config = aws_ref
527            .connection
528            .load_sdk_config(
529                &storage_configuration.connection_context,
530                aws_ref.connection_id,
531                in_task,
532            )
533            .await?;
534
535        let provider = aws_config
536            .credentials_provider()
537            .context("No credentials provider in AWS config")?;
538
539        let creds = provider
540            .provide_credentials()
541            .await
542            .context("Failed to get AWS credentials")?;
543
544        let access_key_id = creds.access_key_id();
545        let secret_access_key = creds.secret_access_key();
546
547        let aws_region = aws_config
548            .region()
549            .map(|r| r.as_ref())
550            .unwrap_or("us-east-1");
551
552        let props = vec![
553            (S3_REGION.to_string(), aws_region.to_string()),
554            (S3_DISABLE_EC2_METADATA.to_string(), "true".to_string()),
555            (S3_ACCESS_KEY_ID.to_string(), access_key_id.to_string()),
556            (
557                S3_SECRET_ACCESS_KEY.to_string(),
558                secret_access_key.to_string(),
559            ),
560            (
561                REST_CATALOG_PROP_WAREHOUSE.to_string(),
562                s3tables.warehouse.clone(),
563            ),
564            (
565                REST_CATALOG_PROP_URI.to_string(),
566                self.uri.to_string().clone(),
567            ),
568        ];
569
570        let catalog = RestCatalogBuilder::default()
571            .with_aws_config(aws_config)
572            .load("IcebergCatalog", props.into_iter().collect())
573            .await?;
574
575        Ok(Arc::new(catalog))
576    }
577
578    async fn connect_rest(
579        &self,
580        rest: &RestIcebergCatalog,
581        storage_configuration: &StorageConfiguration,
582        in_task: InTask,
583    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
584        let mut props = BTreeMap::from([(
585            REST_CATALOG_PROP_URI.to_string(),
586            self.uri.to_string().clone(),
587        )]);
588
589        if let Some(warehouse) = &rest.warehouse {
590            props.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
591        }
592
593        let credential = rest
594            .credential
595            .get_string(
596                in_task,
597                &storage_configuration.connection_context.secrets_reader,
598            )
599            .await
600            .map_err(|e| anyhow!("failed to read Iceberg catalog credential: {e}"))?;
601        props.insert(REST_CATALOG_PROP_CREDENTIAL.to_string(), credential);
602
603        if let Some(scope) = &rest.scope {
604            props.insert(REST_CATALOG_PROP_SCOPE.to_string(), scope.clone());
605        }
606
607        let catalog = RestCatalogBuilder::default()
608            .load("IcebergCatalog", props.into_iter().collect())
609            .await
610            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
611        Ok(Arc::new(catalog))
612    }
613
614    async fn validate(
615        &self,
616        _id: CatalogItemId,
617        storage_configuration: &StorageConfiguration,
618    ) -> Result<(), ConnectionValidationError> {
619        let catalog = self
620            .connect(storage_configuration, InTask::No)
621            .await
622            .map_err(|e| {
623                ConnectionValidationError::Other(anyhow!("failed to connect to catalog: {e}"))
624            })?;
625
626        // If we can list namespaces, the connection is valid.
627        catalog.list_namespaces(None).await.map_err(|e| {
628            ConnectionValidationError::Other(anyhow!("failed to list namespaces: {e}"))
629        })?;
630
631        Ok(())
632    }
633}
634
635#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
636pub struct AwsPrivatelinkConnection {
637    pub service_name: String,
638    pub availability_zones: Vec<String>,
639}
640
641impl AlterCompatible for AwsPrivatelinkConnection {
642    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
643        // Every element of the AwsPrivatelinkConnection connection is configurable.
644        Ok(())
645    }
646}
647
648#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
649pub struct KafkaTlsConfig {
650    pub identity: Option<TlsIdentity>,
651    pub root_cert: Option<StringOrSecret>,
652}
653
654#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
655pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
656    pub mechanism: String,
657    pub username: StringOrSecret,
658    pub password: Option<CatalogItemId>,
659    pub aws: Option<AwsConnectionReference<C>>,
660}
661
662impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
663    for KafkaSaslConfig<ReferencedConnection>
664{
665    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
666        KafkaSaslConfig {
667            mechanism: self.mechanism,
668            username: self.username,
669            password: self.password,
670            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
671        }
672    }
673}
674
675/// Specifies a Kafka broker in a [`KafkaConnection`].
676#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
677pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
678    /// The address of the Kafka broker.
679    pub address: String,
680    /// An optional tunnel to use when connecting to the broker.
681    pub tunnel: Tunnel<C>,
682}
683
684impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
685    for KafkaBroker<ReferencedConnection>
686{
687    fn into_inline_connection(self, r: R) -> KafkaBroker {
688        let KafkaBroker { address, tunnel } = self;
689        KafkaBroker {
690            address,
691            tunnel: tunnel.into_inline_connection(r),
692        }
693    }
694}
695
696#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
697pub struct KafkaTopicOptions {
698    /// The replication factor for the topic.
699    /// If `None`, the broker default will be used.
700    pub replication_factor: Option<NonNeg<i32>>,
701    /// The number of partitions to create.
702    /// If `None`, the broker default will be used.
703    pub partition_count: Option<NonNeg<i32>>,
704    /// The initial configuration parameters for the topic.
705    pub topic_config: BTreeMap<String, String>,
706}
707
708#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
709pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
710    pub brokers: Vec<KafkaBroker<C>>,
711    /// A tunnel through which to route traffic,
712    /// that can be overridden for individual brokers
713    /// in `brokers`.
714    pub default_tunnel: Tunnel<C>,
715    pub progress_topic: Option<String>,
716    pub progress_topic_options: KafkaTopicOptions,
717    pub options: BTreeMap<String, StringOrSecret>,
718    pub tls: Option<KafkaTlsConfig>,
719    pub sasl: Option<KafkaSaslConfig<C>>,
720}
721
722impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
723    for KafkaConnection<ReferencedConnection>
724{
725    fn into_inline_connection(self, r: R) -> KafkaConnection {
726        let KafkaConnection {
727            brokers,
728            progress_topic,
729            progress_topic_options,
730            default_tunnel,
731            options,
732            tls,
733            sasl,
734        } = self;
735
736        let brokers = brokers
737            .into_iter()
738            .map(|broker| broker.into_inline_connection(&r))
739            .collect();
740
741        KafkaConnection {
742            brokers,
743            progress_topic,
744            progress_topic_options,
745            default_tunnel: default_tunnel.into_inline_connection(&r),
746            options,
747            tls,
748            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
749        }
750    }
751}
752
753impl<C: ConnectionAccess> KafkaConnection<C> {
754    /// Returns the name of the progress topic to use for the connection.
755    ///
756    /// The caller is responsible for providing the connection ID as it is not
757    /// known to `KafkaConnection`.
758    pub fn progress_topic(
759        &self,
760        connection_context: &ConnectionContext,
761        connection_id: CatalogItemId,
762    ) -> Cow<'_, str> {
763        if let Some(progress_topic) = &self.progress_topic {
764            Cow::Borrowed(progress_topic)
765        } else {
766            Cow::Owned(format!(
767                "_materialize-progress-{}-{}",
768                connection_context.environment_id, connection_id,
769            ))
770        }
771    }
772
773    fn validate_by_default(&self) -> bool {
774        true
775    }
776}
777
778impl KafkaConnection {
779    /// Generates a string that can be used as the base for a configuration ID
780    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
781    /// or sink.
782    pub fn id_base(
783        connection_context: &ConnectionContext,
784        connection_id: CatalogItemId,
785        object_id: GlobalId,
786    ) -> String {
787        format!(
788            "materialize-{}-{}-{}",
789            connection_context.environment_id, connection_id, object_id,
790        )
791    }
792
793    /// Enriches the provided `client_id` according to any enrichment rules in
794    /// the `kafka_client_id_enrichment_rules` configuration parameter.
795    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
796        #[derive(Debug, Deserialize)]
797        struct EnrichmentRule {
798            #[serde(deserialize_with = "deserialize_regex")]
799            pattern: Regex,
800            payload: String,
801        }
802
803        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
804        where
805            D: Deserializer<'de>,
806        {
807            let buf = String::deserialize(deserializer)?;
808            Regex::new(&buf).map_err(serde::de::Error::custom)
809        }
810
811        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
812        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
813            Ok(rules) => rules,
814            Err(e) => {
815                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
816                return;
817            }
818        };
819
820        // Check every rule against every broker. Rules are matched in the order
821        // that they are specified. It is usually a configuration error if
822        // multiple rules match the same list of Kafka brokers, but we
823        // nonetheless want to provide well defined semantics.
824        debug!(?self.brokers, "evaluating client ID enrichment rules");
825        for rule in rules {
826            let is_match = self
827                .brokers
828                .iter()
829                .any(|b| rule.pattern.is_match(&b.address));
830            debug!(?rule, is_match, "evaluated client ID enrichment rule");
831            if is_match {
832                client_id.push('-');
833                client_id.push_str(&rule.payload);
834            }
835        }
836    }
837
838    /// Creates a Kafka client for the connection.
839    pub async fn create_with_context<C, T>(
840        &self,
841        storage_configuration: &StorageConfiguration,
842        context: C,
843        extra_options: &BTreeMap<&str, String>,
844        in_task: InTask,
845    ) -> Result<T, ContextCreationError>
846    where
847        C: ClientContext,
848        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
849    {
850        let mut options = self.options.clone();
851
852        // Ensure that Kafka topics are *not* automatically created when
853        // consuming, producing, or fetching metadata for a topic. This ensures
854        // that we don't accidentally create topics with the wrong number of
855        // partitions.
856        options.insert("allow.auto.create.topics".into(), "false".into());
857
858        let brokers = match &self.default_tunnel {
859            Tunnel::AwsPrivatelink(t) => {
860                assert!(&self.brokers.is_empty());
861
862                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
863                    .get(storage_configuration.config_set());
864                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
865
866                // When using a default privatelink tunnel broker/brokers cannot be specified
867                // instead the tunnel connection_id and port are used for the initial connection.
868                format!(
869                    "{}:{}",
870                    vpc_endpoint_host(
871                        t.connection_id,
872                        None, // Default tunnel does not support availability zones.
873                    ),
874                    t.port.unwrap_or(9092)
875                )
876            }
877            _ => self.brokers.iter().map(|b| &b.address).join(","),
878        };
879        options.insert("bootstrap.servers".into(), brokers.into());
880        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
881            (false, false) => "PLAINTEXT",
882            (true, false) => "SSL",
883            (false, true) => "SASL_PLAINTEXT",
884            (true, true) => "SASL_SSL",
885        };
886        options.insert("security.protocol".into(), security_protocol.into());
887        if let Some(tls) = &self.tls {
888            if let Some(root_cert) = &tls.root_cert {
889                options.insert("ssl.ca.pem".into(), root_cert.clone());
890            }
891            if let Some(identity) = &tls.identity {
892                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
893                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
894            }
895        }
896        if let Some(sasl) = &self.sasl {
897            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
898            options.insert("sasl.username".into(), sasl.username.clone());
899            if let Some(password) = sasl.password {
900                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
901            }
902        }
903
904        options.insert(
905            "retry.backoff.ms".into(),
906            KAFKA_RETRY_BACKOFF
907                .get(storage_configuration.config_set())
908                .as_millis()
909                .into(),
910        );
911        options.insert(
912            "retry.backoff.max.ms".into(),
913            KAFKA_RETRY_BACKOFF_MAX
914                .get(storage_configuration.config_set())
915                .as_millis()
916                .into(),
917        );
918        options.insert(
919            "reconnect.backoff.ms".into(),
920            KAFKA_RECONNECT_BACKOFF
921                .get(storage_configuration.config_set())
922                .as_millis()
923                .into(),
924        );
925        options.insert(
926            "reconnect.backoff.max.ms".into(),
927            KAFKA_RECONNECT_BACKOFF_MAX
928                .get(storage_configuration.config_set())
929                .as_millis()
930                .into(),
931        );
932
933        let mut config = mz_kafka_util::client::create_new_client_config(
934            storage_configuration
935                .connection_context
936                .librdkafka_log_level,
937            storage_configuration.parameters.kafka_timeout_config,
938        );
939        for (k, v) in options {
940            config.set(
941                k,
942                v.get_string(
943                    in_task,
944                    &storage_configuration.connection_context.secrets_reader,
945                )
946                .await
947                .context("reading kafka secret")?,
948            );
949        }
950        for (k, v) in extra_options {
951            config.set(*k, v);
952        }
953
954        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
955            None => None,
956            Some(aws) => Some(
957                aws.connection
958                    .load_sdk_config(
959                        &storage_configuration.connection_context,
960                        aws.connection_id,
961                        in_task,
962                    )
963                    .await?,
964            ),
965        };
966
967        // TODO(roshan): Implement enforcement of external address validation once
968        // rdkafka client has been updated to support providing multiple resolved
969        // addresses for brokers
970        let mut context = TunnelingClientContext::new(
971            context,
972            Handle::current(),
973            storage_configuration
974                .connection_context
975                .ssh_tunnel_manager
976                .clone(),
977            storage_configuration.parameters.ssh_timeout_config,
978            aws_config,
979            in_task,
980        );
981
982        match &self.default_tunnel {
983            Tunnel::Direct => {
984                // By default, don't offer a default override for broker address lookup.
985            }
986            Tunnel::AwsPrivatelink(pl) => {
987                context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
988                    pl.connection_id,
989                    None, // Default tunnel does not support availability zones.
990                )));
991            }
992            Tunnel::Ssh(ssh_tunnel) => {
993                let secret = storage_configuration
994                    .connection_context
995                    .secrets_reader
996                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
997                    .await?;
998                let key_pair = SshKeyPair::from_bytes(&secret)?;
999
1000                // Ensure any ssh-bastion address we connect to is resolved to an external address.
1001                let resolved = resolve_address(
1002                    &ssh_tunnel.connection.host,
1003                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1004                )
1005                .await?;
1006                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
1007                    host: resolved
1008                        .iter()
1009                        .map(|a| a.to_string())
1010                        .collect::<BTreeSet<_>>(),
1011                    port: ssh_tunnel.connection.port,
1012                    user: ssh_tunnel.connection.user.clone(),
1013                    key_pair,
1014                }));
1015            }
1016        }
1017
1018        for broker in &self.brokers {
1019            let mut addr_parts = broker.address.splitn(2, ':');
1020            let addr = BrokerAddr {
1021                host: addr_parts
1022                    .next()
1023                    .context("BROKER is not address:port")?
1024                    .into(),
1025                port: addr_parts
1026                    .next()
1027                    .unwrap_or("9092")
1028                    .parse()
1029                    .context("parsing BROKER port")?,
1030            };
1031            match &broker.tunnel {
1032                Tunnel::Direct => {
1033                    // By default, don't override broker address lookup.
1034                    //
1035                    // N.B.
1036                    //
1037                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
1038                    // we avoid doing because:
1039                    // - Its not necessary.
1040                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
1041                    // in the `TunnelingClientContext`.
1042                }
1043                Tunnel::AwsPrivatelink(aws_privatelink) => {
1044                    let host = mz_cloud_resources::vpc_endpoint_host(
1045                        aws_privatelink.connection_id,
1046                        aws_privatelink.availability_zone.as_deref(),
1047                    );
1048                    let port = aws_privatelink.port;
1049                    context.add_broker_rewrite(
1050                        addr,
1051                        BrokerRewrite {
1052                            host: host.clone(),
1053                            port,
1054                        },
1055                    );
1056                }
1057                Tunnel::Ssh(ssh_tunnel) => {
1058                    // Ensure any SSH bastion address we connect to is resolved to an external address.
1059                    let ssh_host_resolved = resolve_address(
1060                        &ssh_tunnel.connection.host,
1061                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1062                    )
1063                    .await?;
1064                    context
1065                        .add_ssh_tunnel(
1066                            addr,
1067                            SshTunnelConfig {
1068                                host: ssh_host_resolved
1069                                    .iter()
1070                                    .map(|a| a.to_string())
1071                                    .collect::<BTreeSet<_>>(),
1072                                port: ssh_tunnel.connection.port,
1073                                user: ssh_tunnel.connection.user.clone(),
1074                                key_pair: SshKeyPair::from_bytes(
1075                                    &storage_configuration
1076                                        .connection_context
1077                                        .secrets_reader
1078                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
1079                                        .await?,
1080                                )?,
1081                            },
1082                        )
1083                        .await
1084                        .map_err(ContextCreationError::Ssh)?;
1085                }
1086            }
1087        }
1088
1089        Ok(config.create_with_context(context)?)
1090    }
1091
1092    async fn validate(
1093        &self,
1094        _id: CatalogItemId,
1095        storage_configuration: &StorageConfiguration,
1096    ) -> Result<(), anyhow::Error> {
1097        let (context, error_rx) = MzClientContext::with_errors();
1098        let consumer: BaseConsumer<_> = self
1099            .create_with_context(
1100                storage_configuration,
1101                context,
1102                &BTreeMap::new(),
1103                // We are in a normal tokio context during validation, already.
1104                InTask::No,
1105            )
1106            .await?;
1107        let consumer = Arc::new(consumer);
1108
1109        let timeout = storage_configuration
1110            .parameters
1111            .kafka_timeout_config
1112            .fetch_metadata_timeout;
1113
1114        // librdkafka doesn't expose an API for determining whether a connection to
1115        // the Kafka cluster has been successfully established. So we make a
1116        // metadata request, though we don't care about the results, so that we can
1117        // report any errors making that request. If the request succeeds, we know
1118        // we were able to contact at least one broker, and that's a good proxy for
1119        // being able to contact all the brokers in the cluster.
1120        //
1121        // The downside of this approach is it produces a generic error message like
1122        // "metadata fetch error" with no additional details. The real networking
1123        // error is buried in the librdkafka logs, which are not visible to users.
1124        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
1125            let consumer = Arc::clone(&consumer);
1126            move || consumer.fetch_metadata(None, timeout)
1127        })
1128        .await;
1129        match result {
1130            Ok(_) => Ok(()),
1131            // The error returned by `fetch_metadata` does not provide any details which makes for
1132            // a crappy user facing error message. For this reason we attempt to grab a better
1133            // error message from the client context, which should contain any error logs emitted
1134            // by librdkafka, and fallback to the generic error if there is nothing there.
1135            Err(err) => {
1136                // Multiple errors might have been logged during this validation but some are more
1137                // relevant than others. Specifically, we prefer non-internal errors over internal
1138                // errors since those give much more useful information to the users.
1139                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
1140                    MzKafkaError::Internal(_) => new,
1141                    _ => cur,
1142                });
1143
1144                // Don't drop the consumer until after we've drained the errors
1145                // channel. Dropping the consumer can introduce spurious errors.
1146                // See database-issues#7432.
1147                drop(consumer);
1148
1149                match main_err {
1150                    Some(err) => Err(err.into()),
1151                    None => Err(err.into()),
1152                }
1153            }
1154        }
1155    }
1156}
1157
1158impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
1159    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1160        let KafkaConnection {
1161            brokers: _,
1162            default_tunnel: _,
1163            progress_topic,
1164            progress_topic_options,
1165            options: _,
1166            tls: _,
1167            sasl: _,
1168        } = self;
1169
1170        let compatibility_checks = [
1171            (progress_topic == &other.progress_topic, "progress_topic"),
1172            (
1173                progress_topic_options == &other.progress_topic_options,
1174                "progress_topic_options",
1175            ),
1176        ];
1177
1178        for (compatible, field) in compatibility_checks {
1179            if !compatible {
1180                tracing::warn!(
1181                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1182                    self,
1183                    other
1184                );
1185
1186                return Err(AlterError { id });
1187            }
1188        }
1189
1190        Ok(())
1191    }
1192}
1193
1194/// A connection to a Confluent Schema Registry.
1195#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1196pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1197    /// The URL of the schema registry.
1198    pub url: Url,
1199    /// Trusted root TLS certificate in PEM format.
1200    pub tls_root_cert: Option<StringOrSecret>,
1201    /// An optional TLS client certificate for authentication with the schema
1202    /// registry.
1203    pub tls_identity: Option<TlsIdentity>,
1204    /// Optional HTTP authentication credentials for the schema registry.
1205    pub http_auth: Option<CsrConnectionHttpAuth>,
1206    /// A tunnel through which to route traffic.
1207    pub tunnel: Tunnel<C>,
1208}
1209
1210impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1211    for CsrConnection<ReferencedConnection>
1212{
1213    fn into_inline_connection(self, r: R) -> CsrConnection {
1214        let CsrConnection {
1215            url,
1216            tls_root_cert,
1217            tls_identity,
1218            http_auth,
1219            tunnel,
1220        } = self;
1221        CsrConnection {
1222            url,
1223            tls_root_cert,
1224            tls_identity,
1225            http_auth,
1226            tunnel: tunnel.into_inline_connection(r),
1227        }
1228    }
1229}
1230
1231impl<C: ConnectionAccess> CsrConnection<C> {
1232    fn validate_by_default(&self) -> bool {
1233        true
1234    }
1235}
1236
1237impl CsrConnection {
1238    /// Constructs a schema registry client from the connection.
1239    pub async fn connect(
1240        &self,
1241        storage_configuration: &StorageConfiguration,
1242        in_task: InTask,
1243    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1244        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1245        if let Some(root_cert) = &self.tls_root_cert {
1246            let root_cert = root_cert
1247                .get_string(
1248                    in_task,
1249                    &storage_configuration.connection_context.secrets_reader,
1250                )
1251                .await?;
1252            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1253            client_config = client_config.add_root_certificate(root_cert);
1254        }
1255
1256        if let Some(tls_identity) = &self.tls_identity {
1257            let key = &storage_configuration
1258                .connection_context
1259                .secrets_reader
1260                .read_string_in_task_if(in_task, tls_identity.key)
1261                .await?;
1262            let cert = tls_identity
1263                .cert
1264                .get_string(
1265                    in_task,
1266                    &storage_configuration.connection_context.secrets_reader,
1267                )
1268                .await?;
1269            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1270            client_config = client_config.identity(ident);
1271        }
1272
1273        if let Some(http_auth) = &self.http_auth {
1274            let username = http_auth
1275                .username
1276                .get_string(
1277                    in_task,
1278                    &storage_configuration.connection_context.secrets_reader,
1279                )
1280                .await?;
1281            let password = match http_auth.password {
1282                None => None,
1283                Some(password) => Some(
1284                    storage_configuration
1285                        .connection_context
1286                        .secrets_reader
1287                        .read_string_in_task_if(in_task, password)
1288                        .await?,
1289                ),
1290            };
1291            client_config = client_config.auth(username, password);
1292        }
1293
1294        // TODO: use types to enforce that the URL has a string hostname.
1295        let host = self
1296            .url
1297            .host_str()
1298            .ok_or_else(|| anyhow!("url missing host"))?;
1299        match &self.tunnel {
1300            Tunnel::Direct => {
1301                // Ensure any host we connect to is resolved to an external address.
1302                let resolved = resolve_address(
1303                    host,
1304                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1305                )
1306                .await?;
1307                client_config = client_config.resolve_to_addrs(
1308                    host,
1309                    &resolved
1310                        .iter()
1311                        .map(|addr| SocketAddr::new(*addr, 0))
1312                        .collect::<Vec<_>>(),
1313                )
1314            }
1315            Tunnel::Ssh(ssh_tunnel) => {
1316                let ssh_tunnel = ssh_tunnel
1317                    .connect(
1318                        storage_configuration,
1319                        host,
1320                        // Default to the default http port, but this
1321                        // could default to 8081...
1322                        self.url.port().unwrap_or(80),
1323                        in_task,
1324                    )
1325                    .await
1326                    .map_err(CsrConnectError::Ssh)?;
1327
1328                // Carefully inject the SSH tunnel into the client
1329                // configuration. This is delicate because we need TLS
1330                // verification to continue to use the remote hostname rather
1331                // than the tunnel hostname.
1332
1333                client_config = client_config
1334                    // `resolve_to_addrs` allows us to rewrite the hostname
1335                    // at the DNS level, which means the TCP connection is
1336                    // correctly routed through the tunnel, but TLS verification
1337                    // is still performed against the remote hostname.
1338                    // Unfortunately the port here is ignored if the URL also
1339                    // specifies a port...
1340                    .resolve_to_addrs(host, &[SocketAddr::new(ssh_tunnel.local_addr().ip(), 0)])
1341                    // ...so we also dynamically rewrite the URL to use the
1342                    // current port for the SSH tunnel.
1343                    //
1344                    // WARNING: this is brittle, because we only dynamically
1345                    // update the client configuration with the tunnel *port*,
1346                    // and not the hostname This works fine in practice, because
1347                    // only the SSH tunnel port will change if the tunnel fails
1348                    // and has to be restarted (the hostname is always
1349                    // 127.0.0.1)--but this is an an implementation detail of
1350                    // the SSH tunnel code that we're relying on.
1351                    .dynamic_url({
1352                        let remote_url = self.url.clone();
1353                        move || {
1354                            let mut url = remote_url.clone();
1355                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1356                                .expect("cannot fail");
1357                            url
1358                        }
1359                    });
1360            }
1361            Tunnel::AwsPrivatelink(connection) => {
1362                assert_none!(connection.port);
1363
1364                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1365                    connection.connection_id,
1366                    connection.availability_zone.as_deref(),
1367                );
1368                let addrs: Vec<_> = net::lookup_host((privatelink_host, 0))
1369                    .await
1370                    .context("resolving PrivateLink host")?
1371                    .collect();
1372                client_config = client_config.resolve_to_addrs(host, &addrs)
1373            }
1374        }
1375
1376        Ok(client_config.build()?)
1377    }
1378
1379    async fn validate(
1380        &self,
1381        _id: CatalogItemId,
1382        storage_configuration: &StorageConfiguration,
1383    ) -> Result<(), anyhow::Error> {
1384        let client = self
1385            .connect(
1386                storage_configuration,
1387                // We are in a normal tokio context during validation, already.
1388                InTask::No,
1389            )
1390            .await?;
1391        client.list_subjects().await?;
1392        Ok(())
1393    }
1394}
1395
1396impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1397    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1398        let CsrConnection {
1399            tunnel,
1400            // All non-tunnel fields may change
1401            url: _,
1402            tls_root_cert: _,
1403            tls_identity: _,
1404            http_auth: _,
1405        } = self;
1406
1407        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1408
1409        for (compatible, field) in compatibility_checks {
1410            if !compatible {
1411                tracing::warn!(
1412                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1413                    self,
1414                    other
1415                );
1416
1417                return Err(AlterError { id });
1418            }
1419        }
1420        Ok(())
1421    }
1422}
1423
1424/// A TLS key pair used for client identity.
1425#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1426pub struct TlsIdentity {
1427    /// The client's TLS public certificate in PEM format.
1428    pub cert: StringOrSecret,
1429    /// The ID of the secret containing the client's TLS private key in PEM
1430    /// format.
1431    pub key: CatalogItemId,
1432}
1433
1434/// HTTP authentication credentials in a [`CsrConnection`].
1435#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1436pub struct CsrConnectionHttpAuth {
1437    /// The username.
1438    pub username: StringOrSecret,
1439    /// The ID of the secret containing the password, if any.
1440    pub password: Option<CatalogItemId>,
1441}
1442
1443/// A connection to a PostgreSQL server.
1444#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1445pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1446    /// The hostname of the server.
1447    pub host: String,
1448    /// The port of the server.
1449    pub port: u16,
1450    /// The name of the database to connect to.
1451    pub database: String,
1452    /// The username to authenticate as.
1453    pub user: StringOrSecret,
1454    /// An optional password for authentication.
1455    pub password: Option<CatalogItemId>,
1456    /// A tunnel through which to route traffic.
1457    pub tunnel: Tunnel<C>,
1458    /// Whether to use TLS for encryption, authentication, or both.
1459    pub tls_mode: SslMode,
1460    /// An optional root TLS certificate in PEM format, to verify the server's
1461    /// identity.
1462    pub tls_root_cert: Option<StringOrSecret>,
1463    /// An optional TLS client certificate for authentication.
1464    pub tls_identity: Option<TlsIdentity>,
1465}
1466
1467impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1468    for PostgresConnection<ReferencedConnection>
1469{
1470    fn into_inline_connection(self, r: R) -> PostgresConnection {
1471        let PostgresConnection {
1472            host,
1473            port,
1474            database,
1475            user,
1476            password,
1477            tunnel,
1478            tls_mode,
1479            tls_root_cert,
1480            tls_identity,
1481        } = self;
1482
1483        PostgresConnection {
1484            host,
1485            port,
1486            database,
1487            user,
1488            password,
1489            tunnel: tunnel.into_inline_connection(r),
1490            tls_mode,
1491            tls_root_cert,
1492            tls_identity,
1493        }
1494    }
1495}
1496
1497impl<C: ConnectionAccess> PostgresConnection<C> {
1498    fn validate_by_default(&self) -> bool {
1499        true
1500    }
1501}
1502
1503impl PostgresConnection<InlinedConnection> {
1504    pub async fn config(
1505        &self,
1506        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1507        storage_configuration: &StorageConfiguration,
1508        in_task: InTask,
1509    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1510        let params = &storage_configuration.parameters;
1511
1512        let mut config = tokio_postgres::Config::new();
1513        config
1514            .host(&self.host)
1515            .port(self.port)
1516            .dbname(&self.database)
1517            .user(&self.user.get_string(in_task, secrets_reader).await?)
1518            .ssl_mode(self.tls_mode);
1519        if let Some(password) = self.password {
1520            let password = secrets_reader
1521                .read_string_in_task_if(in_task, password)
1522                .await?;
1523            config.password(password);
1524        }
1525        if let Some(tls_root_cert) = &self.tls_root_cert {
1526            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1527            config.ssl_root_cert(tls_root_cert.as_bytes());
1528        }
1529        if let Some(tls_identity) = &self.tls_identity {
1530            let cert = tls_identity
1531                .cert
1532                .get_string(in_task, secrets_reader)
1533                .await?;
1534            let key = secrets_reader
1535                .read_string_in_task_if(in_task, tls_identity.key)
1536                .await?;
1537            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1538        }
1539
1540        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1541            config.connect_timeout(connect_timeout);
1542        }
1543        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1544            config.keepalives_retries(keepalives_retries);
1545        }
1546        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1547            config.keepalives_idle(keepalives_idle);
1548        }
1549        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1550            config.keepalives_interval(keepalives_interval);
1551        }
1552        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1553            config.tcp_user_timeout(tcp_user_timeout);
1554        }
1555
1556        let mut options = vec![];
1557        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1558            options.push(format!(
1559                "--wal_sender_timeout={}",
1560                wal_sender_timeout.as_millis()
1561            ));
1562        };
1563        if params.pg_source_tcp_configure_server {
1564            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1565                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1566            }
1567            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1568                options.push(format!(
1569                    "--tcp_keepalives_idle={}",
1570                    keepalives_idle.as_secs()
1571                ));
1572            }
1573            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1574                options.push(format!(
1575                    "--tcp_keepalives_interval={}",
1576                    keepalives_interval.as_secs()
1577                ));
1578            }
1579            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1580                options.push(format!(
1581                    "--tcp_user_timeout={}",
1582                    tcp_user_timeout.as_millis()
1583                ));
1584            }
1585        }
1586        config.options(options.join(" ").as_str());
1587
1588        let tunnel = match &self.tunnel {
1589            Tunnel::Direct => {
1590                // Ensure any host we connect to is resolved to an external address.
1591                let resolved = resolve_address(
1592                    &self.host,
1593                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1594                )
1595                .await?;
1596                mz_postgres_util::TunnelConfig::Direct {
1597                    resolved_ips: Some(resolved),
1598                }
1599            }
1600            Tunnel::Ssh(SshTunnel {
1601                connection_id,
1602                connection,
1603            }) => {
1604                let secret = secrets_reader
1605                    .read_in_task_if(in_task, *connection_id)
1606                    .await?;
1607                let key_pair = SshKeyPair::from_bytes(&secret)?;
1608                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1609                let resolved = resolve_address(
1610                    &connection.host,
1611                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1612                )
1613                .await?;
1614                mz_postgres_util::TunnelConfig::Ssh {
1615                    config: SshTunnelConfig {
1616                        host: resolved
1617                            .iter()
1618                            .map(|a| a.to_string())
1619                            .collect::<BTreeSet<_>>(),
1620                        port: connection.port,
1621                        user: connection.user.clone(),
1622                        key_pair,
1623                    },
1624                }
1625            }
1626            Tunnel::AwsPrivatelink(connection) => {
1627                assert_none!(connection.port);
1628                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1629                    connection_id: connection.connection_id,
1630                }
1631            }
1632        };
1633
1634        Ok(mz_postgres_util::Config::new(
1635            config,
1636            tunnel,
1637            params.ssh_timeout_config,
1638            in_task,
1639        )?)
1640    }
1641
1642    pub async fn validate(
1643        &self,
1644        _id: CatalogItemId,
1645        storage_configuration: &StorageConfiguration,
1646    ) -> Result<mz_postgres_util::Client, anyhow::Error> {
1647        let config = self
1648            .config(
1649                &storage_configuration.connection_context.secrets_reader,
1650                storage_configuration,
1651                // We are in a normal tokio context during validation, already.
1652                InTask::No,
1653            )
1654            .await?;
1655        let client = config
1656            .connect(
1657                "connection validation",
1658                &storage_configuration.connection_context.ssh_tunnel_manager,
1659            )
1660            .await?;
1661
1662        let wal_level = mz_postgres_util::get_wal_level(&client).await?;
1663
1664        if wal_level < mz_postgres_util::replication::WalLevel::Logical {
1665            Err(PostgresConnectionValidationError::InsufficientWalLevel { wal_level })?;
1666        }
1667
1668        let max_wal_senders = mz_postgres_util::get_max_wal_senders(&client).await?;
1669
1670        if max_wal_senders < 1 {
1671            Err(PostgresConnectionValidationError::ReplicationDisabled)?;
1672        }
1673
1674        let available_replication_slots =
1675            mz_postgres_util::available_replication_slots(&client).await?;
1676
1677        // We need 1 replication slot for the snapshots and 1 for the continuing replication
1678        if available_replication_slots < 2 {
1679            Err(
1680                PostgresConnectionValidationError::InsufficientReplicationSlotsAvailable {
1681                    count: 2,
1682                },
1683            )?;
1684        }
1685
1686        Ok(client)
1687    }
1688}
1689
1690#[derive(Debug, Clone, thiserror::Error)]
1691pub enum PostgresConnectionValidationError {
1692    #[error("PostgreSQL server has insufficient number of replication slots available")]
1693    InsufficientReplicationSlotsAvailable { count: usize },
1694    #[error("server must have wal_level >= logical, but has {wal_level}")]
1695    InsufficientWalLevel {
1696        wal_level: mz_postgres_util::replication::WalLevel,
1697    },
1698    #[error("replication disabled on server")]
1699    ReplicationDisabled,
1700}
1701
1702impl PostgresConnectionValidationError {
1703    pub fn detail(&self) -> Option<String> {
1704        match self {
1705            Self::InsufficientReplicationSlotsAvailable { count } => Some(format!(
1706                "executing this statement requires {} replication slot{}",
1707                count,
1708                if *count == 1 { "" } else { "s" }
1709            )),
1710            _ => None,
1711        }
1712    }
1713
1714    pub fn hint(&self) -> Option<String> {
1715        match self {
1716            Self::InsufficientReplicationSlotsAvailable { .. } => Some(
1717                "you might be able to wait for other sources to finish snapshotting and try again"
1718                    .into(),
1719            ),
1720            Self::ReplicationDisabled => Some("set max_wal_senders to a value > 0".into()),
1721            _ => None,
1722        }
1723    }
1724}
1725
1726impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1727    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1728        let PostgresConnection {
1729            tunnel,
1730            // All non-tunnel options may change arbitrarily
1731            host: _,
1732            port: _,
1733            database: _,
1734            user: _,
1735            password: _,
1736            tls_mode: _,
1737            tls_root_cert: _,
1738            tls_identity: _,
1739        } = self;
1740
1741        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1742
1743        for (compatible, field) in compatibility_checks {
1744            if !compatible {
1745                tracing::warn!(
1746                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1747                    self,
1748                    other
1749                );
1750
1751                return Err(AlterError { id });
1752            }
1753        }
1754        Ok(())
1755    }
1756}
1757
1758/// Specifies how to tunnel a connection.
1759#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1760pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1761    /// No tunneling.
1762    Direct,
1763    /// Via the specified SSH tunnel connection.
1764    Ssh(SshTunnel<C>),
1765    /// Via the specified AWS PrivateLink connection.
1766    AwsPrivatelink(AwsPrivatelink),
1767}
1768
1769impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1770    fn into_inline_connection(self, r: R) -> Tunnel {
1771        match self {
1772            Tunnel::Direct => Tunnel::Direct,
1773            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1774            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1775        }
1776    }
1777}
1778
1779impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1780    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1781        let compatible = match (self, other) {
1782            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1783            (s, o) => s == o,
1784        };
1785
1786        if !compatible {
1787            tracing::warn!(
1788                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1789                self,
1790                other
1791            );
1792
1793            return Err(AlterError { id });
1794        }
1795
1796        Ok(())
1797    }
1798}
1799
1800/// Specifies which MySQL SSL Mode to use:
1801/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1802/// This is not available as an enum in the mysql-async crate, so we define our own.
1803#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1804pub enum MySqlSslMode {
1805    Disabled,
1806    Required,
1807    VerifyCa,
1808    VerifyIdentity,
1809}
1810
1811/// A connection to a MySQL server.
1812#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1813pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1814    /// The hostname of the server.
1815    pub host: String,
1816    /// The port of the server.
1817    pub port: u16,
1818    /// The username to authenticate as.
1819    pub user: StringOrSecret,
1820    /// An optional password for authentication.
1821    pub password: Option<CatalogItemId>,
1822    /// A tunnel through which to route traffic.
1823    pub tunnel: Tunnel<C>,
1824    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1825    pub tls_mode: MySqlSslMode,
1826    /// An optional root TLS certificate in PEM format, to verify the server's
1827    /// identity.
1828    pub tls_root_cert: Option<StringOrSecret>,
1829    /// An optional TLS client certificate for authentication.
1830    pub tls_identity: Option<TlsIdentity>,
1831    /// Reference to the AWS connection information to be used for IAM authenitcation and
1832    /// assuming AWS roles.
1833    pub aws_connection: Option<AwsConnectionReference<C>>,
1834}
1835
1836impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
1837    for MySqlConnection<ReferencedConnection>
1838{
1839    fn into_inline_connection(self, r: R) -> MySqlConnection {
1840        let MySqlConnection {
1841            host,
1842            port,
1843            user,
1844            password,
1845            tunnel,
1846            tls_mode,
1847            tls_root_cert,
1848            tls_identity,
1849            aws_connection,
1850        } = self;
1851
1852        MySqlConnection {
1853            host,
1854            port,
1855            user,
1856            password,
1857            tunnel: tunnel.into_inline_connection(&r),
1858            tls_mode,
1859            tls_root_cert,
1860            tls_identity,
1861            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
1862        }
1863    }
1864}
1865
1866impl<C: ConnectionAccess> MySqlConnection<C> {
1867    fn validate_by_default(&self) -> bool {
1868        true
1869    }
1870}
1871
1872impl MySqlConnection<InlinedConnection> {
1873    pub async fn config(
1874        &self,
1875        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1876        storage_configuration: &StorageConfiguration,
1877        in_task: InTask,
1878    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
1879        // TODO(roshan): Set appropriate connection timeouts
1880        let mut opts = mysql_async::OptsBuilder::default()
1881            .ip_or_hostname(&self.host)
1882            .tcp_port(self.port)
1883            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
1884
1885        if let Some(password) = self.password {
1886            let password = secrets_reader
1887                .read_string_in_task_if(in_task, password)
1888                .await?;
1889            opts = opts.pass(Some(password));
1890        }
1891
1892        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
1893        // which uses opt-in security features (SSL, CA verification, & Identity verification).
1894        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
1895        // we need to appropriately disable features to match the intent of each enum value.
1896        let mut ssl_opts = match self.tls_mode {
1897            MySqlSslMode::Disabled => None,
1898            MySqlSslMode::Required => Some(
1899                mysql_async::SslOpts::default()
1900                    .with_danger_accept_invalid_certs(true)
1901                    .with_danger_skip_domain_validation(true),
1902            ),
1903            MySqlSslMode::VerifyCa => {
1904                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
1905            }
1906            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
1907        };
1908
1909        if matches!(
1910            self.tls_mode,
1911            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
1912        ) {
1913            if let Some(tls_root_cert) = &self.tls_root_cert {
1914                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1915                ssl_opts = ssl_opts.map(|opts| {
1916                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
1917                });
1918            }
1919        }
1920
1921        if let Some(identity) = &self.tls_identity {
1922            let key = secrets_reader
1923                .read_string_in_task_if(in_task, identity.key)
1924                .await?;
1925            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
1926            let Pkcs12Archive { der, pass } =
1927                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
1928
1929            // Add client identity to SSLOpts
1930            ssl_opts = ssl_opts.map(|opts| {
1931                opts.with_client_identity(Some(
1932                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
1933                ))
1934            });
1935        }
1936
1937        opts = opts.ssl_opts(ssl_opts);
1938
1939        let tunnel = match &self.tunnel {
1940            Tunnel::Direct => {
1941                // Ensure any host we connect to is resolved to an external address.
1942                let resolved = resolve_address(
1943                    &self.host,
1944                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1945                )
1946                .await?;
1947                mz_mysql_util::TunnelConfig::Direct {
1948                    resolved_ips: Some(resolved),
1949                }
1950            }
1951            Tunnel::Ssh(SshTunnel {
1952                connection_id,
1953                connection,
1954            }) => {
1955                let secret = secrets_reader
1956                    .read_in_task_if(in_task, *connection_id)
1957                    .await?;
1958                let key_pair = SshKeyPair::from_bytes(&secret)?;
1959                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1960                let resolved = resolve_address(
1961                    &connection.host,
1962                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1963                )
1964                .await?;
1965                mz_mysql_util::TunnelConfig::Ssh {
1966                    config: SshTunnelConfig {
1967                        host: resolved
1968                            .iter()
1969                            .map(|a| a.to_string())
1970                            .collect::<BTreeSet<_>>(),
1971                        port: connection.port,
1972                        user: connection.user.clone(),
1973                        key_pair,
1974                    },
1975                }
1976            }
1977            Tunnel::AwsPrivatelink(connection) => {
1978                assert_none!(connection.port);
1979                mz_mysql_util::TunnelConfig::AwsPrivatelink {
1980                    connection_id: connection.connection_id,
1981                }
1982            }
1983        };
1984
1985        let aws_config = match self.aws_connection.as_ref() {
1986            None => None,
1987            Some(aws_ref) => Some(
1988                aws_ref
1989                    .connection
1990                    .load_sdk_config(
1991                        &storage_configuration.connection_context,
1992                        aws_ref.connection_id,
1993                        in_task,
1994                    )
1995                    .await?,
1996            ),
1997        };
1998
1999        Ok(mz_mysql_util::Config::new(
2000            opts,
2001            tunnel,
2002            storage_configuration.parameters.ssh_timeout_config,
2003            in_task,
2004            storage_configuration
2005                .parameters
2006                .mysql_source_timeouts
2007                .clone(),
2008            aws_config,
2009        )?)
2010    }
2011
2012    pub async fn validate(
2013        &self,
2014        _id: CatalogItemId,
2015        storage_configuration: &StorageConfiguration,
2016    ) -> Result<MySqlConn, MySqlConnectionValidationError> {
2017        let config = self
2018            .config(
2019                &storage_configuration.connection_context.secrets_reader,
2020                storage_configuration,
2021                // We are in a normal tokio context during validation, already.
2022                InTask::No,
2023            )
2024            .await?;
2025        let mut conn = config
2026            .connect(
2027                "connection validation",
2028                &storage_configuration.connection_context.ssh_tunnel_manager,
2029            )
2030            .await?;
2031
2032        // Check if the MySQL database is configured to allow row-based consistent GTID replication
2033        let mut setting_errors = vec![];
2034        let gtid_res = mz_mysql_util::ensure_gtid_consistency(&mut conn).await;
2035        let binlog_res = mz_mysql_util::ensure_full_row_binlog_format(&mut conn).await;
2036        let order_res = mz_mysql_util::ensure_replication_commit_order(&mut conn).await;
2037        for res in [gtid_res, binlog_res, order_res] {
2038            match res {
2039                Err(MySqlError::InvalidSystemSetting {
2040                    setting,
2041                    expected,
2042                    actual,
2043                }) => {
2044                    setting_errors.push((setting, expected, actual));
2045                }
2046                Err(err) => Err(err)?,
2047                Ok(()) => {}
2048            }
2049        }
2050        if !setting_errors.is_empty() {
2051            Err(MySqlConnectionValidationError::ReplicationSettingsError(
2052                setting_errors,
2053            ))?;
2054        }
2055
2056        Ok(conn)
2057    }
2058}
2059
2060#[derive(Debug, thiserror::Error)]
2061pub enum MySqlConnectionValidationError {
2062    #[error("Invalid MySQL system replication settings")]
2063    ReplicationSettingsError(Vec<(String, String, String)>),
2064    #[error(transparent)]
2065    Client(#[from] MySqlError),
2066    #[error("{}", .0.display_with_causes())]
2067    Other(#[from] anyhow::Error),
2068}
2069
2070impl MySqlConnectionValidationError {
2071    pub fn detail(&self) -> Option<String> {
2072        match self {
2073            Self::ReplicationSettingsError(settings) => Some(format!(
2074                "Invalid MySQL system replication settings: {}",
2075                itertools::join(
2076                    settings.iter().map(|(setting, expected, actual)| format!(
2077                        "{}: expected {}, got {}",
2078                        setting, expected, actual
2079                    )),
2080                    "; "
2081                )
2082            )),
2083            _ => None,
2084        }
2085    }
2086
2087    pub fn hint(&self) -> Option<String> {
2088        match self {
2089            Self::ReplicationSettingsError(_) => {
2090                Some("Set the necessary MySQL database system settings.".into())
2091            }
2092            _ => None,
2093        }
2094    }
2095}
2096
2097impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
2098    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2099        let MySqlConnection {
2100            tunnel,
2101            // All non-tunnel options may change arbitrarily
2102            host: _,
2103            port: _,
2104            user: _,
2105            password: _,
2106            tls_mode: _,
2107            tls_root_cert: _,
2108            tls_identity: _,
2109            aws_connection: _,
2110        } = self;
2111
2112        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2113
2114        for (compatible, field) in compatibility_checks {
2115            if !compatible {
2116                tracing::warn!(
2117                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2118                    self,
2119                    other
2120                );
2121
2122                return Err(AlterError { id });
2123            }
2124        }
2125        Ok(())
2126    }
2127}
2128
2129/// Details how to connect to an instance of Microsoft SQL Server.
2130///
2131/// For specifics of connecting to SQL Server for purposes of creating a
2132/// Materialize Source, see [`SqlServerSourceConnection`] which wraps this type.
2133///
2134/// [`SqlServerSourceConnection`]: crate::sources::SqlServerSourceConnection
2135#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2136pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2137    /// The hostname of the server.
2138    pub host: String,
2139    /// The port of the server.
2140    pub port: u16,
2141    /// Database we should connect to.
2142    pub database: String,
2143    /// The username to authenticate as.
2144    pub user: StringOrSecret,
2145    /// Password used for authentication.
2146    pub password: CatalogItemId,
2147    /// A tunnel through which to route traffic.
2148    pub tunnel: Tunnel<C>,
2149    /// Level of encryption to use for the connection.
2150    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2151    /// Certificate validation policy
2152    pub certificate_validation_policy: mz_sql_server_util::config::CertificateValidationPolicy,
2153    /// TLS CA Certifiecate in PEM format
2154    pub tls_root_cert: Option<StringOrSecret>,
2155}
2156
2157impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2158    fn validate_by_default(&self) -> bool {
2159        true
2160    }
2161}
2162
2163impl SqlServerConnectionDetails<InlinedConnection> {
2164    /// Attempts to open a connection to the upstream SQL Server instance.
2165    pub async fn validate(
2166        &self,
2167        _id: CatalogItemId,
2168        storage_configuration: &StorageConfiguration,
2169    ) -> Result<mz_sql_server_util::Client, anyhow::Error> {
2170        let config = self
2171            .resolve_config(
2172                &storage_configuration.connection_context.secrets_reader,
2173                storage_configuration,
2174                InTask::No,
2175            )
2176            .await?;
2177        tracing::debug!(?config, "Validating SQL Server connection");
2178
2179        let mut client = mz_sql_server_util::Client::connect(config).await?;
2180
2181        // Ensure the upstream SQL Server instance is configured to allow CDC.
2182        //
2183        // Run all of the checks necessary and collect the errors to provide the best
2184        // guidance as to which system settings need to be enabled.
2185        let mut replication_errors = vec![];
2186        for error in [
2187            mz_sql_server_util::inspect::ensure_database_cdc_enabled(&mut client).await,
2188            mz_sql_server_util::inspect::ensure_snapshot_isolation_enabled(&mut client).await,
2189            mz_sql_server_util::inspect::ensure_sql_server_agent_running(&mut client).await,
2190        ] {
2191            match error {
2192                Err(mz_sql_server_util::SqlServerError::InvalidSystemSetting {
2193                    name,
2194                    expected,
2195                    actual,
2196                }) => replication_errors.push((name, expected, actual)),
2197                Err(other) => Err(other)?,
2198                Ok(()) => (),
2199            }
2200        }
2201        if !replication_errors.is_empty() {
2202            Err(SqlServerConnectionValidationError::ReplicationSettingsError(replication_errors))?;
2203        }
2204
2205        Ok(client)
2206    }
2207
2208    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2209    /// so the returned [`Config`] can be used to open a connection with the
2210    /// upstream system.
2211    ///
2212    /// The provided [`InTask`] argument determines whether any I/O is run in an
2213    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2214    /// future. The main goal here is to prevent running I/O in timely threads.
2215    ///
2216    /// [`Config`]: mz_sql_server_util::Config
2217    pub async fn resolve_config(
2218        &self,
2219        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2220        storage_configuration: &StorageConfiguration,
2221        in_task: InTask,
2222    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2223        let dyncfg = storage_configuration.config_set();
2224        let mut inner_config = tiberius::Config::new();
2225
2226        // Setup default connection params.
2227        inner_config.host(&self.host);
2228        inner_config.port(self.port);
2229        inner_config.database(self.database.clone());
2230        inner_config.encryption(self.encryption.into());
2231        match self.certificate_validation_policy {
2232            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2233                inner_config.trust_cert()
2234            }
2235            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2236                inner_config.trust_cert_ca_pem(
2237                    self.tls_root_cert
2238                        .as_ref()
2239                        .unwrap()
2240                        .get_string(in_task, secrets_reader)
2241                        .await
2242                        .context("ca certificate")?,
2243                );
2244            }
2245            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => (), // no-op
2246        }
2247
2248        inner_config.application_name("materialize");
2249
2250        // Read our auth settings from
2251        let user = self
2252            .user
2253            .get_string(in_task, secrets_reader)
2254            .await
2255            .context("username")?;
2256        let password = secrets_reader
2257            .read_string_in_task_if(in_task, self.password)
2258            .await
2259            .context("password")?;
2260        // TODO(sql_server3): Support other methods of authentication besides
2261        // username and password.
2262        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2263
2264        // Prevent users from probing our internal network ports by trying to
2265        // connect to localhost, or another non-external IP.
2266        let enfoce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2267
2268        let tunnel = match &self.tunnel {
2269            Tunnel::Direct => mz_sql_server_util::config::TunnelConfig::Direct,
2270            Tunnel::Ssh(SshTunnel {
2271                connection_id,
2272                connection: ssh_connection,
2273            }) => {
2274                let secret = secrets_reader
2275                    .read_in_task_if(in_task, *connection_id)
2276                    .await
2277                    .context("ssh secret")?;
2278                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2279                // Ensure any SSH-bastion host we connect to is resolved to an
2280                // external address.
2281                let addresses = resolve_address(&ssh_connection.host, enfoce_external_addresses)
2282                    .await
2283                    .context("ssh tunnel")?;
2284
2285                let config = SshTunnelConfig {
2286                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2287                    port: ssh_connection.port,
2288                    user: ssh_connection.user.clone(),
2289                    key_pair,
2290                };
2291                mz_sql_server_util::config::TunnelConfig::Ssh {
2292                    config,
2293                    manager: storage_configuration
2294                        .connection_context
2295                        .ssh_tunnel_manager
2296                        .clone(),
2297                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2298                    host: self.host.clone(),
2299                    port: self.port,
2300                }
2301            }
2302            Tunnel::AwsPrivatelink(private_link_connection) => {
2303                assert_none!(private_link_connection.port);
2304                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2305                    connection_id: private_link_connection.connection_id,
2306                    port: self.port,
2307                }
2308            }
2309        };
2310
2311        Ok(mz_sql_server_util::Config::new(
2312            inner_config,
2313            tunnel,
2314            in_task,
2315        ))
2316    }
2317}
2318
2319#[derive(Debug, Clone, thiserror::Error)]
2320pub enum SqlServerConnectionValidationError {
2321    #[error("Invalid SQL Server system replication settings")]
2322    ReplicationSettingsError(Vec<(String, String, String)>),
2323}
2324
2325impl SqlServerConnectionValidationError {
2326    pub fn detail(&self) -> Option<String> {
2327        match self {
2328            Self::ReplicationSettingsError(settings) => Some(format!(
2329                "Invalid SQL Server system replication settings: {}",
2330                itertools::join(
2331                    settings.iter().map(|(setting, expected, actual)| format!(
2332                        "{}: expected {}, got {}",
2333                        setting, expected, actual
2334                    )),
2335                    "; "
2336                )
2337            )),
2338        }
2339    }
2340
2341    pub fn hint(&self) -> Option<String> {
2342        match self {
2343            _ => None,
2344        }
2345    }
2346}
2347
2348impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2349    for SqlServerConnectionDetails<ReferencedConnection>
2350{
2351    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2352        let SqlServerConnectionDetails {
2353            host,
2354            port,
2355            database,
2356            user,
2357            password,
2358            tunnel,
2359            encryption,
2360            certificate_validation_policy,
2361            tls_root_cert,
2362        } = self;
2363
2364        SqlServerConnectionDetails {
2365            host,
2366            port,
2367            database,
2368            user,
2369            password,
2370            tunnel: tunnel.into_inline_connection(&r),
2371            encryption,
2372            certificate_validation_policy,
2373            tls_root_cert,
2374        }
2375    }
2376}
2377
2378impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2379    fn alter_compatible(
2380        &self,
2381        id: mz_repr::GlobalId,
2382        other: &Self,
2383    ) -> Result<(), crate::controller::AlterError> {
2384        let SqlServerConnectionDetails {
2385            tunnel,
2386            // TODO(sql_server2): Figure out how these variables are allowed to change.
2387            host: _,
2388            port: _,
2389            database: _,
2390            user: _,
2391            password: _,
2392            encryption: _,
2393            certificate_validation_policy: _,
2394            tls_root_cert: _,
2395        } = self;
2396
2397        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2398
2399        for (compatible, field) in compatibility_checks {
2400            if !compatible {
2401                tracing::warn!(
2402                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2403                    self,
2404                    other
2405                );
2406
2407                return Err(AlterError { id });
2408            }
2409        }
2410        Ok(())
2411    }
2412}
2413
2414/// A connection to an SSH tunnel.
2415#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2416pub struct SshConnection {
2417    pub host: String,
2418    pub port: u16,
2419    pub user: String,
2420}
2421
2422use self::inline::{
2423    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2424    ReferencedConnection,
2425};
2426
2427impl AlterCompatible for SshConnection {
2428    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2429        // Every element of the SSH connection is configurable.
2430        Ok(())
2431    }
2432}
2433
2434/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2435#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2436pub struct AwsPrivatelink {
2437    /// The ID of the connection to the AWS PrivateLink service.
2438    pub connection_id: CatalogItemId,
2439    // The availability zone to use when connecting to the AWS PrivateLink service.
2440    pub availability_zone: Option<String>,
2441    /// The port to use when connecting to the AWS PrivateLink service, if
2442    /// different from the port in [`KafkaBroker::address`].
2443    pub port: Option<u16>,
2444}
2445
2446impl AlterCompatible for AwsPrivatelink {
2447    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2448        let AwsPrivatelink {
2449            connection_id,
2450            availability_zone: _,
2451            port: _,
2452        } = self;
2453
2454        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2455
2456        for (compatible, field) in compatibility_checks {
2457            if !compatible {
2458                tracing::warn!(
2459                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2460                    self,
2461                    other
2462                );
2463
2464                return Err(AlterError { id });
2465            }
2466        }
2467
2468        Ok(())
2469    }
2470}
2471
2472/// Specifies an SSH tunnel connection.
2473#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2474pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2475    /// id of the ssh connection
2476    pub connection_id: CatalogItemId,
2477    /// ssh connection object
2478    pub connection: C::Ssh,
2479}
2480
2481impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2482    fn into_inline_connection(self, r: R) -> SshTunnel {
2483        let SshTunnel {
2484            connection,
2485            connection_id,
2486        } = self;
2487
2488        SshTunnel {
2489            connection: r.resolve_connection(connection).unwrap_ssh(),
2490            connection_id,
2491        }
2492    }
2493}
2494
2495impl SshTunnel<InlinedConnection> {
2496    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2497    /// secret.
2498    async fn connect(
2499        &self,
2500        storage_configuration: &StorageConfiguration,
2501        remote_host: &str,
2502        remote_port: u16,
2503        in_task: InTask,
2504    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2505        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2506        let resolved = resolve_address(
2507            &self.connection.host,
2508            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2509        )
2510        .await?;
2511        storage_configuration
2512            .connection_context
2513            .ssh_tunnel_manager
2514            .connect(
2515                SshTunnelConfig {
2516                    host: resolved
2517                        .iter()
2518                        .map(|a| a.to_string())
2519                        .collect::<BTreeSet<_>>(),
2520                    port: self.connection.port,
2521                    user: self.connection.user.clone(),
2522                    key_pair: SshKeyPair::from_bytes(
2523                        &storage_configuration
2524                            .connection_context
2525                            .secrets_reader
2526                            .read_in_task_if(in_task, self.connection_id)
2527                            .await?,
2528                    )?,
2529                },
2530                remote_host,
2531                remote_port,
2532                storage_configuration.parameters.ssh_timeout_config,
2533                in_task,
2534            )
2535            .await
2536    }
2537}
2538
2539impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2540    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2541        let SshTunnel {
2542            connection_id,
2543            connection,
2544        } = self;
2545
2546        let compatibility_checks = [
2547            (connection_id == &other.connection_id, "connection_id"),
2548            (
2549                connection.alter_compatible(id, &other.connection).is_ok(),
2550                "connection",
2551            ),
2552        ];
2553
2554        for (compatible, field) in compatibility_checks {
2555            if !compatible {
2556                tracing::warn!(
2557                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2558                    self,
2559                    other
2560                );
2561
2562                return Err(AlterError { id });
2563            }
2564        }
2565
2566        Ok(())
2567    }
2568}
2569
2570impl SshConnection {
2571    #[allow(clippy::unused_async)]
2572    async fn validate(
2573        &self,
2574        id: CatalogItemId,
2575        storage_configuration: &StorageConfiguration,
2576    ) -> Result<(), anyhow::Error> {
2577        let secret = storage_configuration
2578            .connection_context
2579            .secrets_reader
2580            .read_in_task_if(
2581                // We are in a normal tokio context during validation, already.
2582                InTask::No,
2583                id,
2584            )
2585            .await?;
2586        let key_pair = SshKeyPair::from_bytes(&secret)?;
2587
2588        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2589        let resolved = resolve_address(
2590            &self.host,
2591            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2592        )
2593        .await?;
2594
2595        let config = SshTunnelConfig {
2596            host: resolved
2597                .iter()
2598                .map(|a| a.to_string())
2599                .collect::<BTreeSet<_>>(),
2600            port: self.port,
2601            user: self.user.clone(),
2602            key_pair,
2603        };
2604        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2605        // can actually create a new connection to the ssh bastion, without tunneling.
2606        config
2607            .validate(storage_configuration.parameters.ssh_timeout_config)
2608            .await
2609    }
2610
2611    fn validate_by_default(&self) -> bool {
2612        false
2613    }
2614}
2615
2616impl AwsPrivatelinkConnection {
2617    #[allow(clippy::unused_async)]
2618    async fn validate(
2619        &self,
2620        id: CatalogItemId,
2621        storage_configuration: &StorageConfiguration,
2622    ) -> Result<(), anyhow::Error> {
2623        let Some(ref cloud_resource_reader) = storage_configuration
2624            .connection_context
2625            .cloud_resource_reader
2626        else {
2627            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2628        };
2629
2630        // No need to optionally run this in a task, as we are just validating from envd.
2631        let status = cloud_resource_reader.read(id).await?;
2632
2633        let availability = status
2634            .conditions
2635            .as_ref()
2636            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2637
2638        match availability {
2639            Some(condition) if condition.status == "True" => Ok(()),
2640            Some(condition) => Err(anyhow!("{}", condition.message)),
2641            None => Err(anyhow!("Endpoint availability is unknown")),
2642        }
2643    }
2644
2645    fn validate_by_default(&self) -> bool {
2646        false
2647    }
2648}