Skip to main content

mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow};
18use async_trait::async_trait;
19use aws_credential_types::provider::ProvideCredentials;
20use iceberg::Catalog;
21use iceberg::CatalogBuilder;
22use iceberg::io::{
23    AwsCredential, AwsCredentialLoad, CustomAwsCredentialLoader, S3_ACCESS_KEY_ID,
24    S3_DISABLE_EC2_METADATA, S3_REGION, S3_SECRET_ACCESS_KEY,
25};
26use iceberg_catalog_rest::{
27    REST_CATALOG_PROP_URI, REST_CATALOG_PROP_WAREHOUSE, RestCatalogBuilder,
28};
29use itertools::Itertools;
30use mz_ccsr::tls::{Certificate, Identity};
31use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
32use mz_dyncfg::ConfigSet;
33use mz_kafka_util::client::{
34    BrokerAddr, BrokerRewrite, HostMappingRules, MzClientContext, MzKafkaError, TunnelConfig,
35    TunnelingClientContext,
36};
37use mz_mysql_util::{MySqlConn, MySqlError};
38use mz_ore::assert_none;
39use mz_ore::error::ErrorExt;
40use mz_ore::future::{InTask, OreFutureExt};
41use mz_ore::netio::resolve_address;
42use mz_ore::num::NonNeg;
43use mz_repr::{CatalogItemId, GlobalId};
44use mz_secrets::SecretsReader;
45use mz_sql_parser::ast::ConnectionRulePattern;
46use mz_ssh_util::keys::SshKeyPair;
47use mz_ssh_util::tunnel::SshTunnelConfig;
48use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
49use mz_tracing::CloneableEnvFilter;
50use rdkafka::ClientContext;
51use rdkafka::config::FromClientConfigAndContext;
52use rdkafka::consumer::{BaseConsumer, Consumer};
53use regex::Regex;
54use serde::{Deserialize, Deserializer, Serialize};
55use tokio::net;
56use tokio::runtime::Handle;
57use tokio_postgres::config::SslMode;
58use tracing::{debug, info, warn};
59use url::Url;
60
61use crate::AlterCompatible;
62use crate::configuration::StorageConfiguration;
63use crate::connections::aws::{
64    AwsAuth, AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
65};
66use crate::connections::string_or_secret::StringOrSecret;
67use crate::controller::AlterError;
68use crate::dyncfgs::{
69    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
70    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM, KAFKA_RECONNECT_BACKOFF,
71    KAFKA_RECONNECT_BACKOFF_MAX, KAFKA_RETRY_BACKOFF, KAFKA_RETRY_BACKOFF_MAX,
72};
73use crate::errors::{ContextCreationError, CsrConnectError};
74
75pub mod aws;
76pub mod inline;
77pub mod string_or_secret;
78
79const REST_CATALOG_PROP_SCOPE: &str = "scope";
80const REST_CATALOG_PROP_CREDENTIAL: &str = "credential";
81
82/// A credential loader that wraps an aws-sdk-rust credentials provider for use with
83/// iceberg/OpenDAL. This allows us to provide refreshable credentials from the AWS SDK
84/// credential chain (including the full assume role chain) to OpenDAL's S3 implementation.
85///
86/// We use this instead of OpenDAL's built-in assume role support because Materialize
87/// has a runtime-defined credential chain (ambient → jump role → user role with external ID)
88/// that can't be expressed via OpenDAL's static configuration properties.
89struct AwsSdkCredentialLoader {
90    /// The underlying AWS SDK credentials provider. For assume role auth, this provider
91    /// already handles the full chain: ambient creds -> jump role -> user role.
92    provider: aws_credential_types::provider::SharedCredentialsProvider,
93}
94
95impl AwsSdkCredentialLoader {
96    fn new(provider: aws_credential_types::provider::SharedCredentialsProvider) -> Self {
97        Self { provider }
98    }
99}
100
101#[async_trait]
102impl AwsCredentialLoad for AwsSdkCredentialLoader {
103    async fn load_credential(
104        &self,
105        _client: reqwest::Client,
106    ) -> anyhow::Result<Option<AwsCredential>> {
107        let creds = self
108            .provider
109            .provide_credentials()
110            .await
111            .map_err(|e| {
112                warn!(
113                    error = %e.display_with_causes(),
114                    "failed to load AWS credentials for Iceberg FileIO from SDK provider"
115                );
116                e
117            })
118            .context(
119                "failed to load AWS credentials from SDK provider for Iceberg FileIO \
120                 (credential source may be temporarily unavailable)",
121            )?;
122
123        Ok(Some(AwsCredential {
124            access_key_id: creds.access_key_id().to_string(),
125            secret_access_key: creds.secret_access_key().to_string(),
126            session_token: creds.session_token().map(|s| s.to_string()),
127            expires_in: creds.expiry().map(|t| t.into()),
128        }))
129    }
130}
131
132/// An extension trait for [`SecretsReader`]
133#[async_trait::async_trait]
134trait SecretsReaderExt {
135    /// `SecretsReader::read`, but optionally run in a task.
136    async fn read_in_task_if(
137        &self,
138        in_task: InTask,
139        id: CatalogItemId,
140    ) -> Result<Vec<u8>, anyhow::Error>;
141
142    /// `SecretsReader::read_string`, but optionally run in a task.
143    async fn read_string_in_task_if(
144        &self,
145        in_task: InTask,
146        id: CatalogItemId,
147    ) -> Result<String, anyhow::Error>;
148}
149
150#[async_trait::async_trait]
151impl SecretsReaderExt for Arc<dyn SecretsReader> {
152    async fn read_in_task_if(
153        &self,
154        in_task: InTask,
155        id: CatalogItemId,
156    ) -> Result<Vec<u8>, anyhow::Error> {
157        let sr = Arc::clone(self);
158        async move { sr.read(id).await }
159            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
160            .await
161    }
162    async fn read_string_in_task_if(
163        &self,
164        in_task: InTask,
165        id: CatalogItemId,
166    ) -> Result<String, anyhow::Error> {
167        let sr = Arc::clone(self);
168        async move { sr.read_string(id).await }
169            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
170            .await
171    }
172}
173
174/// Extra context to pass through when instantiating a connection for a source
175/// or sink.
176///
177/// Should be kept cheaply cloneable.
178#[derive(Debug, Clone)]
179pub struct ConnectionContext {
180    /// An opaque identifier for the environment in which this process is
181    /// running.
182    ///
183    /// The storage layer is intentionally unaware of the structure within this
184    /// identifier. Higher layers of the stack can make use of that structure,
185    /// but the storage layer should be oblivious to it.
186    pub environment_id: String,
187    /// The level for librdkafka's logs.
188    pub librdkafka_log_level: tracing::Level,
189    /// A prefix for an external ID to use for all AWS AssumeRole operations.
190    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
191    /// The ARN for a Materialize-controlled role to assume before assuming
192    /// a customer's requested role for an AWS connection.
193    pub aws_connection_role_arn: Option<String>,
194    /// A secrets reader.
195    pub secrets_reader: Arc<dyn SecretsReader>,
196    /// A cloud resource reader, if supported in this configuration.
197    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
198    /// A manager for SSH tunnels.
199    pub ssh_tunnel_manager: SshTunnelManager,
200}
201
202impl ConnectionContext {
203    /// Constructs a new connection context from command line arguments.
204    ///
205    /// **WARNING:** it is critical for security that the `aws_external_id` be
206    /// provided by the operator of the Materialize service (i.e., via a CLI
207    /// argument or environment variable) and not the end user of Materialize
208    /// (e.g., via a configuration option in a SQL statement). See
209    /// [`AwsExternalIdPrefix`] for details.
210    pub fn from_cli_args(
211        environment_id: String,
212        startup_log_level: &CloneableEnvFilter,
213        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
214        aws_connection_role_arn: Option<String>,
215        secrets_reader: Arc<dyn SecretsReader>,
216        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
217    ) -> ConnectionContext {
218        ConnectionContext {
219            environment_id,
220            librdkafka_log_level: mz_ore::tracing::crate_level(
221                &startup_log_level.clone().into(),
222                "librdkafka",
223            ),
224            aws_external_id_prefix,
225            aws_connection_role_arn,
226            secrets_reader,
227            cloud_resource_reader,
228            ssh_tunnel_manager: SshTunnelManager::default(),
229        }
230    }
231
232    /// Constructs a new connection context for usage in tests.
233    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
234        ConnectionContext {
235            environment_id: "test-environment-id".into(),
236            librdkafka_log_level: tracing::Level::INFO,
237            aws_external_id_prefix: Some(
238                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
239                    "test-aws-external-id-prefix",
240                )
241                .expect("infallible"),
242            ),
243            aws_connection_role_arn: Some(
244                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
245            ),
246            secrets_reader,
247            cloud_resource_reader: None,
248            ssh_tunnel_manager: SshTunnelManager::default(),
249        }
250    }
251}
252
253#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
254pub enum Connection<C: ConnectionAccess = InlinedConnection> {
255    Kafka(KafkaConnection<C>),
256    Csr(CsrConnection<C>),
257    Postgres(PostgresConnection<C>),
258    Ssh(SshConnection),
259    Aws(AwsConnection),
260    AwsPrivatelink(AwsPrivatelinkConnection),
261    MySql(MySqlConnection<C>),
262    SqlServer(SqlServerConnectionDetails<C>),
263    IcebergCatalog(IcebergCatalogConnection<C>),
264}
265
266impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
267    for Connection<ReferencedConnection>
268{
269    fn into_inline_connection(self, r: R) -> Connection {
270        match self {
271            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
272            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
273            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
274            Connection::Ssh(ssh) => Connection::Ssh(ssh),
275            Connection::Aws(aws) => Connection::Aws(aws),
276            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
277            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
278            Connection::SqlServer(sql_server) => {
279                Connection::SqlServer(sql_server.into_inline_connection(r))
280            }
281            Connection::IcebergCatalog(iceberg) => {
282                Connection::IcebergCatalog(iceberg.into_inline_connection(r))
283            }
284        }
285    }
286}
287
288impl<C: ConnectionAccess> Connection<C> {
289    /// Whether this connection should be validated by default on creation.
290    pub fn validate_by_default(&self) -> bool {
291        match self {
292            Connection::Kafka(conn) => conn.validate_by_default(),
293            Connection::Csr(conn) => conn.validate_by_default(),
294            Connection::Postgres(conn) => conn.validate_by_default(),
295            Connection::Ssh(conn) => conn.validate_by_default(),
296            Connection::Aws(conn) => conn.validate_by_default(),
297            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
298            Connection::MySql(conn) => conn.validate_by_default(),
299            Connection::SqlServer(conn) => conn.validate_by_default(),
300            Connection::IcebergCatalog(conn) => conn.validate_by_default(),
301        }
302    }
303}
304
305impl Connection<InlinedConnection> {
306    /// Validates this connection by attempting to connect to the upstream system.
307    pub async fn validate(
308        &self,
309        id: CatalogItemId,
310        storage_configuration: &StorageConfiguration,
311    ) -> Result<(), ConnectionValidationError> {
312        match self {
313            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
314            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
315            Connection::Postgres(conn) => {
316                conn.validate(id, storage_configuration).await?;
317            }
318            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
319            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
320            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
321            Connection::MySql(conn) => {
322                conn.validate(id, storage_configuration).await?;
323            }
324            Connection::SqlServer(conn) => {
325                conn.validate(id, storage_configuration).await?;
326            }
327            Connection::IcebergCatalog(conn) => conn.validate(id, storage_configuration).await?,
328        }
329        Ok(())
330    }
331
332    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
333        match self {
334            Self::Kafka(conn) => conn,
335            o => unreachable!("{o:?} is not a Kafka connection"),
336        }
337    }
338
339    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
340        match self {
341            Self::Postgres(conn) => conn,
342            o => unreachable!("{o:?} is not a Postgres connection"),
343        }
344    }
345
346    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
347        match self {
348            Self::MySql(conn) => conn,
349            o => unreachable!("{o:?} is not a MySQL connection"),
350        }
351    }
352
353    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
354        match self {
355            Self::SqlServer(conn) => conn,
356            o => unreachable!("{o:?} is not a SQL Server connection"),
357        }
358    }
359
360    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
361        match self {
362            Self::Aws(conn) => conn,
363            o => unreachable!("{o:?} is not an AWS connection"),
364        }
365    }
366
367    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
368        match self {
369            Self::Ssh(conn) => conn,
370            o => unreachable!("{o:?} is not an SSH connection"),
371        }
372    }
373
374    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
375        match self {
376            Self::Csr(conn) => conn,
377            o => unreachable!("{o:?} is not a Kafka connection"),
378        }
379    }
380
381    pub fn unwrap_iceberg_catalog(self) -> <InlinedConnection as ConnectionAccess>::IcebergCatalog {
382        match self {
383            Self::IcebergCatalog(conn) => conn,
384            o => unreachable!("{o:?} is not an Iceberg catalog connection"),
385        }
386    }
387}
388
389/// An error returned by [`Connection::validate`].
390#[derive(thiserror::Error, Debug)]
391pub enum ConnectionValidationError {
392    #[error(transparent)]
393    Postgres(#[from] PostgresConnectionValidationError),
394    #[error(transparent)]
395    MySql(#[from] MySqlConnectionValidationError),
396    #[error(transparent)]
397    SqlServer(#[from] SqlServerConnectionValidationError),
398    #[error(transparent)]
399    Aws(#[from] AwsConnectionValidationError),
400    #[error("{}", .0.display_with_causes())]
401    Other(#[from] anyhow::Error),
402}
403
404impl ConnectionValidationError {
405    /// Reports additional details about the error, if any are available.
406    pub fn detail(&self) -> Option<String> {
407        match self {
408            ConnectionValidationError::Postgres(e) => e.detail(),
409            ConnectionValidationError::MySql(e) => e.detail(),
410            ConnectionValidationError::SqlServer(e) => e.detail(),
411            ConnectionValidationError::Aws(e) => e.detail(),
412            ConnectionValidationError::Other(_) => None,
413        }
414    }
415
416    /// Reports a hint for the user about how the error could be fixed.
417    pub fn hint(&self) -> Option<String> {
418        match self {
419            ConnectionValidationError::Postgres(e) => e.hint(),
420            ConnectionValidationError::MySql(e) => e.hint(),
421            ConnectionValidationError::SqlServer(e) => e.hint(),
422            ConnectionValidationError::Aws(e) => e.hint(),
423            ConnectionValidationError::Other(_) => None,
424        }
425    }
426}
427
428impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
429    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
430        match (self, other) {
431            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
432            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
433            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
434            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
435            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
436            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
437            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
438            _ => {
439                tracing::warn!(
440                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
441                    self,
442                    other
443                );
444                Err(AlterError { id })
445            }
446        }
447    }
448}
449
450#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
451pub struct RestIcebergCatalog {
452    /// For REST catalogs, the oauth2 credential in a `CLIENT_ID:CLIENT_SECRET` format
453    pub credential: StringOrSecret,
454    /// The oauth2 scope for REST catalogs
455    pub scope: Option<String>,
456    /// The warehouse for REST catalogs
457    pub warehouse: Option<String>,
458}
459
460#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
461pub struct S3TablesRestIcebergCatalog<C: ConnectionAccess = InlinedConnection> {
462    /// The AWS connection details, for s3tables
463    pub aws_connection: AwsConnectionReference<C>,
464    /// The warehouse for s3tables
465    pub warehouse: String,
466}
467
468impl<R: ConnectionResolver> IntoInlineConnection<S3TablesRestIcebergCatalog, R>
469    for S3TablesRestIcebergCatalog<ReferencedConnection>
470{
471    fn into_inline_connection(self, r: R) -> S3TablesRestIcebergCatalog {
472        S3TablesRestIcebergCatalog {
473            aws_connection: self.aws_connection.into_inline_connection(&r),
474            warehouse: self.warehouse,
475        }
476    }
477}
478
479#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
480pub enum IcebergCatalogType {
481    Rest,
482    S3TablesRest,
483}
484
485#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
486pub enum IcebergCatalogImpl<C: ConnectionAccess = InlinedConnection> {
487    Rest(RestIcebergCatalog),
488    S3TablesRest(S3TablesRestIcebergCatalog<C>),
489}
490
491impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogImpl, R>
492    for IcebergCatalogImpl<ReferencedConnection>
493{
494    fn into_inline_connection(self, r: R) -> IcebergCatalogImpl {
495        match self {
496            IcebergCatalogImpl::Rest(rest) => IcebergCatalogImpl::Rest(rest),
497            IcebergCatalogImpl::S3TablesRest(s3tables) => {
498                IcebergCatalogImpl::S3TablesRest(s3tables.into_inline_connection(r))
499            }
500        }
501    }
502}
503
504#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
505pub struct IcebergCatalogConnection<C: ConnectionAccess = InlinedConnection> {
506    /// The catalog impl impl of that catalog
507    pub catalog: IcebergCatalogImpl<C>,
508    /// Where the catalog is located
509    pub uri: reqwest::Url,
510}
511
512impl AlterCompatible for IcebergCatalogConnection {
513    fn alter_compatible(&self, id: GlobalId, _other: &Self) -> Result<(), AlterError> {
514        Err(AlterError { id })
515    }
516}
517
518impl<R: ConnectionResolver> IntoInlineConnection<IcebergCatalogConnection, R>
519    for IcebergCatalogConnection<ReferencedConnection>
520{
521    fn into_inline_connection(self, r: R) -> IcebergCatalogConnection {
522        IcebergCatalogConnection {
523            catalog: self.catalog.into_inline_connection(&r),
524            uri: self.uri,
525        }
526    }
527}
528
529impl<C: ConnectionAccess> IcebergCatalogConnection<C> {
530    fn validate_by_default(&self) -> bool {
531        true
532    }
533}
534
535impl IcebergCatalogConnection<InlinedConnection> {
536    pub async fn connect(
537        &self,
538        storage_configuration: &StorageConfiguration,
539        in_task: InTask,
540    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
541        match self.catalog {
542            IcebergCatalogImpl::S3TablesRest(ref s3tables) => {
543                self.connect_s3tables(s3tables, storage_configuration, in_task)
544                    .await
545            }
546            IcebergCatalogImpl::Rest(ref rest) => {
547                self.connect_rest(rest, storage_configuration, in_task)
548                    .await
549            }
550        }
551    }
552
553    pub fn catalog_type(&self) -> IcebergCatalogType {
554        match self.catalog {
555            IcebergCatalogImpl::S3TablesRest(_) => IcebergCatalogType::S3TablesRest,
556            IcebergCatalogImpl::Rest(_) => IcebergCatalogType::Rest,
557        }
558    }
559
560    pub fn s3tables_catalog(&self) -> Option<&S3TablesRestIcebergCatalog> {
561        match &self.catalog {
562            IcebergCatalogImpl::S3TablesRest(s3tables) => Some(s3tables),
563            IcebergCatalogImpl::Rest(_) => None,
564        }
565    }
566
567    pub fn rest_catalog(&self) -> Option<&RestIcebergCatalog> {
568        match &self.catalog {
569            IcebergCatalogImpl::Rest(rest) => Some(rest),
570            IcebergCatalogImpl::S3TablesRest(_) => None,
571        }
572    }
573
574    async fn connect_s3tables(
575        &self,
576        s3tables: &S3TablesRestIcebergCatalog,
577        storage_configuration: &StorageConfiguration,
578        in_task: InTask,
579    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
580        let secret_reader = &storage_configuration.connection_context.secrets_reader;
581        let aws_ref = &s3tables.aws_connection;
582        let aws_config = aws_ref
583            .connection
584            .load_sdk_config(
585                &storage_configuration.connection_context,
586                aws_ref.connection_id,
587                in_task,
588                ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
589            )
590            .await
591            .with_context(|| {
592                format!(
593                    "failed to load AWS SDK config for S3 Tables Iceberg catalog \
594                     (connection id: {}, auth method: {}, catalog uri: {}, warehouse: {})",
595                    aws_ref.connection_id,
596                    aws_ref.connection.auth_method(),
597                    self.uri,
598                    s3tables.warehouse
599                )
600            })?;
601
602        let aws_region = aws_ref
603            .connection
604            .region
605            .clone()
606            .unwrap_or_else(|| "us-east-1".to_string());
607
608        let mut props = vec![
609            (S3_REGION.to_string(), aws_region),
610            (S3_DISABLE_EC2_METADATA.to_string(), "true".to_string()),
611            (
612                REST_CATALOG_PROP_WAREHOUSE.to_string(),
613                s3tables.warehouse.clone(),
614            ),
615            (REST_CATALOG_PROP_URI.to_string(), self.uri.to_string()),
616        ];
617
618        let aws_auth = aws_ref.connection.auth.clone();
619
620        if let AwsAuth::Credentials(creds) = &aws_auth {
621            props.push((
622                S3_ACCESS_KEY_ID.to_string(),
623                creds
624                    .access_key_id
625                    .get_string(in_task, secret_reader)
626                    .await?,
627            ));
628            props.push((
629                S3_SECRET_ACCESS_KEY.to_string(),
630                secret_reader.read_string(creds.secret_access_key).await?,
631            ));
632        }
633
634        // Build the catalog with aws_config for REST API signing.
635        // For AssumeRole auth, we also add a FileIO extension so OpenDAL can
636        // use our credential chain for S3 object access.
637        let catalog = RestCatalogBuilder::default()
638            .with_aws_config(aws_config.clone())
639            .load("IcebergCatalog", props.into_iter().collect())
640            .await
641            .with_context(|| {
642                format!(
643                    "failed to create S3 Tables Iceberg catalog \
644                     (connection id: {}, catalog uri: {}, warehouse: {})",
645                    aws_ref.connection_id, self.uri, s3tables.warehouse
646                )
647            })?;
648
649        let catalog = if matches!(aws_auth, AwsAuth::AssumeRole(_)) {
650            let credentials_provider = aws_config
651                .credentials_provider()
652                .ok_or_else(|| anyhow!("aws_config missing credentials provider"))?;
653            let file_io_loader = CustomAwsCredentialLoader::new(Arc::new(
654                AwsSdkCredentialLoader::new(credentials_provider),
655            ));
656            catalog.with_file_io_extension(file_io_loader)
657        } else {
658            catalog
659        };
660
661        Ok(Arc::new(catalog))
662    }
663
664    async fn connect_rest(
665        &self,
666        rest: &RestIcebergCatalog,
667        storage_configuration: &StorageConfiguration,
668        in_task: InTask,
669    ) -> Result<Arc<dyn Catalog>, anyhow::Error> {
670        let mut props = BTreeMap::from([(
671            REST_CATALOG_PROP_URI.to_string(),
672            self.uri.to_string().clone(),
673        )]);
674
675        if let Some(warehouse) = &rest.warehouse {
676            props.insert(REST_CATALOG_PROP_WAREHOUSE.to_string(), warehouse.clone());
677        }
678
679        let credential = rest
680            .credential
681            .get_string(
682                in_task,
683                &storage_configuration.connection_context.secrets_reader,
684            )
685            .await
686            .map_err(|e| anyhow!("failed to read Iceberg catalog credential: {e}"))?;
687        props.insert(REST_CATALOG_PROP_CREDENTIAL.to_string(), credential);
688
689        if let Some(scope) = &rest.scope {
690            props.insert(REST_CATALOG_PROP_SCOPE.to_string(), scope.clone());
691        }
692
693        let catalog = RestCatalogBuilder::default()
694            .load("IcebergCatalog", props.into_iter().collect())
695            .await
696            .map_err(|e| anyhow!("failed to create Iceberg catalog: {e}"))?;
697        Ok(Arc::new(catalog))
698    }
699
700    async fn validate(
701        &self,
702        _id: CatalogItemId,
703        storage_configuration: &StorageConfiguration,
704    ) -> Result<(), ConnectionValidationError> {
705        let catalog = self
706            .connect(storage_configuration, InTask::No)
707            .await
708            .map_err(|e| {
709                ConnectionValidationError::Other(anyhow!("failed to connect to catalog: {e}"))
710            })?;
711
712        // If we can list namespaces, the connection is valid.
713        catalog.list_namespaces(None).await.map_err(|e| {
714            ConnectionValidationError::Other(anyhow!("failed to list namespaces: {e}"))
715        })?;
716
717        Ok(())
718    }
719}
720
721#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
722pub struct AwsPrivatelinkConnection {
723    pub service_name: String,
724    pub availability_zones: Vec<String>,
725}
726
727impl AlterCompatible for AwsPrivatelinkConnection {
728    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
729        // Every element of the AwsPrivatelinkConnection connection is configurable.
730        Ok(())
731    }
732}
733
734#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
735pub struct KafkaTlsConfig {
736    pub identity: Option<TlsIdentity>,
737    pub root_cert: Option<StringOrSecret>,
738}
739
740#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
741pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
742    pub mechanism: String,
743    pub username: StringOrSecret,
744    pub password: Option<CatalogItemId>,
745    pub aws: Option<AwsConnectionReference<C>>,
746}
747
748impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
749    for KafkaSaslConfig<ReferencedConnection>
750{
751    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
752        KafkaSaslConfig {
753            mechanism: self.mechanism,
754            username: self.username,
755            password: self.password,
756            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
757        }
758    }
759}
760
761/// Specifies a Kafka broker in a [`KafkaConnection`].
762#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
763pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
764    /// The address of the Kafka broker.
765    pub address: String,
766    /// An optional tunnel to use when connecting to the broker.
767    pub tunnel: Tunnel<C>,
768}
769
770impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
771    for KafkaBroker<ReferencedConnection>
772{
773    fn into_inline_connection(self, r: R) -> KafkaBroker {
774        let KafkaBroker { address, tunnel } = self;
775        KafkaBroker {
776            address,
777            tunnel: tunnel.into_inline_connection(r),
778        }
779    }
780}
781
782#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
783pub struct KafkaTopicOptions {
784    /// The replication factor for the topic.
785    /// If `None`, the broker default will be used.
786    pub replication_factor: Option<NonNeg<i32>>,
787    /// The number of partitions to create.
788    /// If `None`, the broker default will be used.
789    pub partition_count: Option<NonNeg<i32>>,
790    /// The initial configuration parameters for the topic.
791    pub topic_config: BTreeMap<String, String>,
792}
793
794#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
795pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
796    pub brokers: Vec<KafkaBroker<C>>,
797    /// A tunnel through which to route traffic,
798    /// that can be overridden for individual brokers
799    /// in `brokers`.
800    pub default_tunnel: Tunnel<C>,
801    pub progress_topic: Option<String>,
802    pub progress_topic_options: KafkaTopicOptions,
803    pub options: BTreeMap<String, StringOrSecret>,
804    pub tls: Option<KafkaTlsConfig>,
805    pub sasl: Option<KafkaSaslConfig<C>>,
806}
807
808impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
809    for KafkaConnection<ReferencedConnection>
810{
811    fn into_inline_connection(self, r: R) -> KafkaConnection {
812        let KafkaConnection {
813            brokers,
814            progress_topic,
815            progress_topic_options,
816            default_tunnel,
817            options,
818            tls,
819            sasl,
820        } = self;
821
822        let brokers = brokers
823            .into_iter()
824            .map(|broker| broker.into_inline_connection(&r))
825            .collect();
826
827        KafkaConnection {
828            brokers,
829            progress_topic,
830            progress_topic_options,
831            default_tunnel: default_tunnel.into_inline_connection(&r),
832            options,
833            tls,
834            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
835        }
836    }
837}
838
839impl<C: ConnectionAccess> KafkaConnection<C> {
840    /// Returns the name of the progress topic to use for the connection.
841    ///
842    /// The caller is responsible for providing the connection ID as it is not
843    /// known to `KafkaConnection`.
844    pub fn progress_topic(
845        &self,
846        connection_context: &ConnectionContext,
847        connection_id: CatalogItemId,
848    ) -> Cow<'_, str> {
849        if let Some(progress_topic) = &self.progress_topic {
850            Cow::Borrowed(progress_topic)
851        } else {
852            Cow::Owned(format!(
853                "_materialize-progress-{}-{}",
854                connection_context.environment_id, connection_id,
855            ))
856        }
857    }
858
859    fn validate_by_default(&self) -> bool {
860        true
861    }
862}
863
864impl KafkaConnection {
865    /// Generates a string that can be used as the base for a configuration ID
866    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
867    /// or sink.
868    pub fn id_base(
869        connection_context: &ConnectionContext,
870        connection_id: CatalogItemId,
871        object_id: GlobalId,
872    ) -> String {
873        format!(
874            "materialize-{}-{}-{}",
875            connection_context.environment_id, connection_id, object_id,
876        )
877    }
878
879    /// Enriches the provided `client_id` according to any enrichment rules in
880    /// the `kafka_client_id_enrichment_rules` configuration parameter.
881    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
882        #[derive(Debug, Deserialize)]
883        struct EnrichmentRule {
884            #[serde(deserialize_with = "deserialize_regex")]
885            pattern: Regex,
886            payload: String,
887        }
888
889        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
890        where
891            D: Deserializer<'de>,
892        {
893            let buf = String::deserialize(deserializer)?;
894            Regex::new(&buf).map_err(serde::de::Error::custom)
895        }
896
897        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
898        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
899            Ok(rules) => rules,
900            Err(e) => {
901                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
902                return;
903            }
904        };
905
906        // Check every rule against every broker. Rules are matched in the order
907        // that they are specified. It is usually a configuration error if
908        // multiple rules match the same list of Kafka brokers, but we
909        // nonetheless want to provide well defined semantics.
910        debug!(?self.brokers, "evaluating client ID enrichment rules");
911        for rule in rules {
912            let is_match = self
913                .brokers
914                .iter()
915                .any(|b| rule.pattern.is_match(&b.address));
916            debug!(?rule, is_match, "evaluated client ID enrichment rule");
917            if is_match {
918                client_id.push('-');
919                client_id.push_str(&rule.payload);
920            }
921        }
922    }
923
924    /// Creates a Kafka client for the connection.
925    pub async fn create_with_context<C, T>(
926        &self,
927        storage_configuration: &StorageConfiguration,
928        context: C,
929        extra_options: &BTreeMap<&str, String>,
930        in_task: InTask,
931    ) -> Result<T, ContextCreationError>
932    where
933        C: ClientContext,
934        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
935    {
936        let mut options = self.options.clone();
937
938        // Ensure that Kafka topics are *not* automatically created when
939        // consuming, producing, or fetching metadata for a topic. This ensures
940        // that we don't accidentally create topics with the wrong number of
941        // partitions.
942        options.insert("allow.auto.create.topics".into(), "false".into());
943
944        let brokers = match &self.default_tunnel {
945            Tunnel::AwsPrivatelink(t) => {
946                assert!(&self.brokers.is_empty());
947
948                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
949                    .get(storage_configuration.config_set());
950                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
951
952                // When using a default privatelink tunnel broker/brokers cannot be specified
953                // instead the tunnel connection_id and port are used for the initial connection.
954                format!(
955                    "{}:{}",
956                    vpc_endpoint_host(
957                        t.connection_id,
958                        None, // Default tunnel does not support availability zones.
959                    ),
960                    t.port.unwrap_or(9092)
961                )
962            }
963            Tunnel::AwsPrivatelinks(_pl) => {
964                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
965                    .get(storage_configuration.config_set());
966                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
967
968                if self.brokers.is_empty() {
969                    return Err(ContextCreationError::Other(anyhow::anyhow!(
970                        "at least one static broker is required when using BROKER or BROKERS"
971                    )));
972                }
973                self.brokers.iter().map(|b| &b.address).join(",")
974            }
975            _ => self.brokers.iter().map(|b| &b.address).join(","),
976        };
977        options.insert("bootstrap.servers".into(), brokers.clone().into());
978        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
979            (false, false) => "PLAINTEXT",
980            (true, false) => "SSL",
981            (false, true) => "SASL_PLAINTEXT",
982            (true, true) => "SASL_SSL",
983        };
984        info!(
985            "kafka: create_with_context bootstrap.servers={brokers}, security_protocol={security_protocol}"
986        );
987        options.insert("security.protocol".into(), security_protocol.into());
988        if let Some(tls) = &self.tls {
989            if let Some(root_cert) = &tls.root_cert {
990                options.insert("ssl.ca.pem".into(), root_cert.clone());
991            }
992            if let Some(identity) = &tls.identity {
993                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
994                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
995            }
996        }
997        if let Some(sasl) = &self.sasl {
998            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
999            options.insert("sasl.username".into(), sasl.username.clone());
1000            if let Some(password) = sasl.password {
1001                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
1002            }
1003        }
1004
1005        options.insert(
1006            "retry.backoff.ms".into(),
1007            KAFKA_RETRY_BACKOFF
1008                .get(storage_configuration.config_set())
1009                .as_millis()
1010                .into(),
1011        );
1012        options.insert(
1013            "retry.backoff.max.ms".into(),
1014            KAFKA_RETRY_BACKOFF_MAX
1015                .get(storage_configuration.config_set())
1016                .as_millis()
1017                .into(),
1018        );
1019        options.insert(
1020            "reconnect.backoff.ms".into(),
1021            KAFKA_RECONNECT_BACKOFF
1022                .get(storage_configuration.config_set())
1023                .as_millis()
1024                .into(),
1025        );
1026        options.insert(
1027            "reconnect.backoff.max.ms".into(),
1028            KAFKA_RECONNECT_BACKOFF_MAX
1029                .get(storage_configuration.config_set())
1030                .as_millis()
1031                .into(),
1032        );
1033
1034        let mut config = mz_kafka_util::client::create_new_client_config(
1035            storage_configuration
1036                .connection_context
1037                .librdkafka_log_level,
1038            storage_configuration.parameters.kafka_timeout_config,
1039        );
1040        for (k, v) in options {
1041            config.set(
1042                k,
1043                v.get_string(
1044                    in_task,
1045                    &storage_configuration.connection_context.secrets_reader,
1046                )
1047                .await
1048                .context("reading kafka secret")?,
1049            );
1050        }
1051        for (k, v) in extra_options {
1052            config.set(*k, v);
1053        }
1054
1055        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
1056            None => None,
1057            Some(aws) => Some(
1058                aws.connection
1059                    .load_sdk_config(
1060                        &storage_configuration.connection_context,
1061                        aws.connection_id,
1062                        in_task,
1063                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1064                    )
1065                    .await?,
1066            ),
1067        };
1068
1069        // TODO(roshan): Implement enforcement of external address validation once
1070        // rdkafka client has been updated to support providing multiple resolved
1071        // addresses for brokers
1072        let mut context = TunnelingClientContext::new(
1073            context,
1074            Handle::current(),
1075            storage_configuration
1076                .connection_context
1077                .ssh_tunnel_manager
1078                .clone(),
1079            storage_configuration.parameters.ssh_timeout_config,
1080            aws_config,
1081            in_task,
1082        );
1083
1084        match &self.default_tunnel {
1085            Tunnel::Direct => {
1086                // By default, don't offer a default override for broker address lookup.
1087            }
1088            Tunnel::AwsPrivatelink(pl) => {
1089                context.set_default_tunnel(TunnelConfig::StaticHost(
1090                    // Possible bug: We have been ignoring the configured port.
1091                    KafkaConnection::from_default_aws_privatelink(pl).host,
1092                ));
1093            }
1094            Tunnel::AwsPrivatelinks(pl) => {
1095                context.set_default_tunnel(TunnelConfig::Rules(
1096                    KafkaConnection::from_aws_privatelinks(pl),
1097                ));
1098            }
1099            Tunnel::Ssh(ssh_tunnel) => {
1100                let secret = storage_configuration
1101                    .connection_context
1102                    .secrets_reader
1103                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
1104                    .await?;
1105                let key_pair = SshKeyPair::from_bytes(&secret)?;
1106
1107                // Ensure any ssh-bastion address we connect to is resolved to an external address.
1108                let resolved = resolve_address(
1109                    &ssh_tunnel.connection.host,
1110                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1111                )
1112                .await?;
1113                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
1114                    host: resolved
1115                        .iter()
1116                        .map(|a| a.to_string())
1117                        .collect::<BTreeSet<_>>(),
1118                    port: ssh_tunnel.connection.port,
1119                    user: ssh_tunnel.connection.user.clone(),
1120                    key_pair,
1121                }));
1122            }
1123        }
1124        info!(
1125            "kafka: tunnel config set to {}",
1126            match &self.default_tunnel {
1127                Tunnel::Direct => "Direct".to_string(),
1128                Tunnel::AwsPrivatelink(_) => "AwsPrivatelink (static host)".to_string(),
1129                Tunnel::AwsPrivatelinks(pl) =>
1130                    format!("AwsPrivatelinks ({} rules)", pl.rules.len()),
1131                Tunnel::Ssh(_) => "Ssh".to_string(),
1132            }
1133        );
1134
1135        // Here, we preemptively rewrite broker addresses.
1136        // In concept, this overlaps with 'TunnelingClientContext::resolve_broker_addr'.
1137        for broker in &self.brokers {
1138            let mut addr_parts = broker.address.splitn(2, ':');
1139            let addr = BrokerAddr {
1140                host: addr_parts
1141                    .next()
1142                    .context("BROKER is not address:port")?
1143                    .into(),
1144                port: addr_parts
1145                    .next()
1146                    .unwrap_or("9092")
1147                    .parse()
1148                    .context("parsing BROKER port")?,
1149            };
1150            match &broker.tunnel {
1151                Tunnel::Direct => {
1152                    // By default, don't override broker address lookup.
1153                    //
1154                    // N.B.
1155                    //
1156                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
1157                    // we avoid doing because:
1158                    // - Its not necessary.
1159                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
1160                    // in the `TunnelingClientContext`.
1161                }
1162                Tunnel::AwsPrivatelink(aws_privatelink) => {
1163                    context.add_broker_rewrite(
1164                        addr,
1165                        KafkaConnection::from_aws_privatelink(aws_privatelink),
1166                    );
1167                }
1168                Tunnel::AwsPrivatelinks(_) => unreachable!(
1169                    "Individually predefined brokers do not use rule-based PrivateLinks routing."
1170                ),
1171                Tunnel::Ssh(ssh_tunnel) => {
1172                    // Ensure any SSH bastion address we connect to is resolved to an external address.
1173                    let ssh_host_resolved = resolve_address(
1174                        &ssh_tunnel.connection.host,
1175                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1176                    )
1177                    .await?;
1178                    context
1179                        .add_ssh_tunnel(
1180                            addr,
1181                            SshTunnelConfig {
1182                                host: ssh_host_resolved
1183                                    .iter()
1184                                    .map(|a| a.to_string())
1185                                    .collect::<BTreeSet<_>>(),
1186                                port: ssh_tunnel.connection.port,
1187                                user: ssh_tunnel.connection.user.clone(),
1188                                key_pair: SshKeyPair::from_bytes(
1189                                    &storage_configuration
1190                                        .connection_context
1191                                        .secrets_reader
1192                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
1193                                        .await?,
1194                                )?,
1195                            },
1196                        )
1197                        .await
1198                        .map_err(ContextCreationError::Ssh)?;
1199                }
1200            }
1201        }
1202
1203        Ok(config.create_with_context(context)?)
1204    }
1205
1206    async fn validate(
1207        &self,
1208        _id: CatalogItemId,
1209        storage_configuration: &StorageConfiguration,
1210    ) -> Result<(), anyhow::Error> {
1211        let (context, error_rx) = MzClientContext::with_errors();
1212        let consumer: BaseConsumer<_> = self
1213            .create_with_context(
1214                storage_configuration,
1215                context,
1216                &BTreeMap::new(),
1217                // We are in a normal tokio context during validation, already.
1218                InTask::No,
1219            )
1220            .await?;
1221        let consumer = Arc::new(consumer);
1222
1223        let timeout = storage_configuration
1224            .parameters
1225            .kafka_timeout_config
1226            .fetch_metadata_timeout;
1227
1228        // librdkafka doesn't expose an API for determining whether a connection to
1229        // the Kafka cluster has been successfully established. So we make a
1230        // metadata request, though we don't care about the results, so that we can
1231        // report any errors making that request. If the request succeeds, we know
1232        // we were able to contact at least one broker, and that's a good proxy for
1233        // being able to contact all the brokers in the cluster.
1234        //
1235        // The downside of this approach is it produces a generic error message like
1236        // "metadata fetch error" with no additional details. The real networking
1237        // error is buried in the librdkafka logs, which are not visible to users.
1238        info!("kafka: starting connection validation via fetch_metadata (timeout={timeout:?})");
1239        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
1240            let consumer = Arc::clone(&consumer);
1241            move || consumer.fetch_metadata(None, timeout)
1242        })
1243        .await;
1244        info!(
1245            "kafka: connection validation result: {}",
1246            if result.is_ok() { "success" } else { "failed" },
1247        );
1248        match result {
1249            Ok(_) => Ok(()),
1250            // The error returned by `fetch_metadata` does not provide any details which makes for
1251            // a crappy user facing error message. For this reason we attempt to grab a better
1252            // error message from the client context, which should contain any error logs emitted
1253            // by librdkafka, and fallback to the generic error if there is nothing there.
1254            Err(err) => {
1255                // Multiple errors might have been logged during this validation but some are more
1256                // relevant than others. Specifically, we prefer non-internal errors over internal
1257                // errors since those give much more useful information to the users.
1258                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
1259                    MzKafkaError::Internal(_) => new,
1260                    _ => cur,
1261                });
1262
1263                // Don't drop the consumer until after we've drained the errors
1264                // channel. Dropping the consumer can introduce spurious errors.
1265                // See database-issues#7432.
1266                drop(consumer);
1267
1268                match main_err {
1269                    Some(err) => Err(err.into()),
1270                    None => Err(err.into()),
1271                }
1272            }
1273        }
1274    }
1275
1276    /// The "default" PrivateLink connection is used for bootstrapping Kafka.
1277    fn from_default_aws_privatelink(pl: &AwsPrivatelink) -> BrokerRewrite {
1278        BrokerRewrite {
1279            host: vpc_endpoint_host(
1280                pl.connection_id,
1281                None, // Default tunnel does not support availability zones.
1282            ),
1283            port: pl.port,
1284        }
1285    }
1286
1287    /// The "not default" PrivateLink connections are used for routing to specific Kafka brokers.
1288    fn from_aws_privatelink(pl: &AwsPrivatelink) -> BrokerRewrite {
1289        BrokerRewrite {
1290            host: vpc_endpoint_host(pl.connection_id, pl.availability_zone.as_deref()),
1291            port: pl.port,
1292        }
1293    }
1294
1295    fn from_aws_privatelink_rule(
1296        AwsPrivatelinkRule { pattern, to }: &AwsPrivatelinkRule,
1297    ) -> (mz_kafka_util::client::ConnectionRulePattern, BrokerRewrite) {
1298        (
1299            mz_kafka_util::client::ConnectionRulePattern {
1300                prefix_wildcard: pattern.prefix_wildcard,
1301                literal_match: pattern.literal_match.clone(),
1302                suffix_wildcard: pattern.suffix_wildcard,
1303            },
1304            KafkaConnection::from_aws_privatelink(to),
1305        )
1306    }
1307
1308    fn from_aws_privatelinks(pl: &AwsPrivatelinks) -> HostMappingRules {
1309        HostMappingRules {
1310            rules: pl
1311                .rules
1312                .iter()
1313                .map(KafkaConnection::from_aws_privatelink_rule)
1314                .collect_vec(),
1315        }
1316    }
1317}
1318
1319impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
1320    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1321        let KafkaConnection {
1322            brokers: _,
1323            default_tunnel: _,
1324            progress_topic,
1325            progress_topic_options,
1326            options: _,
1327            tls: _,
1328            sasl: _,
1329        } = self;
1330
1331        let compatibility_checks = [
1332            (progress_topic == &other.progress_topic, "progress_topic"),
1333            (
1334                progress_topic_options == &other.progress_topic_options,
1335                "progress_topic_options",
1336            ),
1337        ];
1338
1339        for (compatible, field) in compatibility_checks {
1340            if !compatible {
1341                tracing::warn!(
1342                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1343                    self,
1344                    other
1345                );
1346
1347                return Err(AlterError { id });
1348            }
1349        }
1350
1351        Ok(())
1352    }
1353}
1354
1355/// A connection to a Confluent Schema Registry.
1356#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1357pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1358    /// The URL of the schema registry.
1359    pub url: Url,
1360    /// Trusted root TLS certificate in PEM format.
1361    pub tls_root_cert: Option<StringOrSecret>,
1362    /// An optional TLS client certificate for authentication with the schema
1363    /// registry.
1364    pub tls_identity: Option<TlsIdentity>,
1365    /// Optional HTTP authentication credentials for the schema registry.
1366    pub http_auth: Option<CsrConnectionHttpAuth>,
1367    /// A tunnel through which to route traffic.
1368    pub tunnel: Tunnel<C>,
1369}
1370
1371impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1372    for CsrConnection<ReferencedConnection>
1373{
1374    fn into_inline_connection(self, r: R) -> CsrConnection {
1375        let CsrConnection {
1376            url,
1377            tls_root_cert,
1378            tls_identity,
1379            http_auth,
1380            tunnel,
1381        } = self;
1382        CsrConnection {
1383            url,
1384            tls_root_cert,
1385            tls_identity,
1386            http_auth,
1387            tunnel: tunnel.into_inline_connection(r),
1388        }
1389    }
1390}
1391
1392impl<C: ConnectionAccess> CsrConnection<C> {
1393    fn validate_by_default(&self) -> bool {
1394        true
1395    }
1396}
1397
1398impl CsrConnection {
1399    /// Constructs a schema registry client from the connection.
1400    pub async fn connect(
1401        &self,
1402        storage_configuration: &StorageConfiguration,
1403        in_task: InTask,
1404    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1405        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1406        if let Some(root_cert) = &self.tls_root_cert {
1407            let root_cert = root_cert
1408                .get_string(
1409                    in_task,
1410                    &storage_configuration.connection_context.secrets_reader,
1411                )
1412                .await?;
1413            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1414            client_config = client_config.add_root_certificate(root_cert);
1415        }
1416
1417        if let Some(tls_identity) = &self.tls_identity {
1418            let key = &storage_configuration
1419                .connection_context
1420                .secrets_reader
1421                .read_string_in_task_if(in_task, tls_identity.key)
1422                .await?;
1423            let cert = tls_identity
1424                .cert
1425                .get_string(
1426                    in_task,
1427                    &storage_configuration.connection_context.secrets_reader,
1428                )
1429                .await?;
1430            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1431            client_config = client_config.identity(ident);
1432        }
1433
1434        if let Some(http_auth) = &self.http_auth {
1435            let username = http_auth
1436                .username
1437                .get_string(
1438                    in_task,
1439                    &storage_configuration.connection_context.secrets_reader,
1440                )
1441                .await?;
1442            let password = match http_auth.password {
1443                None => None,
1444                Some(password) => Some(
1445                    storage_configuration
1446                        .connection_context
1447                        .secrets_reader
1448                        .read_string_in_task_if(in_task, password)
1449                        .await?,
1450                ),
1451            };
1452            client_config = client_config.auth(username, password);
1453        }
1454
1455        // TODO: use types to enforce that the URL has a string hostname.
1456        let host = self
1457            .url
1458            .host_str()
1459            .ok_or_else(|| anyhow!("url missing host"))?;
1460        match &self.tunnel {
1461            Tunnel::Direct => {
1462                // Ensure any host we connect to is resolved to an external address.
1463                let resolved = resolve_address(
1464                    host,
1465                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1466                )
1467                .await?;
1468                client_config = client_config.resolve_to_addrs(
1469                    host,
1470                    &resolved
1471                        .iter()
1472                        .map(|addr| SocketAddr::new(*addr, 0))
1473                        .collect::<Vec<_>>(),
1474                )
1475            }
1476            Tunnel::Ssh(ssh_tunnel) => {
1477                let ssh_tunnel = ssh_tunnel
1478                    .connect(
1479                        storage_configuration,
1480                        host,
1481                        // Default to the default http port, but this
1482                        // could default to 8081...
1483                        self.url.port().unwrap_or(80),
1484                        in_task,
1485                    )
1486                    .await
1487                    .map_err(CsrConnectError::Ssh)?;
1488
1489                // Carefully inject the SSH tunnel into the client
1490                // configuration. This is delicate because we need TLS
1491                // verification to continue to use the remote hostname rather
1492                // than the tunnel hostname.
1493
1494                client_config = client_config
1495                    // `resolve_to_addrs` allows us to rewrite the hostname
1496                    // at the DNS level, which means the TCP connection is
1497                    // correctly routed through the tunnel, but TLS verification
1498                    // is still performed against the remote hostname.
1499                    // Unfortunately the port here is ignored if the URL also
1500                    // specifies a port...
1501                    .resolve_to_addrs(host, &[SocketAddr::new(ssh_tunnel.local_addr().ip(), 0)])
1502                    // ...so we also dynamically rewrite the URL to use the
1503                    // current port for the SSH tunnel.
1504                    //
1505                    // WARNING: this is brittle, because we only dynamically
1506                    // update the client configuration with the tunnel *port*,
1507                    // and not the hostname This works fine in practice, because
1508                    // only the SSH tunnel port will change if the tunnel fails
1509                    // and has to be restarted (the hostname is always
1510                    // 127.0.0.1)--but this is an an implementation detail of
1511                    // the SSH tunnel code that we're relying on.
1512                    .dynamic_url({
1513                        let remote_url = self.url.clone();
1514                        move || {
1515                            let mut url = remote_url.clone();
1516                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1517                                .expect("cannot fail");
1518                            url
1519                        }
1520                    });
1521            }
1522            Tunnel::AwsPrivatelink(connection) => {
1523                assert_none!(connection.port);
1524
1525                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1526                    connection.connection_id,
1527                    connection.availability_zone.as_deref(),
1528                );
1529                let addrs: Vec<_> = net::lookup_host((privatelink_host, 0))
1530                    .await
1531                    .context("resolving PrivateLink host")?
1532                    .collect();
1533                client_config = client_config.resolve_to_addrs(host, &addrs)
1534            }
1535            Tunnel::AwsPrivatelinks(_) => {
1536                unreachable!("MATCHING broker rules are only available for Kafka connections.");
1537            }
1538        }
1539
1540        Ok(client_config.build()?)
1541    }
1542
1543    async fn validate(
1544        &self,
1545        _id: CatalogItemId,
1546        storage_configuration: &StorageConfiguration,
1547    ) -> Result<(), anyhow::Error> {
1548        let client = self
1549            .connect(
1550                storage_configuration,
1551                // We are in a normal tokio context during validation, already.
1552                InTask::No,
1553            )
1554            .await?;
1555        client.list_subjects().await?;
1556        Ok(())
1557    }
1558}
1559
1560impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1561    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1562        let CsrConnection {
1563            tunnel,
1564            // All non-tunnel fields may change
1565            url: _,
1566            tls_root_cert: _,
1567            tls_identity: _,
1568            http_auth: _,
1569        } = self;
1570
1571        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1572
1573        for (compatible, field) in compatibility_checks {
1574            if !compatible {
1575                tracing::warn!(
1576                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1577                    self,
1578                    other
1579                );
1580
1581                return Err(AlterError { id });
1582            }
1583        }
1584        Ok(())
1585    }
1586}
1587
1588/// A TLS key pair used for client identity.
1589#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1590pub struct TlsIdentity {
1591    /// The client's TLS public certificate in PEM format.
1592    pub cert: StringOrSecret,
1593    /// The ID of the secret containing the client's TLS private key in PEM
1594    /// format.
1595    pub key: CatalogItemId,
1596}
1597
1598/// HTTP authentication credentials in a [`CsrConnection`].
1599#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1600pub struct CsrConnectionHttpAuth {
1601    /// The username.
1602    pub username: StringOrSecret,
1603    /// The ID of the secret containing the password, if any.
1604    pub password: Option<CatalogItemId>,
1605}
1606
1607/// A connection to a PostgreSQL server.
1608#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1609pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1610    /// The hostname of the server.
1611    pub host: String,
1612    /// The port of the server.
1613    pub port: u16,
1614    /// The name of the database to connect to.
1615    pub database: String,
1616    /// The username to authenticate as.
1617    pub user: StringOrSecret,
1618    /// An optional password for authentication.
1619    pub password: Option<CatalogItemId>,
1620    /// A tunnel through which to route traffic.
1621    pub tunnel: Tunnel<C>,
1622    /// Whether to use TLS for encryption, authentication, or both.
1623    pub tls_mode: SslMode,
1624    /// An optional root TLS certificate in PEM format, to verify the server's
1625    /// identity.
1626    pub tls_root_cert: Option<StringOrSecret>,
1627    /// An optional TLS client certificate for authentication.
1628    pub tls_identity: Option<TlsIdentity>,
1629}
1630
1631impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1632    for PostgresConnection<ReferencedConnection>
1633{
1634    fn into_inline_connection(self, r: R) -> PostgresConnection {
1635        let PostgresConnection {
1636            host,
1637            port,
1638            database,
1639            user,
1640            password,
1641            tunnel,
1642            tls_mode,
1643            tls_root_cert,
1644            tls_identity,
1645        } = self;
1646
1647        PostgresConnection {
1648            host,
1649            port,
1650            database,
1651            user,
1652            password,
1653            tunnel: tunnel.into_inline_connection(r),
1654            tls_mode,
1655            tls_root_cert,
1656            tls_identity,
1657        }
1658    }
1659}
1660
1661impl<C: ConnectionAccess> PostgresConnection<C> {
1662    fn validate_by_default(&self) -> bool {
1663        true
1664    }
1665}
1666
1667impl PostgresConnection<InlinedConnection> {
1668    pub async fn config(
1669        &self,
1670        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1671        storage_configuration: &StorageConfiguration,
1672        in_task: InTask,
1673    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1674        let params = &storage_configuration.parameters;
1675
1676        let mut config = tokio_postgres::Config::new();
1677        config
1678            .host(&self.host)
1679            .port(self.port)
1680            .dbname(&self.database)
1681            .user(&self.user.get_string(in_task, secrets_reader).await?)
1682            .ssl_mode(self.tls_mode);
1683        if let Some(password) = self.password {
1684            let password = secrets_reader
1685                .read_string_in_task_if(in_task, password)
1686                .await?;
1687            config.password(password);
1688        }
1689        if let Some(tls_root_cert) = &self.tls_root_cert {
1690            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1691            config.ssl_root_cert(tls_root_cert.as_bytes());
1692        }
1693        if let Some(tls_identity) = &self.tls_identity {
1694            let cert = tls_identity
1695                .cert
1696                .get_string(in_task, secrets_reader)
1697                .await?;
1698            let key = secrets_reader
1699                .read_string_in_task_if(in_task, tls_identity.key)
1700                .await?;
1701            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1702        }
1703
1704        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1705            config.connect_timeout(connect_timeout);
1706        }
1707        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1708            config.keepalives_retries(keepalives_retries);
1709        }
1710        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1711            config.keepalives_idle(keepalives_idle);
1712        }
1713        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1714            config.keepalives_interval(keepalives_interval);
1715        }
1716        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1717            config.tcp_user_timeout(tcp_user_timeout);
1718        }
1719
1720        let mut options = vec![];
1721        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1722            options.push(format!(
1723                "--wal_sender_timeout={}",
1724                wal_sender_timeout.as_millis()
1725            ));
1726        };
1727        if params.pg_source_tcp_configure_server {
1728            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1729                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1730            }
1731            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1732                options.push(format!(
1733                    "--tcp_keepalives_idle={}",
1734                    keepalives_idle.as_secs()
1735                ));
1736            }
1737            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1738                options.push(format!(
1739                    "--tcp_keepalives_interval={}",
1740                    keepalives_interval.as_secs()
1741                ));
1742            }
1743            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1744                options.push(format!(
1745                    "--tcp_user_timeout={}",
1746                    tcp_user_timeout.as_millis()
1747                ));
1748            }
1749        }
1750        config.options(options.join(" ").as_str());
1751
1752        let tunnel = match &self.tunnel {
1753            Tunnel::Direct => {
1754                // Ensure any host we connect to is resolved to an external address.
1755                let resolved = resolve_address(
1756                    &self.host,
1757                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1758                )
1759                .await?;
1760                mz_postgres_util::TunnelConfig::Direct {
1761                    resolved_ips: Some(resolved),
1762                }
1763            }
1764            Tunnel::Ssh(SshTunnel {
1765                connection_id,
1766                connection,
1767            }) => {
1768                let secret = secrets_reader
1769                    .read_in_task_if(in_task, *connection_id)
1770                    .await?;
1771                let key_pair = SshKeyPair::from_bytes(&secret)?;
1772                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1773                let resolved = resolve_address(
1774                    &connection.host,
1775                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1776                )
1777                .await?;
1778                mz_postgres_util::TunnelConfig::Ssh {
1779                    config: SshTunnelConfig {
1780                        host: resolved
1781                            .iter()
1782                            .map(|a| a.to_string())
1783                            .collect::<BTreeSet<_>>(),
1784                        port: connection.port,
1785                        user: connection.user.clone(),
1786                        key_pair,
1787                    },
1788                }
1789            }
1790            Tunnel::AwsPrivatelink(connection) => {
1791                assert_none!(connection.port);
1792                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1793                    connection_id: connection.connection_id,
1794                }
1795            }
1796            Tunnel::AwsPrivatelinks(_) => {
1797                unreachable!("MATCHING broker rules are only available for Kafka connections.");
1798            }
1799        };
1800
1801        Ok(mz_postgres_util::Config::new(
1802            config,
1803            tunnel,
1804            params.ssh_timeout_config,
1805            in_task,
1806        )?)
1807    }
1808
1809    pub async fn validate(
1810        &self,
1811        _id: CatalogItemId,
1812        storage_configuration: &StorageConfiguration,
1813    ) -> Result<mz_postgres_util::Client, anyhow::Error> {
1814        let config = self
1815            .config(
1816                &storage_configuration.connection_context.secrets_reader,
1817                storage_configuration,
1818                // We are in a normal tokio context during validation, already.
1819                InTask::No,
1820            )
1821            .await?;
1822        let client = config
1823            .connect(
1824                "connection validation",
1825                &storage_configuration.connection_context.ssh_tunnel_manager,
1826            )
1827            .await?;
1828
1829        let wal_level = mz_postgres_util::get_wal_level(&client).await?;
1830
1831        if wal_level < mz_postgres_util::replication::WalLevel::Logical {
1832            Err(PostgresConnectionValidationError::InsufficientWalLevel { wal_level })?;
1833        }
1834
1835        let max_wal_senders = mz_postgres_util::get_max_wal_senders(&client).await?;
1836
1837        if max_wal_senders < 1 {
1838            Err(PostgresConnectionValidationError::ReplicationDisabled)?;
1839        }
1840
1841        let available_replication_slots =
1842            mz_postgres_util::available_replication_slots(&client).await?;
1843
1844        // We need 1 replication slot for the snapshots and 1 for the continuing replication
1845        if available_replication_slots < 2 {
1846            Err(
1847                PostgresConnectionValidationError::InsufficientReplicationSlotsAvailable {
1848                    count: 2,
1849                },
1850            )?;
1851        }
1852
1853        Ok(client)
1854    }
1855}
1856
1857#[derive(Debug, Clone, thiserror::Error)]
1858pub enum PostgresConnectionValidationError {
1859    #[error("PostgreSQL server has insufficient number of replication slots available")]
1860    InsufficientReplicationSlotsAvailable { count: usize },
1861    #[error("server must have wal_level >= logical, but has {wal_level}")]
1862    InsufficientWalLevel {
1863        wal_level: mz_postgres_util::replication::WalLevel,
1864    },
1865    #[error("replication disabled on server")]
1866    ReplicationDisabled,
1867}
1868
1869impl PostgresConnectionValidationError {
1870    pub fn detail(&self) -> Option<String> {
1871        match self {
1872            Self::InsufficientReplicationSlotsAvailable { count } => Some(format!(
1873                "executing this statement requires {} replication slot{}",
1874                count,
1875                if *count == 1 { "" } else { "s" }
1876            )),
1877            _ => None,
1878        }
1879    }
1880
1881    pub fn hint(&self) -> Option<String> {
1882        match self {
1883            Self::InsufficientReplicationSlotsAvailable { .. } => Some(
1884                "you might be able to wait for other sources to finish snapshotting and try again"
1885                    .into(),
1886            ),
1887            Self::ReplicationDisabled => Some("set max_wal_senders to a value > 0".into()),
1888            Self::InsufficientWalLevel { .. } => None,
1889        }
1890    }
1891}
1892
1893impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1894    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1895        let PostgresConnection {
1896            tunnel,
1897            // All non-tunnel options may change arbitrarily
1898            host: _,
1899            port: _,
1900            database: _,
1901            user: _,
1902            password: _,
1903            tls_mode: _,
1904            tls_root_cert: _,
1905            tls_identity: _,
1906        } = self;
1907
1908        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1909
1910        for (compatible, field) in compatibility_checks {
1911            if !compatible {
1912                tracing::warn!(
1913                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1914                    self,
1915                    other
1916                );
1917
1918                return Err(AlterError { id });
1919            }
1920        }
1921        Ok(())
1922    }
1923}
1924
1925/// Specifies how to tunnel a connection.
1926#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1927pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1928    /// No tunneling.
1929    Direct,
1930    /// Via the specified SSH tunnel connection.
1931    Ssh(SshTunnel<C>),
1932    /// Via the specified AWS PrivateLink connection.
1933    AwsPrivatelink(AwsPrivatelink),
1934    AwsPrivatelinks(AwsPrivatelinks),
1935}
1936
1937impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1938    fn into_inline_connection(self, r: R) -> Tunnel {
1939        match self {
1940            Tunnel::Direct => Tunnel::Direct,
1941            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1942            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1943            Tunnel::AwsPrivatelinks(x) => Tunnel::AwsPrivatelinks(x),
1944        }
1945    }
1946}
1947
1948impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1949    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1950        let compatible = match (self, other) {
1951            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1952            (s, o) => s == o,
1953        };
1954
1955        if !compatible {
1956            tracing::warn!(
1957                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1958                self,
1959                other
1960            );
1961
1962            return Err(AlterError { id });
1963        }
1964
1965        Ok(())
1966    }
1967}
1968
1969/// Specifies which MySQL SSL Mode to use:
1970/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1971/// This is not available as an enum in the mysql-async crate, so we define our own.
1972#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1973pub enum MySqlSslMode {
1974    Disabled,
1975    Required,
1976    VerifyCa,
1977    VerifyIdentity,
1978}
1979
1980/// A connection to a MySQL server.
1981#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1982pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1983    /// The hostname of the server.
1984    pub host: String,
1985    /// The port of the server.
1986    pub port: u16,
1987    /// The username to authenticate as.
1988    pub user: StringOrSecret,
1989    /// An optional password for authentication.
1990    pub password: Option<CatalogItemId>,
1991    /// A tunnel through which to route traffic.
1992    pub tunnel: Tunnel<C>,
1993    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1994    pub tls_mode: MySqlSslMode,
1995    /// An optional root TLS certificate in PEM format, to verify the server's
1996    /// identity.
1997    pub tls_root_cert: Option<StringOrSecret>,
1998    /// An optional TLS client certificate for authentication.
1999    pub tls_identity: Option<TlsIdentity>,
2000    /// Reference to the AWS connection information to be used for IAM authenitcation and
2001    /// assuming AWS roles.
2002    pub aws_connection: Option<AwsConnectionReference<C>>,
2003}
2004
2005impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
2006    for MySqlConnection<ReferencedConnection>
2007{
2008    fn into_inline_connection(self, r: R) -> MySqlConnection {
2009        let MySqlConnection {
2010            host,
2011            port,
2012            user,
2013            password,
2014            tunnel,
2015            tls_mode,
2016            tls_root_cert,
2017            tls_identity,
2018            aws_connection,
2019        } = self;
2020
2021        MySqlConnection {
2022            host,
2023            port,
2024            user,
2025            password,
2026            tunnel: tunnel.into_inline_connection(&r),
2027            tls_mode,
2028            tls_root_cert,
2029            tls_identity,
2030            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
2031        }
2032    }
2033}
2034
2035impl<C: ConnectionAccess> MySqlConnection<C> {
2036    fn validate_by_default(&self) -> bool {
2037        true
2038    }
2039}
2040
2041impl MySqlConnection<InlinedConnection> {
2042    pub async fn config(
2043        &self,
2044        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2045        storage_configuration: &StorageConfiguration,
2046        in_task: InTask,
2047    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
2048        // TODO(roshan): Set appropriate connection timeouts
2049        let mut opts = mysql_async::OptsBuilder::default()
2050            .ip_or_hostname(&self.host)
2051            .tcp_port(self.port)
2052            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
2053
2054        if let Some(password) = self.password {
2055            let password = secrets_reader
2056                .read_string_in_task_if(in_task, password)
2057                .await?;
2058            opts = opts.pass(Some(password));
2059        }
2060
2061        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
2062        // which uses opt-in security features (SSL, CA verification, & Identity verification).
2063        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
2064        // we need to appropriately disable features to match the intent of each enum value.
2065        let mut ssl_opts = match self.tls_mode {
2066            MySqlSslMode::Disabled => None,
2067            MySqlSslMode::Required => Some(
2068                mysql_async::SslOpts::default()
2069                    .with_danger_accept_invalid_certs(true)
2070                    .with_danger_skip_domain_validation(true),
2071            ),
2072            MySqlSslMode::VerifyCa => {
2073                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
2074            }
2075            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
2076        };
2077
2078        if matches!(
2079            self.tls_mode,
2080            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
2081        ) {
2082            if let Some(tls_root_cert) = &self.tls_root_cert {
2083                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
2084                ssl_opts = ssl_opts.map(|opts| {
2085                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
2086                });
2087            }
2088        }
2089
2090        if let Some(identity) = &self.tls_identity {
2091            let key = secrets_reader
2092                .read_string_in_task_if(in_task, identity.key)
2093                .await?;
2094            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
2095            let (der, pass) =
2096                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?.into_parts();
2097
2098            // Add client identity to SSLOpts
2099            ssl_opts = ssl_opts.map(|opts| {
2100                opts.with_client_identity(Some(
2101                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
2102                ))
2103            });
2104        }
2105
2106        opts = opts.ssl_opts(ssl_opts);
2107
2108        let tunnel = match &self.tunnel {
2109            Tunnel::Direct => {
2110                // Ensure any host we connect to is resolved to an external address.
2111                let resolved = resolve_address(
2112                    &self.host,
2113                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2114                )
2115                .await?;
2116                mz_mysql_util::TunnelConfig::Direct {
2117                    resolved_ips: Some(resolved),
2118                }
2119            }
2120            Tunnel::Ssh(SshTunnel {
2121                connection_id,
2122                connection,
2123            }) => {
2124                let secret = secrets_reader
2125                    .read_in_task_if(in_task, *connection_id)
2126                    .await?;
2127                let key_pair = SshKeyPair::from_bytes(&secret)?;
2128                // Ensure any ssh-bastion host we connect to is resolved to an external address.
2129                let resolved = resolve_address(
2130                    &connection.host,
2131                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2132                )
2133                .await?;
2134                mz_mysql_util::TunnelConfig::Ssh {
2135                    config: SshTunnelConfig {
2136                        host: resolved
2137                            .iter()
2138                            .map(|a| a.to_string())
2139                            .collect::<BTreeSet<_>>(),
2140                        port: connection.port,
2141                        user: connection.user.clone(),
2142                        key_pair,
2143                    },
2144                }
2145            }
2146            Tunnel::AwsPrivatelink(connection) => {
2147                assert_none!(connection.port);
2148                mz_mysql_util::TunnelConfig::AwsPrivatelink {
2149                    connection_id: connection.connection_id,
2150                }
2151            }
2152            Tunnel::AwsPrivatelinks(_) => {
2153                unreachable!("MATCHING broker rules are only available for Kafka connections.");
2154            }
2155        };
2156
2157        let aws_config = match self.aws_connection.as_ref() {
2158            None => None,
2159            Some(aws_ref) => Some(
2160                aws_ref
2161                    .connection
2162                    .load_sdk_config(
2163                        &storage_configuration.connection_context,
2164                        aws_ref.connection_id,
2165                        in_task,
2166                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2167                    )
2168                    .await?,
2169            ),
2170        };
2171
2172        Ok(mz_mysql_util::Config::new(
2173            opts,
2174            tunnel,
2175            storage_configuration.parameters.ssh_timeout_config,
2176            in_task,
2177            storage_configuration
2178                .parameters
2179                .mysql_source_timeouts
2180                .clone(),
2181            aws_config,
2182        )?)
2183    }
2184
2185    pub async fn validate(
2186        &self,
2187        _id: CatalogItemId,
2188        storage_configuration: &StorageConfiguration,
2189    ) -> Result<MySqlConn, MySqlConnectionValidationError> {
2190        let config = self
2191            .config(
2192                &storage_configuration.connection_context.secrets_reader,
2193                storage_configuration,
2194                // We are in a normal tokio context during validation, already.
2195                InTask::No,
2196            )
2197            .await?;
2198        let mut conn = config
2199            .connect(
2200                "connection validation",
2201                &storage_configuration.connection_context.ssh_tunnel_manager,
2202            )
2203            .await?;
2204
2205        // Check if the MySQL database is configured to allow row-based consistent GTID replication
2206        let mut setting_errors = vec![];
2207        let gtid_res = mz_mysql_util::ensure_gtid_consistency(&mut conn).await;
2208        let binlog_res = mz_mysql_util::ensure_full_row_binlog_format(&mut conn).await;
2209        let order_res = mz_mysql_util::ensure_replication_commit_order(&mut conn).await;
2210        for res in [gtid_res, binlog_res, order_res] {
2211            match res {
2212                Err(MySqlError::InvalidSystemSetting {
2213                    setting,
2214                    expected,
2215                    actual,
2216                }) => {
2217                    setting_errors.push((setting, expected, actual));
2218                }
2219                Err(err) => Err(err)?,
2220                Ok(()) => {}
2221            }
2222        }
2223        if !setting_errors.is_empty() {
2224            Err(MySqlConnectionValidationError::ReplicationSettingsError(
2225                setting_errors,
2226            ))?;
2227        }
2228
2229        Ok(conn)
2230    }
2231}
2232
2233#[derive(Debug, thiserror::Error)]
2234pub enum MySqlConnectionValidationError {
2235    #[error("Invalid MySQL system replication settings")]
2236    ReplicationSettingsError(Vec<(String, String, String)>),
2237    #[error(transparent)]
2238    Client(#[from] MySqlError),
2239    #[error("{}", .0.display_with_causes())]
2240    Other(#[from] anyhow::Error),
2241}
2242
2243impl MySqlConnectionValidationError {
2244    pub fn detail(&self) -> Option<String> {
2245        match self {
2246            Self::ReplicationSettingsError(settings) => Some(format!(
2247                "Invalid MySQL system replication settings: {}",
2248                itertools::join(
2249                    settings.iter().map(|(setting, expected, actual)| format!(
2250                        "{}: expected {}, got {}",
2251                        setting, expected, actual
2252                    )),
2253                    "; "
2254                )
2255            )),
2256            _ => None,
2257        }
2258    }
2259
2260    pub fn hint(&self) -> Option<String> {
2261        match self {
2262            Self::ReplicationSettingsError(_) => {
2263                Some("Set the necessary MySQL database system settings.".into())
2264            }
2265            _ => None,
2266        }
2267    }
2268}
2269
2270impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
2271    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2272        let MySqlConnection {
2273            tunnel,
2274            // All non-tunnel options may change arbitrarily
2275            host: _,
2276            port: _,
2277            user: _,
2278            password: _,
2279            tls_mode: _,
2280            tls_root_cert: _,
2281            tls_identity: _,
2282            aws_connection: _,
2283        } = self;
2284
2285        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2286
2287        for (compatible, field) in compatibility_checks {
2288            if !compatible {
2289                tracing::warn!(
2290                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2291                    self,
2292                    other
2293                );
2294
2295                return Err(AlterError { id });
2296            }
2297        }
2298        Ok(())
2299    }
2300}
2301
2302/// Details how to connect to an instance of Microsoft SQL Server.
2303///
2304/// For specifics of connecting to SQL Server for purposes of creating a
2305/// Materialize Source, see [`SqlServerSourceConnection`] which wraps this type.
2306///
2307/// [`SqlServerSourceConnection`]: crate::sources::SqlServerSourceConnection
2308#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2309pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2310    /// The hostname of the server.
2311    pub host: String,
2312    /// The port of the server.
2313    pub port: u16,
2314    /// Database we should connect to.
2315    pub database: String,
2316    /// The username to authenticate as.
2317    pub user: StringOrSecret,
2318    /// Password used for authentication.
2319    pub password: CatalogItemId,
2320    /// A tunnel through which to route traffic.
2321    pub tunnel: Tunnel<C>,
2322    /// Level of encryption to use for the connection.
2323    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2324    /// Certificate validation policy
2325    pub certificate_validation_policy: mz_sql_server_util::config::CertificateValidationPolicy,
2326    /// TLS CA Certifiecate in PEM format
2327    pub tls_root_cert: Option<StringOrSecret>,
2328}
2329
2330impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2331    fn validate_by_default(&self) -> bool {
2332        true
2333    }
2334}
2335
2336impl SqlServerConnectionDetails<InlinedConnection> {
2337    /// Attempts to open a connection to the upstream SQL Server instance.
2338    pub async fn validate(
2339        &self,
2340        _id: CatalogItemId,
2341        storage_configuration: &StorageConfiguration,
2342    ) -> Result<mz_sql_server_util::Client, anyhow::Error> {
2343        let config = self
2344            .resolve_config(
2345                &storage_configuration.connection_context.secrets_reader,
2346                storage_configuration,
2347                InTask::No,
2348            )
2349            .await?;
2350        tracing::debug!(?config, "Validating SQL Server connection");
2351
2352        let mut client = mz_sql_server_util::Client::connect(config).await?;
2353
2354        // Ensure the upstream SQL Server instance is configured to allow CDC.
2355        //
2356        // Run all of the checks necessary and collect the errors to provide the best
2357        // guidance as to which system settings need to be enabled.
2358        let mut replication_errors = vec![];
2359        for error in [
2360            mz_sql_server_util::inspect::ensure_database_cdc_enabled(&mut client).await,
2361            mz_sql_server_util::inspect::ensure_snapshot_isolation_enabled(&mut client).await,
2362            mz_sql_server_util::inspect::ensure_sql_server_agent_running(&mut client).await,
2363        ] {
2364            match error {
2365                Err(mz_sql_server_util::SqlServerError::InvalidSystemSetting {
2366                    name,
2367                    expected,
2368                    actual,
2369                }) => replication_errors.push((name, expected, actual)),
2370                Err(other) => Err(other)?,
2371                Ok(()) => (),
2372            }
2373        }
2374        if !replication_errors.is_empty() {
2375            Err(SqlServerConnectionValidationError::ReplicationSettingsError(replication_errors))?;
2376        }
2377
2378        Ok(client)
2379    }
2380
2381    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2382    /// so the returned [`Config`] can be used to open a connection with the
2383    /// upstream system.
2384    ///
2385    /// The provided [`InTask`] argument determines whether any I/O is run in an
2386    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2387    /// future. The main goal here is to prevent running I/O in timely threads.
2388    ///
2389    /// [`Config`]: mz_sql_server_util::Config
2390    pub async fn resolve_config(
2391        &self,
2392        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2393        storage_configuration: &StorageConfiguration,
2394        in_task: InTask,
2395    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2396        let dyncfg = storage_configuration.config_set();
2397        let mut inner_config = tiberius::Config::new();
2398
2399        // Setup default connection params.
2400        inner_config.host(&self.host);
2401        inner_config.port(self.port);
2402        inner_config.database(self.database.clone());
2403        inner_config.encryption(self.encryption.into());
2404        match self.certificate_validation_policy {
2405            mz_sql_server_util::config::CertificateValidationPolicy::TrustAll => {
2406                inner_config.trust_cert()
2407            }
2408            mz_sql_server_util::config::CertificateValidationPolicy::VerifyCA => {
2409                inner_config.trust_cert_ca_pem(
2410                    self.tls_root_cert
2411                        .as_ref()
2412                        .unwrap()
2413                        .get_string(in_task, secrets_reader)
2414                        .await
2415                        .context("ca certificate")?,
2416                );
2417            }
2418            mz_sql_server_util::config::CertificateValidationPolicy::VerifySystem => (), // no-op
2419        }
2420
2421        inner_config.application_name("materialize");
2422
2423        // Read our auth settings from
2424        let user = self
2425            .user
2426            .get_string(in_task, secrets_reader)
2427            .await
2428            .context("username")?;
2429        let password = secrets_reader
2430            .read_string_in_task_if(in_task, self.password)
2431            .await
2432            .context("password")?;
2433        // TODO(sql_server3): Support other methods of authentication besides
2434        // username and password.
2435        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2436
2437        // Prevent users from probing our internal network ports by trying to
2438        // connect to localhost, or another non-external IP.
2439        let enforce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2440
2441        let tunnel = match &self.tunnel {
2442            Tunnel::Direct => {
2443                let resolved_addresses: Vec<SocketAddr> =
2444                    resolve_address(&self.host, enforce_external_addresses)
2445                        .await?
2446                        .into_iter()
2447                        .map(|ip| SocketAddr::new(ip, self.port))
2448                        .collect();
2449                mz_sql_server_util::config::TunnelConfig::Direct {
2450                    resolved_addresses: resolved_addresses.into_boxed_slice(),
2451                }
2452            }
2453            Tunnel::Ssh(SshTunnel {
2454                connection_id,
2455                connection: ssh_connection,
2456            }) => {
2457                let secret = secrets_reader
2458                    .read_in_task_if(in_task, *connection_id)
2459                    .await
2460                    .context("ssh secret")?;
2461                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2462                // Ensure any SSH-bastion host we connect to is resolved to an
2463                // external address.
2464                let addresses = resolve_address(&ssh_connection.host, enforce_external_addresses)
2465                    .await
2466                    .context("ssh tunnel")?;
2467
2468                let config = SshTunnelConfig {
2469                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2470                    port: ssh_connection.port,
2471                    user: ssh_connection.user.clone(),
2472                    key_pair,
2473                };
2474                mz_sql_server_util::config::TunnelConfig::Ssh {
2475                    config,
2476                    manager: storage_configuration
2477                        .connection_context
2478                        .ssh_tunnel_manager
2479                        .clone(),
2480                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2481                    host: self.host.clone(),
2482                    port: self.port,
2483                }
2484            }
2485            Tunnel::AwsPrivatelink(private_link_connection) => {
2486                assert_none!(private_link_connection.port);
2487                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2488                    connection_id: private_link_connection.connection_id,
2489                    port: self.port,
2490                }
2491            }
2492            Tunnel::AwsPrivatelinks(_) => {
2493                unreachable!("MATCHING broker rules are only available for Kafka connections.");
2494            }
2495        };
2496
2497        Ok(mz_sql_server_util::Config::new(
2498            inner_config,
2499            tunnel,
2500            in_task,
2501        ))
2502    }
2503}
2504
2505#[derive(Debug, Clone, thiserror::Error)]
2506pub enum SqlServerConnectionValidationError {
2507    #[error("Invalid SQL Server system replication settings")]
2508    ReplicationSettingsError(Vec<(String, String, String)>),
2509}
2510
2511impl SqlServerConnectionValidationError {
2512    pub fn detail(&self) -> Option<String> {
2513        match self {
2514            Self::ReplicationSettingsError(settings) => Some(format!(
2515                "Invalid SQL Server system replication settings: {}",
2516                itertools::join(
2517                    settings.iter().map(|(setting, expected, actual)| format!(
2518                        "{}: expected {}, got {}",
2519                        setting, expected, actual
2520                    )),
2521                    "; "
2522                )
2523            )),
2524        }
2525    }
2526
2527    pub fn hint(&self) -> Option<String> {
2528        match self {
2529            _ => None,
2530        }
2531    }
2532}
2533
2534impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2535    for SqlServerConnectionDetails<ReferencedConnection>
2536{
2537    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2538        let SqlServerConnectionDetails {
2539            host,
2540            port,
2541            database,
2542            user,
2543            password,
2544            tunnel,
2545            encryption,
2546            certificate_validation_policy,
2547            tls_root_cert,
2548        } = self;
2549
2550        SqlServerConnectionDetails {
2551            host,
2552            port,
2553            database,
2554            user,
2555            password,
2556            tunnel: tunnel.into_inline_connection(&r),
2557            encryption,
2558            certificate_validation_policy,
2559            tls_root_cert,
2560        }
2561    }
2562}
2563
2564impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2565    fn alter_compatible(
2566        &self,
2567        id: mz_repr::GlobalId,
2568        other: &Self,
2569    ) -> Result<(), crate::controller::AlterError> {
2570        let SqlServerConnectionDetails {
2571            tunnel,
2572            // TODO(sql_server2): Figure out how these variables are allowed to change.
2573            host: _,
2574            port: _,
2575            database: _,
2576            user: _,
2577            password: _,
2578            encryption: _,
2579            certificate_validation_policy: _,
2580            tls_root_cert: _,
2581        } = self;
2582
2583        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2584
2585        for (compatible, field) in compatibility_checks {
2586            if !compatible {
2587                tracing::warn!(
2588                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2589                    self,
2590                    other
2591                );
2592
2593                return Err(AlterError { id });
2594            }
2595        }
2596        Ok(())
2597    }
2598}
2599
2600/// A connection to an SSH tunnel.
2601#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2602pub struct SshConnection {
2603    pub host: String,
2604    pub port: u16,
2605    pub user: String,
2606}
2607
2608use self::inline::{
2609    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2610    ReferencedConnection,
2611};
2612
2613impl AlterCompatible for SshConnection {
2614    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2615        // Every element of the SSH connection is configurable.
2616        Ok(())
2617    }
2618}
2619
2620/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2621#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2622pub struct AwsPrivatelink {
2623    /// The ID of the connection to the AWS PrivateLink service.
2624    pub connection_id: CatalogItemId,
2625    // The availability zone to use when connecting to the AWS PrivateLink service.
2626    pub availability_zone: Option<String>,
2627    /// The port to use when connecting to the AWS PrivateLink service, if
2628    /// different from the port in [`KafkaBroker::address`].
2629    pub port: Option<u16>,
2630}
2631
2632impl AlterCompatible for AwsPrivatelink {
2633    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2634        let AwsPrivatelink {
2635            connection_id,
2636            availability_zone: _,
2637            port: _,
2638        } = self;
2639
2640        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2641
2642        for (compatible, field) in compatibility_checks {
2643            if !compatible {
2644                tracing::warn!(
2645                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2646                    self,
2647                    other
2648                );
2649
2650                return Err(AlterError { id });
2651            }
2652        }
2653
2654        Ok(())
2655    }
2656}
2657
2658#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2659pub struct AwsPrivatelinks {
2660    /// Route to brokers through PrivateLink connections according to these rules.
2661    /// Exact-match rules (no wildcards) are used as bootstrap brokers.
2662    /// Wildcard rules are applied dynamically to discovered brokers.
2663    pub rules: Vec<AwsPrivatelinkRule>,
2664}
2665
2666#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2667pub struct AwsPrivatelinkRule {
2668    /// Given a broker's host:port, should we use this route?
2669    pub pattern: ConnectionRulePattern,
2670    /// Route to the broker through this PrivateLink connection.
2671    pub to: AwsPrivatelink,
2672}
2673
2674/// Specifies an SSH tunnel connection.
2675#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2676pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2677    /// id of the ssh connection
2678    pub connection_id: CatalogItemId,
2679    /// ssh connection object
2680    pub connection: C::Ssh,
2681}
2682
2683impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2684    fn into_inline_connection(self, r: R) -> SshTunnel {
2685        let SshTunnel {
2686            connection,
2687            connection_id,
2688        } = self;
2689
2690        SshTunnel {
2691            connection: r.resolve_connection(connection).unwrap_ssh(),
2692            connection_id,
2693        }
2694    }
2695}
2696
2697impl SshTunnel<InlinedConnection> {
2698    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2699    /// secret.
2700    async fn connect(
2701        &self,
2702        storage_configuration: &StorageConfiguration,
2703        remote_host: &str,
2704        remote_port: u16,
2705        in_task: InTask,
2706    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2707        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2708        let resolved = resolve_address(
2709            &self.connection.host,
2710            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2711        )
2712        .await?;
2713        storage_configuration
2714            .connection_context
2715            .ssh_tunnel_manager
2716            .connect(
2717                SshTunnelConfig {
2718                    host: resolved
2719                        .iter()
2720                        .map(|a| a.to_string())
2721                        .collect::<BTreeSet<_>>(),
2722                    port: self.connection.port,
2723                    user: self.connection.user.clone(),
2724                    key_pair: SshKeyPair::from_bytes(
2725                        &storage_configuration
2726                            .connection_context
2727                            .secrets_reader
2728                            .read_in_task_if(in_task, self.connection_id)
2729                            .await?,
2730                    )?,
2731                },
2732                remote_host,
2733                remote_port,
2734                storage_configuration.parameters.ssh_timeout_config,
2735                in_task,
2736            )
2737            .await
2738    }
2739}
2740
2741impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2742    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2743        let SshTunnel {
2744            connection_id,
2745            connection,
2746        } = self;
2747
2748        let compatibility_checks = [
2749            (connection_id == &other.connection_id, "connection_id"),
2750            (
2751                connection.alter_compatible(id, &other.connection).is_ok(),
2752                "connection",
2753            ),
2754        ];
2755
2756        for (compatible, field) in compatibility_checks {
2757            if !compatible {
2758                tracing::warn!(
2759                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2760                    self,
2761                    other
2762                );
2763
2764                return Err(AlterError { id });
2765            }
2766        }
2767
2768        Ok(())
2769    }
2770}
2771
2772impl SshConnection {
2773    #[allow(clippy::unused_async)]
2774    async fn validate(
2775        &self,
2776        id: CatalogItemId,
2777        storage_configuration: &StorageConfiguration,
2778    ) -> Result<(), anyhow::Error> {
2779        let secret = storage_configuration
2780            .connection_context
2781            .secrets_reader
2782            .read_in_task_if(
2783                // We are in a normal tokio context during validation, already.
2784                InTask::No,
2785                id,
2786            )
2787            .await?;
2788        let key_pair = SshKeyPair::from_bytes(&secret)?;
2789
2790        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2791        let resolved = resolve_address(
2792            &self.host,
2793            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2794        )
2795        .await?;
2796
2797        let config = SshTunnelConfig {
2798            host: resolved
2799                .iter()
2800                .map(|a| a.to_string())
2801                .collect::<BTreeSet<_>>(),
2802            port: self.port,
2803            user: self.user.clone(),
2804            key_pair,
2805        };
2806        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2807        // can actually create a new connection to the ssh bastion, without tunneling.
2808        config
2809            .validate(storage_configuration.parameters.ssh_timeout_config)
2810            .await
2811    }
2812
2813    fn validate_by_default(&self) -> bool {
2814        false
2815    }
2816}
2817
2818impl AwsPrivatelinkConnection {
2819    #[allow(clippy::unused_async)]
2820    async fn validate(
2821        &self,
2822        id: CatalogItemId,
2823        storage_configuration: &StorageConfiguration,
2824    ) -> Result<(), anyhow::Error> {
2825        let Some(ref cloud_resource_reader) = storage_configuration
2826            .connection_context
2827            .cloud_resource_reader
2828        else {
2829            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2830        };
2831
2832        // No need to optionally run this in a task, as we are just validating from envd.
2833        let status = cloud_resource_reader.read(id).await?;
2834
2835        let availability = status
2836            .conditions
2837            .as_ref()
2838            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2839
2840        match availability {
2841            Some(condition) if condition.status == "True" => Ok(()),
2842            Some(condition) => Err(anyhow!("{}", condition.message)),
2843            None => Err(anyhow!("Endpoint availability is unknown")),
2844        }
2845    }
2846
2847    fn validate_by_default(&self) -> bool {
2848        false
2849    }
2850}