mz_storage_types/
connections.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Connection types.
11
12use std::borrow::Cow;
13use std::collections::{BTreeMap, BTreeSet};
14use std::net::SocketAddr;
15use std::sync::Arc;
16
17use anyhow::{Context, anyhow, bail};
18use itertools::Itertools;
19use mz_ccsr::tls::{Certificate, Identity};
20use mz_cloud_resources::{AwsExternalIdPrefix, CloudResourceReader, vpc_endpoint_host};
21use mz_dyncfg::ConfigSet;
22use mz_kafka_util::client::{
23    BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
24};
25use mz_ore::assert_none;
26use mz_ore::error::ErrorExt;
27use mz_ore::future::{InTask, OreFutureExt};
28use mz_ore::netio::resolve_address;
29use mz_ore::num::NonNeg;
30use mz_postgres_util::tunnel::PostgresFlavor;
31use mz_proto::tokio_postgres::any_ssl_mode;
32use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
33use mz_repr::url::any_url;
34use mz_repr::{CatalogItemId, GlobalId};
35use mz_secrets::SecretsReader;
36use mz_ssh_util::keys::SshKeyPair;
37use mz_ssh_util::tunnel::SshTunnelConfig;
38use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
39use mz_tls_util::Pkcs12Archive;
40use mz_tracing::CloneableEnvFilter;
41use proptest::strategy::Strategy;
42use proptest_derive::Arbitrary;
43use rdkafka::ClientContext;
44use rdkafka::config::FromClientConfigAndContext;
45use rdkafka::consumer::{BaseConsumer, Consumer};
46use regex::Regex;
47use serde::{Deserialize, Deserializer, Serialize};
48use tokio::net;
49use tokio::runtime::Handle;
50use tokio_postgres::config::SslMode;
51use tracing::{debug, warn};
52use url::Url;
53
54use crate::AlterCompatible;
55use crate::configuration::StorageConfiguration;
56use crate::connections::aws::{
57    AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
58};
59use crate::connections::string_or_secret::StringOrSecret;
60use crate::controller::AlterError;
61use crate::dyncfgs::{
62    ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
63    KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM,
64};
65use crate::errors::{ContextCreationError, CsrConnectError};
66
67pub mod aws;
68pub mod inline;
69pub mod string_or_secret;
70
71include!(concat!(env!("OUT_DIR"), "/mz_storage_types.connections.rs"));
72
73/// An extension trait for [`SecretsReader`]
74#[async_trait::async_trait]
75trait SecretsReaderExt {
76    /// `SecretsReader::read`, but optionally run in a task.
77    async fn read_in_task_if(
78        &self,
79        in_task: InTask,
80        id: CatalogItemId,
81    ) -> Result<Vec<u8>, anyhow::Error>;
82
83    /// `SecretsReader::read_string`, but optionally run in a task.
84    async fn read_string_in_task_if(
85        &self,
86        in_task: InTask,
87        id: CatalogItemId,
88    ) -> Result<String, anyhow::Error>;
89}
90
91#[async_trait::async_trait]
92impl SecretsReaderExt for Arc<dyn SecretsReader> {
93    async fn read_in_task_if(
94        &self,
95        in_task: InTask,
96        id: CatalogItemId,
97    ) -> Result<Vec<u8>, anyhow::Error> {
98        let sr = Arc::clone(self);
99        async move { sr.read(id).await }
100            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
101            .await
102    }
103    async fn read_string_in_task_if(
104        &self,
105        in_task: InTask,
106        id: CatalogItemId,
107    ) -> Result<String, anyhow::Error> {
108        let sr = Arc::clone(self);
109        async move { sr.read_string(id).await }
110            .run_in_task_if(in_task, || "secrets_reader_read".to_string())
111            .await
112    }
113}
114
115/// Extra context to pass through when instantiating a connection for a source
116/// or sink.
117///
118/// Should be kept cheaply cloneable.
119#[derive(Debug, Clone)]
120pub struct ConnectionContext {
121    /// An opaque identifier for the environment in which this process is
122    /// running.
123    ///
124    /// The storage layer is intentionally unaware of the structure within this
125    /// identifier. Higher layers of the stack can make use of that structure,
126    /// but the storage layer should be oblivious to it.
127    pub environment_id: String,
128    /// The level for librdkafka's logs.
129    pub librdkafka_log_level: tracing::Level,
130    /// A prefix for an external ID to use for all AWS AssumeRole operations.
131    pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
132    /// The ARN for a Materialize-controlled role to assume before assuming
133    /// a customer's requested role for an AWS connection.
134    pub aws_connection_role_arn: Option<String>,
135    /// A secrets reader.
136    pub secrets_reader: Arc<dyn SecretsReader>,
137    /// A cloud resource reader, if supported in this configuration.
138    pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
139    /// A manager for SSH tunnels.
140    pub ssh_tunnel_manager: SshTunnelManager,
141}
142
143impl ConnectionContext {
144    /// Constructs a new connection context from command line arguments.
145    ///
146    /// **WARNING:** it is critical for security that the `aws_external_id` be
147    /// provided by the operator of the Materialize service (i.e., via a CLI
148    /// argument or environment variable) and not the end user of Materialize
149    /// (e.g., via a configuration option in a SQL statement). See
150    /// [`AwsExternalIdPrefix`] for details.
151    pub fn from_cli_args(
152        environment_id: String,
153        startup_log_level: &CloneableEnvFilter,
154        aws_external_id_prefix: Option<AwsExternalIdPrefix>,
155        aws_connection_role_arn: Option<String>,
156        secrets_reader: Arc<dyn SecretsReader>,
157        cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
158    ) -> ConnectionContext {
159        ConnectionContext {
160            environment_id,
161            librdkafka_log_level: mz_ore::tracing::crate_level(
162                &startup_log_level.clone().into(),
163                "librdkafka",
164            ),
165            aws_external_id_prefix,
166            aws_connection_role_arn,
167            secrets_reader,
168            cloud_resource_reader,
169            ssh_tunnel_manager: SshTunnelManager::default(),
170        }
171    }
172
173    /// Constructs a new connection context for usage in tests.
174    pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
175        ConnectionContext {
176            environment_id: "test-environment-id".into(),
177            librdkafka_log_level: tracing::Level::INFO,
178            aws_external_id_prefix: Some(
179                AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
180                    "test-aws-external-id-prefix",
181                )
182                .expect("infallible"),
183            ),
184            aws_connection_role_arn: Some(
185                "arn:aws:iam::123456789000:role/MaterializeConnection".into(),
186            ),
187            secrets_reader,
188            cloud_resource_reader: None,
189            ssh_tunnel_manager: SshTunnelManager::default(),
190        }
191    }
192}
193
194#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
195pub enum Connection<C: ConnectionAccess = InlinedConnection> {
196    Kafka(KafkaConnection<C>),
197    Csr(CsrConnection<C>),
198    Postgres(PostgresConnection<C>),
199    Ssh(SshConnection),
200    Aws(AwsConnection),
201    AwsPrivatelink(AwsPrivatelinkConnection),
202    MySql(MySqlConnection<C>),
203    SqlServer(SqlServerConnectionDetails<C>),
204}
205
206impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
207    for Connection<ReferencedConnection>
208{
209    fn into_inline_connection(self, r: R) -> Connection {
210        match self {
211            Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
212            Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
213            Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
214            Connection::Ssh(ssh) => Connection::Ssh(ssh),
215            Connection::Aws(aws) => Connection::Aws(aws),
216            Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
217            Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
218            Connection::SqlServer(sql_server) => {
219                Connection::SqlServer(sql_server.into_inline_connection(r))
220            }
221        }
222    }
223}
224
225impl<C: ConnectionAccess> Connection<C> {
226    /// Whether this connection should be validated by default on creation.
227    pub fn validate_by_default(&self) -> bool {
228        match self {
229            Connection::Kafka(conn) => conn.validate_by_default(),
230            Connection::Csr(conn) => conn.validate_by_default(),
231            Connection::Postgres(conn) => conn.validate_by_default(),
232            Connection::Ssh(conn) => conn.validate_by_default(),
233            Connection::Aws(conn) => conn.validate_by_default(),
234            Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
235            Connection::MySql(conn) => conn.validate_by_default(),
236            Connection::SqlServer(conn) => conn.validate_by_default(),
237        }
238    }
239}
240
241impl Connection<InlinedConnection> {
242    /// Validates this connection by attempting to connect to the upstream system.
243    pub async fn validate(
244        &self,
245        id: CatalogItemId,
246        storage_configuration: &StorageConfiguration,
247    ) -> Result<(), ConnectionValidationError> {
248        match self {
249            Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
250            Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
251            Connection::Postgres(conn) => conn.validate(id, storage_configuration).await?,
252            Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
253            Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
254            Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
255            Connection::MySql(conn) => conn.validate(id, storage_configuration).await?,
256            Connection::SqlServer(conn) => conn.validate(id, storage_configuration).await?,
257        }
258        Ok(())
259    }
260
261    pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
262        match self {
263            Self::Kafka(conn) => conn,
264            o => unreachable!("{o:?} is not a Kafka connection"),
265        }
266    }
267
268    pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
269        match self {
270            Self::Postgres(conn) => conn,
271            o => unreachable!("{o:?} is not a Postgres connection"),
272        }
273    }
274
275    pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
276        match self {
277            Self::MySql(conn) => conn,
278            o => unreachable!("{o:?} is not a MySQL connection"),
279        }
280    }
281
282    pub fn unwrap_sql_server(self) -> <InlinedConnection as ConnectionAccess>::SqlServer {
283        match self {
284            Self::SqlServer(conn) => conn,
285            o => unreachable!("{o:?} is not a SQL Server connection"),
286        }
287    }
288
289    pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
290        match self {
291            Self::Aws(conn) => conn,
292            o => unreachable!("{o:?} is not an AWS connection"),
293        }
294    }
295
296    pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
297        match self {
298            Self::Ssh(conn) => conn,
299            o => unreachable!("{o:?} is not an SSH connection"),
300        }
301    }
302
303    pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
304        match self {
305            Self::Csr(conn) => conn,
306            o => unreachable!("{o:?} is not a Kafka connection"),
307        }
308    }
309}
310
311/// An error returned by [`Connection::validate`].
312#[derive(thiserror::Error, Debug)]
313pub enum ConnectionValidationError {
314    #[error(transparent)]
315    Aws(#[from] AwsConnectionValidationError),
316    #[error("{}", .0.display_with_causes())]
317    Other(#[from] anyhow::Error),
318}
319
320impl ConnectionValidationError {
321    /// Reports additional details about the error, if any are available.
322    pub fn detail(&self) -> Option<String> {
323        match self {
324            ConnectionValidationError::Aws(e) => e.detail(),
325            ConnectionValidationError::Other(_) => None,
326        }
327    }
328
329    /// Reports a hint for the user about how the error could be fixed.
330    pub fn hint(&self) -> Option<String> {
331        match self {
332            ConnectionValidationError::Aws(e) => e.hint(),
333            ConnectionValidationError::Other(_) => None,
334        }
335    }
336}
337
338impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
339    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
340        match (self, other) {
341            (Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
342            (Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
343            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
344            (Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
345            (Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
346            (Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
347            (Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
348            _ => {
349                tracing::warn!(
350                    "Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
351                    self,
352                    other
353                );
354                Err(AlterError { id })
355            }
356        }
357    }
358}
359
360#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
361pub struct AwsPrivatelinkConnection {
362    pub service_name: String,
363    pub availability_zones: Vec<String>,
364}
365
366impl AlterCompatible for AwsPrivatelinkConnection {
367    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
368        // Every element of the AwsPrivatelinkConnection connection is configurable.
369        Ok(())
370    }
371}
372
373#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
374pub struct KafkaTlsConfig {
375    pub identity: Option<TlsIdentity>,
376    pub root_cert: Option<StringOrSecret>,
377}
378
379#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
380pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
381    pub mechanism: String,
382    pub username: StringOrSecret,
383    pub password: Option<CatalogItemId>,
384    pub aws: Option<AwsConnectionReference<C>>,
385}
386
387impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
388    for KafkaSaslConfig<ReferencedConnection>
389{
390    fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
391        KafkaSaslConfig {
392            mechanism: self.mechanism,
393            username: self.username,
394            password: self.password,
395            aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
396        }
397    }
398}
399
400/// Specifies a Kafka broker in a [`KafkaConnection`].
401#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
402pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
403    /// The address of the Kafka broker.
404    pub address: String,
405    /// An optional tunnel to use when connecting to the broker.
406    pub tunnel: Tunnel<C>,
407}
408
409impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
410    for KafkaBroker<ReferencedConnection>
411{
412    fn into_inline_connection(self, r: R) -> KafkaBroker {
413        let KafkaBroker { address, tunnel } = self;
414        KafkaBroker {
415            address,
416            tunnel: tunnel.into_inline_connection(r),
417        }
418    }
419}
420
421impl RustType<ProtoKafkaBroker> for KafkaBroker {
422    fn into_proto(&self) -> ProtoKafkaBroker {
423        ProtoKafkaBroker {
424            address: self.address.into_proto(),
425            tunnel: Some(self.tunnel.into_proto()),
426        }
427    }
428
429    fn from_proto(proto: ProtoKafkaBroker) -> Result<Self, TryFromProtoError> {
430        Ok(KafkaBroker {
431            address: proto.address.into_rust()?,
432            tunnel: proto
433                .tunnel
434                .into_rust_if_some("ProtoKafkaConnection::tunnel")?,
435        })
436    }
437}
438
439#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
440pub struct KafkaTopicOptions {
441    /// The replication factor for the topic.
442    /// If `None`, the broker default will be used.
443    pub replication_factor: Option<NonNeg<i32>>,
444    /// The number of partitions to create.
445    /// If `None`, the broker default will be used.
446    pub partition_count: Option<NonNeg<i32>>,
447    /// The initial configuration parameters for the topic.
448    pub topic_config: BTreeMap<String, String>,
449}
450
451impl RustType<ProtoKafkaTopicOptions> for KafkaTopicOptions {
452    fn into_proto(&self) -> ProtoKafkaTopicOptions {
453        ProtoKafkaTopicOptions {
454            replication_factor: self.replication_factor.map(|f| *f),
455            partition_count: self.partition_count.map(|f| *f),
456            topic_config: self.topic_config.clone(),
457        }
458    }
459
460    fn from_proto(proto: ProtoKafkaTopicOptions) -> Result<Self, TryFromProtoError> {
461        Ok(KafkaTopicOptions {
462            replication_factor: proto.replication_factor.map(NonNeg::try_from).transpose()?,
463            partition_count: proto.partition_count.map(NonNeg::try_from).transpose()?,
464            topic_config: proto.topic_config,
465        })
466    }
467}
468
469#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
470pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
471    pub brokers: Vec<KafkaBroker<C>>,
472    /// A tunnel through which to route traffic,
473    /// that can be overridden for individual brokers
474    /// in `brokers`.
475    pub default_tunnel: Tunnel<C>,
476    pub progress_topic: Option<String>,
477    pub progress_topic_options: KafkaTopicOptions,
478    pub options: BTreeMap<String, StringOrSecret>,
479    pub tls: Option<KafkaTlsConfig>,
480    pub sasl: Option<KafkaSaslConfig<C>>,
481}
482
483impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
484    for KafkaConnection<ReferencedConnection>
485{
486    fn into_inline_connection(self, r: R) -> KafkaConnection {
487        let KafkaConnection {
488            brokers,
489            progress_topic,
490            progress_topic_options,
491            default_tunnel,
492            options,
493            tls,
494            sasl,
495        } = self;
496
497        let brokers = brokers
498            .into_iter()
499            .map(|broker| broker.into_inline_connection(&r))
500            .collect();
501
502        KafkaConnection {
503            brokers,
504            progress_topic,
505            progress_topic_options,
506            default_tunnel: default_tunnel.into_inline_connection(&r),
507            options,
508            tls,
509            sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
510        }
511    }
512}
513
514impl<C: ConnectionAccess> KafkaConnection<C> {
515    /// Returns the name of the progress topic to use for the connection.
516    ///
517    /// The caller is responsible for providing the connection ID as it is not
518    /// known to `KafkaConnection`.
519    pub fn progress_topic(
520        &self,
521        connection_context: &ConnectionContext,
522        connection_id: CatalogItemId,
523    ) -> Cow<str> {
524        if let Some(progress_topic) = &self.progress_topic {
525            Cow::Borrowed(progress_topic)
526        } else {
527            Cow::Owned(format!(
528                "_materialize-progress-{}-{}",
529                connection_context.environment_id, connection_id,
530            ))
531        }
532    }
533
534    fn validate_by_default(&self) -> bool {
535        true
536    }
537}
538
539impl KafkaConnection {
540    /// Generates a string that can be used as the base for a configuration ID
541    /// (e.g., `client.id`, `group.id`, `transactional.id`) for a Kafka source
542    /// or sink.
543    pub fn id_base(
544        connection_context: &ConnectionContext,
545        connection_id: CatalogItemId,
546        object_id: GlobalId,
547    ) -> String {
548        format!(
549            "materialize-{}-{}-{}",
550            connection_context.environment_id, connection_id, object_id,
551        )
552    }
553
554    /// Enriches the provided `client_id` according to any enrichment rules in
555    /// the `kafka_client_id_enrichment_rules` configuration parameter.
556    pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
557        #[derive(Debug, Deserialize)]
558        struct EnrichmentRule {
559            #[serde(deserialize_with = "deserialize_regex")]
560            pattern: Regex,
561            payload: String,
562        }
563
564        fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
565        where
566            D: Deserializer<'de>,
567        {
568            let buf = String::deserialize(deserializer)?;
569            Regex::new(&buf).map_err(serde::de::Error::custom)
570        }
571
572        let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
573        let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
574            Ok(rules) => rules,
575            Err(e) => {
576                warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
577                return;
578            }
579        };
580
581        // Check every rule against every broker. Rules are matched in the order
582        // that they are specified. It is usually a configuration error if
583        // multiple rules match the same list of Kafka brokers, but we
584        // nonetheless want to provide well defined semantics.
585        debug!(?self.brokers, "evaluating client ID enrichment rules");
586        for rule in rules {
587            let is_match = self
588                .brokers
589                .iter()
590                .any(|b| rule.pattern.is_match(&b.address));
591            debug!(?rule, is_match, "evaluated client ID enrichment rule");
592            if is_match {
593                client_id.push('-');
594                client_id.push_str(&rule.payload);
595            }
596        }
597    }
598
599    /// Creates a Kafka client for the connection.
600    pub async fn create_with_context<C, T>(
601        &self,
602        storage_configuration: &StorageConfiguration,
603        context: C,
604        extra_options: &BTreeMap<&str, String>,
605        in_task: InTask,
606    ) -> Result<T, ContextCreationError>
607    where
608        C: ClientContext,
609        T: FromClientConfigAndContext<TunnelingClientContext<C>>,
610    {
611        let mut options = self.options.clone();
612
613        // Ensure that Kafka topics are *not* automatically created when
614        // consuming, producing, or fetching metadata for a topic. This ensures
615        // that we don't accidentally create topics with the wrong number of
616        // partitions.
617        options.insert("allow.auto.create.topics".into(), "false".into());
618
619        let brokers = match &self.default_tunnel {
620            Tunnel::AwsPrivatelink(t) => {
621                assert!(&self.brokers.is_empty());
622
623                let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
624                    .get(storage_configuration.config_set());
625                options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
626
627                // When using a default privatelink tunnel broker/brokers cannot be specified
628                // instead the tunnel connection_id and port are used for the initial connection.
629                format!(
630                    "{}:{}",
631                    vpc_endpoint_host(
632                        t.connection_id,
633                        None, // Default tunnel does not support availability zones.
634                    ),
635                    t.port.unwrap_or(9092)
636                )
637            }
638            _ => self.brokers.iter().map(|b| &b.address).join(","),
639        };
640        options.insert("bootstrap.servers".into(), brokers.into());
641        let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
642            (false, false) => "PLAINTEXT",
643            (true, false) => "SSL",
644            (false, true) => "SASL_PLAINTEXT",
645            (true, true) => "SASL_SSL",
646        };
647        options.insert("security.protocol".into(), security_protocol.into());
648        if let Some(tls) = &self.tls {
649            if let Some(root_cert) = &tls.root_cert {
650                options.insert("ssl.ca.pem".into(), root_cert.clone());
651            }
652            if let Some(identity) = &tls.identity {
653                options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
654                options.insert("ssl.certificate.pem".into(), identity.cert.clone());
655            }
656        }
657        if let Some(sasl) = &self.sasl {
658            options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
659            options.insert("sasl.username".into(), sasl.username.clone());
660            if let Some(password) = sasl.password {
661                options.insert("sasl.password".into(), StringOrSecret::Secret(password));
662            }
663        }
664
665        let mut config = mz_kafka_util::client::create_new_client_config(
666            storage_configuration
667                .connection_context
668                .librdkafka_log_level,
669            storage_configuration.parameters.kafka_timeout_config,
670        );
671        for (k, v) in options {
672            config.set(
673                k,
674                v.get_string(
675                    in_task,
676                    &storage_configuration.connection_context.secrets_reader,
677                )
678                .await
679                .context("reading kafka secret")?,
680            );
681        }
682        for (k, v) in extra_options {
683            config.set(*k, v);
684        }
685
686        let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
687            None => None,
688            Some(aws) => Some(
689                aws.connection
690                    .load_sdk_config(
691                        &storage_configuration.connection_context,
692                        aws.connection_id,
693                        in_task,
694                    )
695                    .await?,
696            ),
697        };
698
699        // TODO(roshan): Implement enforcement of external address validation once
700        // rdkafka client has been updated to support providing multiple resolved
701        // addresses for brokers
702        let mut context = TunnelingClientContext::new(
703            context,
704            Handle::current(),
705            storage_configuration
706                .connection_context
707                .ssh_tunnel_manager
708                .clone(),
709            storage_configuration.parameters.ssh_timeout_config,
710            aws_config,
711            in_task,
712        );
713
714        match &self.default_tunnel {
715            Tunnel::Direct => {
716                // By default, don't offer a default override for broker address lookup.
717            }
718            Tunnel::AwsPrivatelink(pl) => {
719                context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
720                    pl.connection_id,
721                    None, // Default tunnel does not support availability zones.
722                )));
723            }
724            Tunnel::Ssh(ssh_tunnel) => {
725                let secret = storage_configuration
726                    .connection_context
727                    .secrets_reader
728                    .read_in_task_if(in_task, ssh_tunnel.connection_id)
729                    .await?;
730                let key_pair = SshKeyPair::from_bytes(&secret)?;
731
732                // Ensure any ssh-bastion address we connect to is resolved to an external address.
733                let resolved = resolve_address(
734                    &ssh_tunnel.connection.host,
735                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
736                )
737                .await?;
738                context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
739                    host: resolved
740                        .iter()
741                        .map(|a| a.to_string())
742                        .collect::<BTreeSet<_>>(),
743                    port: ssh_tunnel.connection.port,
744                    user: ssh_tunnel.connection.user.clone(),
745                    key_pair,
746                }));
747            }
748        }
749
750        for broker in &self.brokers {
751            let mut addr_parts = broker.address.splitn(2, ':');
752            let addr = BrokerAddr {
753                host: addr_parts
754                    .next()
755                    .context("BROKER is not address:port")?
756                    .into(),
757                port: addr_parts
758                    .next()
759                    .unwrap_or("9092")
760                    .parse()
761                    .context("parsing BROKER port")?,
762            };
763            match &broker.tunnel {
764                Tunnel::Direct => {
765                    // By default, don't override broker address lookup.
766                    //
767                    // N.B.
768                    //
769                    // We _could_ pre-setup the default ssh tunnel for all known brokers here, but
770                    // we avoid doing because:
771                    // - Its not necessary.
772                    // - Not doing so makes it easier to test the `FailedDefaultSshTunnel` path
773                    // in the `TunnelingClientContext`.
774                }
775                Tunnel::AwsPrivatelink(aws_privatelink) => {
776                    let host = mz_cloud_resources::vpc_endpoint_host(
777                        aws_privatelink.connection_id,
778                        aws_privatelink.availability_zone.as_deref(),
779                    );
780                    let port = aws_privatelink.port;
781                    context.add_broker_rewrite(
782                        addr,
783                        BrokerRewrite {
784                            host: host.clone(),
785                            port,
786                        },
787                    );
788                }
789                Tunnel::Ssh(ssh_tunnel) => {
790                    // Ensure any SSH bastion address we connect to is resolved to an external address.
791                    let ssh_host_resolved = resolve_address(
792                        &ssh_tunnel.connection.host,
793                        ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
794                    )
795                    .await?;
796                    context
797                        .add_ssh_tunnel(
798                            addr,
799                            SshTunnelConfig {
800                                host: ssh_host_resolved
801                                    .iter()
802                                    .map(|a| a.to_string())
803                                    .collect::<BTreeSet<_>>(),
804                                port: ssh_tunnel.connection.port,
805                                user: ssh_tunnel.connection.user.clone(),
806                                key_pair: SshKeyPair::from_bytes(
807                                    &storage_configuration
808                                        .connection_context
809                                        .secrets_reader
810                                        .read_in_task_if(in_task, ssh_tunnel.connection_id)
811                                        .await?,
812                                )?,
813                            },
814                        )
815                        .await
816                        .map_err(ContextCreationError::Ssh)?;
817                }
818            }
819        }
820
821        Ok(config.create_with_context(context)?)
822    }
823
824    async fn validate(
825        &self,
826        _id: CatalogItemId,
827        storage_configuration: &StorageConfiguration,
828    ) -> Result<(), anyhow::Error> {
829        let (context, error_rx) = MzClientContext::with_errors();
830        let consumer: BaseConsumer<_> = self
831            .create_with_context(
832                storage_configuration,
833                context,
834                &BTreeMap::new(),
835                // We are in a normal tokio context during validation, already.
836                InTask::No,
837            )
838            .await?;
839        let consumer = Arc::new(consumer);
840
841        let timeout = storage_configuration
842            .parameters
843            .kafka_timeout_config
844            .fetch_metadata_timeout;
845
846        // librdkafka doesn't expose an API for determining whether a connection to
847        // the Kafka cluster has been successfully established. So we make a
848        // metadata request, though we don't care about the results, so that we can
849        // report any errors making that request. If the request succeeds, we know
850        // we were able to contact at least one broker, and that's a good proxy for
851        // being able to contact all the brokers in the cluster.
852        //
853        // The downside of this approach is it produces a generic error message like
854        // "metadata fetch error" with no additional details. The real networking
855        // error is buried in the librdkafka logs, which are not visible to users.
856        let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
857            let consumer = Arc::clone(&consumer);
858            move || consumer.fetch_metadata(None, timeout)
859        })
860        .await?;
861        match result {
862            Ok(_) => Ok(()),
863            // The error returned by `fetch_metadata` does not provide any details which makes for
864            // a crappy user facing error message. For this reason we attempt to grab a better
865            // error message from the client context, which should contain any error logs emitted
866            // by librdkafka, and fallback to the generic error if there is nothing there.
867            Err(err) => {
868                // Multiple errors might have been logged during this validation but some are more
869                // relevant than others. Specifically, we prefer non-internal errors over internal
870                // errors since those give much more useful information to the users.
871                let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
872                    MzKafkaError::Internal(_) => new,
873                    _ => cur,
874                });
875
876                // Don't drop the consumer until after we've drained the errors
877                // channel. Dropping the consumer can introduce spurious errors.
878                // See database-issues#7432.
879                drop(consumer);
880
881                match main_err {
882                    Some(err) => Err(err.into()),
883                    None => Err(err.into()),
884                }
885            }
886        }
887    }
888}
889
890impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
891    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
892        let KafkaConnection {
893            brokers: _,
894            default_tunnel: _,
895            progress_topic,
896            progress_topic_options,
897            options: _,
898            tls: _,
899            sasl: _,
900        } = self;
901
902        let compatibility_checks = [
903            (progress_topic == &other.progress_topic, "progress_topic"),
904            (
905                progress_topic_options == &other.progress_topic_options,
906                "progress_topic_options",
907            ),
908        ];
909
910        for (compatible, field) in compatibility_checks {
911            if !compatible {
912                tracing::warn!(
913                    "KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
914                    self,
915                    other
916                );
917
918                return Err(AlterError { id });
919            }
920        }
921
922        Ok(())
923    }
924}
925
926impl RustType<ProtoKafkaConnectionTlsConfig> for KafkaTlsConfig {
927    fn into_proto(&self) -> ProtoKafkaConnectionTlsConfig {
928        ProtoKafkaConnectionTlsConfig {
929            identity: self.identity.into_proto(),
930            root_cert: self.root_cert.into_proto(),
931        }
932    }
933
934    fn from_proto(proto: ProtoKafkaConnectionTlsConfig) -> Result<Self, TryFromProtoError> {
935        Ok(KafkaTlsConfig {
936            root_cert: proto.root_cert.into_rust()?,
937            identity: proto.identity.into_rust()?,
938        })
939    }
940}
941
942impl RustType<ProtoKafkaConnectionSaslConfig> for KafkaSaslConfig {
943    fn into_proto(&self) -> ProtoKafkaConnectionSaslConfig {
944        ProtoKafkaConnectionSaslConfig {
945            mechanism: self.mechanism.into_proto(),
946            username: Some(self.username.into_proto()),
947            password: self.password.into_proto(),
948            aws: self.aws.into_proto(),
949        }
950    }
951
952    fn from_proto(proto: ProtoKafkaConnectionSaslConfig) -> Result<Self, TryFromProtoError> {
953        Ok(KafkaSaslConfig {
954            mechanism: proto.mechanism,
955            username: proto
956                .username
957                .into_rust_if_some("ProtoKafkaConnectionSaslConfig::username")?,
958            password: proto.password.into_rust()?,
959            aws: proto.aws.into_rust()?,
960        })
961    }
962}
963
964impl RustType<ProtoKafkaConnection> for KafkaConnection {
965    fn into_proto(&self) -> ProtoKafkaConnection {
966        ProtoKafkaConnection {
967            brokers: self.brokers.into_proto(),
968            default_tunnel: Some(self.default_tunnel.into_proto()),
969            progress_topic: self.progress_topic.into_proto(),
970            progress_topic_options: Some(self.progress_topic_options.into_proto()),
971            options: self
972                .options
973                .iter()
974                .map(|(k, v)| (k.clone(), v.into_proto()))
975                .collect(),
976            tls: self.tls.into_proto(),
977            sasl: self.sasl.into_proto(),
978        }
979    }
980
981    fn from_proto(proto: ProtoKafkaConnection) -> Result<Self, TryFromProtoError> {
982        Ok(KafkaConnection {
983            brokers: proto.brokers.into_rust()?,
984            default_tunnel: proto
985                .default_tunnel
986                .into_rust_if_some("ProtoKafkaConnection::default_tunnel")?,
987            progress_topic: proto.progress_topic,
988            progress_topic_options: match proto.progress_topic_options {
989                Some(progress_topic_options) => progress_topic_options.into_rust()?,
990                None => Default::default(),
991            },
992            options: proto
993                .options
994                .into_iter()
995                .map(|(k, v)| StringOrSecret::from_proto(v).map(|v| (k, v)))
996                .collect::<Result<_, _>>()?,
997            tls: proto.tls.into_rust()?,
998            sasl: proto.sasl.into_rust()?,
999        })
1000    }
1001}
1002
1003/// A connection to a Confluent Schema Registry.
1004#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1005pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
1006    /// The URL of the schema registry.
1007    #[proptest(strategy = "any_url()")]
1008    pub url: Url,
1009    /// Trusted root TLS certificate in PEM format.
1010    pub tls_root_cert: Option<StringOrSecret>,
1011    /// An optional TLS client certificate for authentication with the schema
1012    /// registry.
1013    pub tls_identity: Option<TlsIdentity>,
1014    /// Optional HTTP authentication credentials for the schema registry.
1015    pub http_auth: Option<CsrConnectionHttpAuth>,
1016    /// A tunnel through which to route traffic.
1017    pub tunnel: Tunnel<C>,
1018}
1019
1020impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
1021    for CsrConnection<ReferencedConnection>
1022{
1023    fn into_inline_connection(self, r: R) -> CsrConnection {
1024        let CsrConnection {
1025            url,
1026            tls_root_cert,
1027            tls_identity,
1028            http_auth,
1029            tunnel,
1030        } = self;
1031        CsrConnection {
1032            url,
1033            tls_root_cert,
1034            tls_identity,
1035            http_auth,
1036            tunnel: tunnel.into_inline_connection(r),
1037        }
1038    }
1039}
1040
1041impl<C: ConnectionAccess> CsrConnection<C> {
1042    fn validate_by_default(&self) -> bool {
1043        true
1044    }
1045}
1046
1047impl CsrConnection {
1048    /// Constructs a schema registry client from the connection.
1049    pub async fn connect(
1050        &self,
1051        storage_configuration: &StorageConfiguration,
1052        in_task: InTask,
1053    ) -> Result<mz_ccsr::Client, CsrConnectError> {
1054        let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
1055        if let Some(root_cert) = &self.tls_root_cert {
1056            let root_cert = root_cert
1057                .get_string(
1058                    in_task,
1059                    &storage_configuration.connection_context.secrets_reader,
1060                )
1061                .await?;
1062            let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
1063            client_config = client_config.add_root_certificate(root_cert);
1064        }
1065
1066        if let Some(tls_identity) = &self.tls_identity {
1067            let key = &storage_configuration
1068                .connection_context
1069                .secrets_reader
1070                .read_string_in_task_if(in_task, tls_identity.key)
1071                .await?;
1072            let cert = tls_identity
1073                .cert
1074                .get_string(
1075                    in_task,
1076                    &storage_configuration.connection_context.secrets_reader,
1077                )
1078                .await?;
1079            let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
1080            client_config = client_config.identity(ident);
1081        }
1082
1083        if let Some(http_auth) = &self.http_auth {
1084            let username = http_auth
1085                .username
1086                .get_string(
1087                    in_task,
1088                    &storage_configuration.connection_context.secrets_reader,
1089                )
1090                .await?;
1091            let password = match http_auth.password {
1092                None => None,
1093                Some(password) => Some(
1094                    storage_configuration
1095                        .connection_context
1096                        .secrets_reader
1097                        .read_string_in_task_if(in_task, password)
1098                        .await?,
1099                ),
1100            };
1101            client_config = client_config.auth(username, password);
1102        }
1103
1104        // `net::lookup_host` requires a port but the port will be ignored when
1105        // passed to `resolve_to_addrs`. We use a dummy port that will be easy
1106        // to spot in the logs to make it obvious if some component downstream
1107        // incorrectly starts using this port.
1108        const DUMMY_PORT: u16 = 11111;
1109
1110        // TODO: use types to enforce that the URL has a string hostname.
1111        let host = self
1112            .url
1113            .host_str()
1114            .ok_or_else(|| anyhow!("url missing host"))?;
1115        match &self.tunnel {
1116            Tunnel::Direct => {
1117                // Ensure any host we connect to is resolved to an external address.
1118                let resolved = resolve_address(
1119                    host,
1120                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1121                )
1122                .await?;
1123                client_config = client_config.resolve_to_addrs(
1124                    host,
1125                    &resolved
1126                        .iter()
1127                        .map(|addr| SocketAddr::new(*addr, DUMMY_PORT))
1128                        .collect::<Vec<_>>(),
1129                )
1130            }
1131            Tunnel::Ssh(ssh_tunnel) => {
1132                let ssh_tunnel = ssh_tunnel
1133                    .connect(
1134                        storage_configuration,
1135                        host,
1136                        // Default to the default http port, but this
1137                        // could default to 8081...
1138                        self.url.port().unwrap_or(80),
1139                        in_task,
1140                    )
1141                    .await
1142                    .map_err(CsrConnectError::Ssh)?;
1143
1144                // Carefully inject the SSH tunnel into the client
1145                // configuration. This is delicate because we need TLS
1146                // verification to continue to use the remote hostname rather
1147                // than the tunnel hostname.
1148
1149                client_config = client_config
1150                    // `resolve_to_addrs` allows us to rewrite the hostname
1151                    // at the DNS level, which means the TCP connection is
1152                    // correctly routed through the tunnel, but TLS verification
1153                    // is still performed against the remote hostname.
1154                    // Unfortunately the port here is ignored...
1155                    .resolve_to_addrs(
1156                        host,
1157                        &[SocketAddr::new(ssh_tunnel.local_addr().ip(), DUMMY_PORT)],
1158                    )
1159                    // ...so we also dynamically rewrite the URL to use the
1160                    // current port for the SSH tunnel.
1161                    //
1162                    // WARNING: this is brittle, because we only dynamically
1163                    // update the client configuration with the tunnel *port*,
1164                    // and not the hostname This works fine in practice, because
1165                    // only the SSH tunnel port will change if the tunnel fails
1166                    // and has to be restarted (the hostname is always
1167                    // 127.0.0.1)--but this is an an implementation detail of
1168                    // the SSH tunnel code that we're relying on.
1169                    .dynamic_url({
1170                        let remote_url = self.url.clone();
1171                        move || {
1172                            let mut url = remote_url.clone();
1173                            url.set_port(Some(ssh_tunnel.local_addr().port()))
1174                                .expect("cannot fail");
1175                            url
1176                        }
1177                    });
1178            }
1179            Tunnel::AwsPrivatelink(connection) => {
1180                assert_none!(connection.port);
1181
1182                let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
1183                    connection.connection_id,
1184                    connection.availability_zone.as_deref(),
1185                );
1186                let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_PORT))
1187                    .await
1188                    .context("resolving PrivateLink host")?
1189                    .collect();
1190                client_config = client_config.resolve_to_addrs(host, &addrs)
1191            }
1192        }
1193
1194        Ok(client_config.build()?)
1195    }
1196
1197    async fn validate(
1198        &self,
1199        _id: CatalogItemId,
1200        storage_configuration: &StorageConfiguration,
1201    ) -> Result<(), anyhow::Error> {
1202        let client = self
1203            .connect(
1204                storage_configuration,
1205                // We are in a normal tokio context during validation, already.
1206                InTask::No,
1207            )
1208            .await?;
1209        client.list_subjects().await?;
1210        Ok(())
1211    }
1212}
1213
1214impl RustType<ProtoCsrConnection> for CsrConnection {
1215    fn into_proto(&self) -> ProtoCsrConnection {
1216        ProtoCsrConnection {
1217            url: Some(self.url.into_proto()),
1218            tls_root_cert: self.tls_root_cert.into_proto(),
1219            tls_identity: self.tls_identity.into_proto(),
1220            http_auth: self.http_auth.into_proto(),
1221            tunnel: Some(self.tunnel.into_proto()),
1222        }
1223    }
1224
1225    fn from_proto(proto: ProtoCsrConnection) -> Result<Self, TryFromProtoError> {
1226        Ok(CsrConnection {
1227            url: proto.url.into_rust_if_some("ProtoCsrConnection::url")?,
1228            tls_root_cert: proto.tls_root_cert.into_rust()?,
1229            tls_identity: proto.tls_identity.into_rust()?,
1230            http_auth: proto.http_auth.into_rust()?,
1231            tunnel: proto
1232                .tunnel
1233                .into_rust_if_some("ProtoCsrConnection::tunnel")?,
1234        })
1235    }
1236}
1237
1238impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
1239    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1240        let CsrConnection {
1241            tunnel,
1242            // All non-tunnel fields may change
1243            url: _,
1244            tls_root_cert: _,
1245            tls_identity: _,
1246            http_auth: _,
1247        } = self;
1248
1249        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
1250
1251        for (compatible, field) in compatibility_checks {
1252            if !compatible {
1253                tracing::warn!(
1254                    "CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1255                    self,
1256                    other
1257                );
1258
1259                return Err(AlterError { id });
1260            }
1261        }
1262        Ok(())
1263    }
1264}
1265
1266/// A TLS key pair used for client identity.
1267#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1268pub struct TlsIdentity {
1269    /// The client's TLS public certificate in PEM format.
1270    pub cert: StringOrSecret,
1271    /// The ID of the secret containing the client's TLS private key in PEM
1272    /// format.
1273    pub key: CatalogItemId,
1274}
1275
1276impl RustType<ProtoTlsIdentity> for TlsIdentity {
1277    fn into_proto(&self) -> ProtoTlsIdentity {
1278        ProtoTlsIdentity {
1279            cert: Some(self.cert.into_proto()),
1280            key: Some(self.key.into_proto()),
1281        }
1282    }
1283
1284    fn from_proto(proto: ProtoTlsIdentity) -> Result<Self, TryFromProtoError> {
1285        Ok(TlsIdentity {
1286            cert: proto.cert.into_rust_if_some("ProtoTlsIdentity::cert")?,
1287            key: proto.key.into_rust_if_some("ProtoTlsIdentity::key")?,
1288        })
1289    }
1290}
1291
1292/// HTTP authentication credentials in a [`CsrConnection`].
1293#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1294pub struct CsrConnectionHttpAuth {
1295    /// The username.
1296    pub username: StringOrSecret,
1297    /// The ID of the secret containing the password, if any.
1298    pub password: Option<CatalogItemId>,
1299}
1300
1301impl RustType<ProtoCsrConnectionHttpAuth> for CsrConnectionHttpAuth {
1302    fn into_proto(&self) -> ProtoCsrConnectionHttpAuth {
1303        ProtoCsrConnectionHttpAuth {
1304            username: Some(self.username.into_proto()),
1305            password: self.password.into_proto(),
1306        }
1307    }
1308
1309    fn from_proto(proto: ProtoCsrConnectionHttpAuth) -> Result<Self, TryFromProtoError> {
1310        Ok(CsrConnectionHttpAuth {
1311            username: proto
1312                .username
1313                .into_rust_if_some("ProtoCsrConnectionHttpAuth::username")?,
1314            password: proto.password.into_rust()?,
1315        })
1316    }
1317}
1318
1319/// A connection to a PostgreSQL server.
1320#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1321pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
1322    /// The hostname of the server.
1323    pub host: String,
1324    /// The port of the server.
1325    pub port: u16,
1326    /// The name of the database to connect to.
1327    pub database: String,
1328    /// The username to authenticate as.
1329    pub user: StringOrSecret,
1330    /// An optional password for authentication.
1331    pub password: Option<CatalogItemId>,
1332    /// A tunnel through which to route traffic.
1333    pub tunnel: Tunnel<C>,
1334    /// Whether to use TLS for encryption, authentication, or both.
1335    #[proptest(strategy = "any_ssl_mode()")]
1336    pub tls_mode: SslMode,
1337    /// An optional root TLS certificate in PEM format, to verify the server's
1338    /// identity.
1339    pub tls_root_cert: Option<StringOrSecret>,
1340    /// An optional TLS client certificate for authentication.
1341    pub tls_identity: Option<TlsIdentity>,
1342    /// The kind of postgres server we are connecting to. This can be vanilla, for a normal
1343    /// postgres server or some other system that is pg compatible, like Yugabyte, Aurora, etc.
1344    pub flavor: PostgresFlavor,
1345}
1346
1347impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
1348    for PostgresConnection<ReferencedConnection>
1349{
1350    fn into_inline_connection(self, r: R) -> PostgresConnection {
1351        let PostgresConnection {
1352            host,
1353            port,
1354            database,
1355            user,
1356            password,
1357            tunnel,
1358            tls_mode,
1359            tls_root_cert,
1360            tls_identity,
1361            flavor,
1362        } = self;
1363
1364        PostgresConnection {
1365            host,
1366            port,
1367            database,
1368            user,
1369            password,
1370            tunnel: tunnel.into_inline_connection(r),
1371            tls_mode,
1372            tls_root_cert,
1373            tls_identity,
1374            flavor,
1375        }
1376    }
1377}
1378
1379impl<C: ConnectionAccess> PostgresConnection<C> {
1380    fn validate_by_default(&self) -> bool {
1381        true
1382    }
1383}
1384
1385impl PostgresConnection<InlinedConnection> {
1386    pub async fn config(
1387        &self,
1388        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1389        storage_configuration: &StorageConfiguration,
1390        in_task: InTask,
1391    ) -> Result<mz_postgres_util::Config, anyhow::Error> {
1392        let params = &storage_configuration.parameters;
1393
1394        let mut config = tokio_postgres::Config::new();
1395        config
1396            .host(&self.host)
1397            .port(self.port)
1398            .dbname(&self.database)
1399            .user(&self.user.get_string(in_task, secrets_reader).await?)
1400            .ssl_mode(self.tls_mode);
1401        if let Some(password) = self.password {
1402            let password = secrets_reader
1403                .read_string_in_task_if(in_task, password)
1404                .await?;
1405            config.password(password);
1406        }
1407        if let Some(tls_root_cert) = &self.tls_root_cert {
1408            let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1409            config.ssl_root_cert(tls_root_cert.as_bytes());
1410        }
1411        if let Some(tls_identity) = &self.tls_identity {
1412            let cert = tls_identity
1413                .cert
1414                .get_string(in_task, secrets_reader)
1415                .await?;
1416            let key = secrets_reader
1417                .read_string_in_task_if(in_task, tls_identity.key)
1418                .await?;
1419            config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
1420        }
1421
1422        if let Some(connect_timeout) = params.pg_source_connect_timeout {
1423            config.connect_timeout(connect_timeout);
1424        }
1425        if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1426            config.keepalives_retries(keepalives_retries);
1427        }
1428        if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1429            config.keepalives_idle(keepalives_idle);
1430        }
1431        if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1432            config.keepalives_interval(keepalives_interval);
1433        }
1434        if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1435            config.tcp_user_timeout(tcp_user_timeout);
1436        }
1437
1438        let mut options = vec![];
1439        if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
1440            options.push(format!(
1441                "--wal_sender_timeout={}",
1442                wal_sender_timeout.as_millis()
1443            ));
1444        };
1445        if params.pg_source_tcp_configure_server {
1446            if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
1447                options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
1448            }
1449            if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
1450                options.push(format!(
1451                    "--tcp_keepalives_idle={}",
1452                    keepalives_idle.as_secs()
1453                ));
1454            }
1455            if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
1456                options.push(format!(
1457                    "--tcp_keepalives_interval={}",
1458                    keepalives_interval.as_secs()
1459                ));
1460            }
1461            if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
1462                options.push(format!(
1463                    "--tcp_user_timeout={}",
1464                    tcp_user_timeout.as_millis()
1465                ));
1466            }
1467        }
1468        config.options(options.join(" ").as_str());
1469
1470        let tunnel = match &self.tunnel {
1471            Tunnel::Direct => {
1472                // Ensure any host we connect to is resolved to an external address.
1473                let resolved = resolve_address(
1474                    &self.host,
1475                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1476                )
1477                .await?;
1478                mz_postgres_util::TunnelConfig::Direct {
1479                    resolved_ips: Some(resolved),
1480                }
1481            }
1482            Tunnel::Ssh(SshTunnel {
1483                connection_id,
1484                connection,
1485            }) => {
1486                let secret = secrets_reader
1487                    .read_in_task_if(in_task, *connection_id)
1488                    .await?;
1489                let key_pair = SshKeyPair::from_bytes(&secret)?;
1490                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1491                let resolved = resolve_address(
1492                    &connection.host,
1493                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1494                )
1495                .await?;
1496                mz_postgres_util::TunnelConfig::Ssh {
1497                    config: SshTunnelConfig {
1498                        host: resolved
1499                            .iter()
1500                            .map(|a| a.to_string())
1501                            .collect::<BTreeSet<_>>(),
1502                        port: connection.port,
1503                        user: connection.user.clone(),
1504                        key_pair,
1505                    },
1506                }
1507            }
1508            Tunnel::AwsPrivatelink(connection) => {
1509                assert_none!(connection.port);
1510                mz_postgres_util::TunnelConfig::AwsPrivatelink {
1511                    connection_id: connection.connection_id,
1512                }
1513            }
1514        };
1515
1516        Ok(mz_postgres_util::Config::new(
1517            config,
1518            tunnel,
1519            params.ssh_timeout_config,
1520            in_task,
1521        )?)
1522    }
1523
1524    async fn validate(
1525        &self,
1526        _id: CatalogItemId,
1527        storage_configuration: &StorageConfiguration,
1528    ) -> Result<(), anyhow::Error> {
1529        let config = self
1530            .config(
1531                &storage_configuration.connection_context.secrets_reader,
1532                storage_configuration,
1533                // We are in a normal tokio context during validation, already.
1534                InTask::No,
1535            )
1536            .await?;
1537        let client = config
1538            .connect(
1539                "connection validation",
1540                &storage_configuration.connection_context.ssh_tunnel_manager,
1541            )
1542            .await?;
1543        use PostgresFlavor::*;
1544        match (client.server_flavor(), &self.flavor) {
1545            (Vanilla, Yugabyte) => bail!("Expected to find PostgreSQL server, found Yugabyte."),
1546            (Yugabyte, Vanilla) => bail!("Expected to find Yugabyte server, found PostgreSQL."),
1547            (Vanilla, Vanilla) | (Yugabyte, Yugabyte) => {}
1548        }
1549        Ok(())
1550    }
1551}
1552
1553impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
1554    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1555        let PostgresConnection {
1556            tunnel,
1557            flavor,
1558            // All non-tunnel options may change arbitrarily
1559            host: _,
1560            port: _,
1561            database: _,
1562            user: _,
1563            password: _,
1564            tls_mode: _,
1565            tls_root_cert: _,
1566            tls_identity: _,
1567        } = self;
1568
1569        let compatibility_checks = [
1570            (tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel"),
1571            (flavor == &other.flavor, "flavor"),
1572        ];
1573
1574        for (compatible, field) in compatibility_checks {
1575            if !compatible {
1576                tracing::warn!(
1577                    "PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
1578                    self,
1579                    other
1580                );
1581
1582                return Err(AlterError { id });
1583            }
1584        }
1585        Ok(())
1586    }
1587}
1588
1589impl RustType<ProtoPostgresConnection> for PostgresConnection {
1590    fn into_proto(&self) -> ProtoPostgresConnection {
1591        ProtoPostgresConnection {
1592            host: self.host.into_proto(),
1593            port: self.port.into_proto(),
1594            database: self.database.into_proto(),
1595            user: Some(self.user.into_proto()),
1596            password: self.password.into_proto(),
1597            tls_mode: Some(self.tls_mode.into_proto()),
1598            tls_root_cert: self.tls_root_cert.into_proto(),
1599            tls_identity: self.tls_identity.into_proto(),
1600            tunnel: Some(self.tunnel.into_proto()),
1601            flavor: Some(self.flavor.into_proto()),
1602        }
1603    }
1604
1605    fn from_proto(proto: ProtoPostgresConnection) -> Result<Self, TryFromProtoError> {
1606        Ok(PostgresConnection {
1607            host: proto.host,
1608            port: proto.port.into_rust()?,
1609            database: proto.database,
1610            user: proto
1611                .user
1612                .into_rust_if_some("ProtoPostgresConnection::user")?,
1613            password: proto.password.into_rust()?,
1614            tunnel: proto
1615                .tunnel
1616                .into_rust_if_some("ProtoPostgresConnection::tunnel")?,
1617            tls_mode: proto
1618                .tls_mode
1619                .into_rust_if_some("ProtoPostgresConnection::tls_mode")?,
1620            tls_root_cert: proto.tls_root_cert.into_rust()?,
1621            tls_identity: proto.tls_identity.into_rust()?,
1622            flavor: proto
1623                .flavor
1624                .into_rust_if_some("ProtoPostgresConnection::flavor")?,
1625        })
1626    }
1627}
1628
1629/// Specifies how to tunnel a connection.
1630#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1631pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
1632    /// No tunneling.
1633    Direct,
1634    /// Via the specified SSH tunnel connection.
1635    Ssh(SshTunnel<C>),
1636    /// Via the specified AWS PrivateLink connection.
1637    AwsPrivatelink(AwsPrivatelink),
1638}
1639
1640impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
1641    fn into_inline_connection(self, r: R) -> Tunnel {
1642        match self {
1643            Tunnel::Direct => Tunnel::Direct,
1644            Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
1645            Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
1646        }
1647    }
1648}
1649
1650impl RustType<ProtoTunnel> for Tunnel<InlinedConnection> {
1651    fn into_proto(&self) -> ProtoTunnel {
1652        use proto_tunnel::Tunnel as ProtoTunnelField;
1653        ProtoTunnel {
1654            tunnel: Some(match &self {
1655                Tunnel::Direct => ProtoTunnelField::Direct(()),
1656                Tunnel::Ssh(ssh) => ProtoTunnelField::Ssh(ssh.into_proto()),
1657                Tunnel::AwsPrivatelink(aws) => ProtoTunnelField::AwsPrivatelink(aws.into_proto()),
1658            }),
1659        }
1660    }
1661
1662    fn from_proto(proto: ProtoTunnel) -> Result<Self, TryFromProtoError> {
1663        use proto_tunnel::Tunnel as ProtoTunnelField;
1664        Ok(match proto.tunnel {
1665            None => return Err(TryFromProtoError::missing_field("ProtoTunnel::tunnel")),
1666            Some(ProtoTunnelField::Direct(())) => Tunnel::Direct,
1667            Some(ProtoTunnelField::Ssh(ssh)) => Tunnel::Ssh(ssh.into_rust()?),
1668            Some(ProtoTunnelField::AwsPrivatelink(aws)) => Tunnel::AwsPrivatelink(aws.into_rust()?),
1669        })
1670    }
1671}
1672
1673impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
1674    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1675        let compatible = match (self, other) {
1676            (Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
1677            (s, o) => s == o,
1678        };
1679
1680        if !compatible {
1681            tracing::warn!(
1682                "Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
1683                self,
1684                other
1685            );
1686
1687            return Err(AlterError { id });
1688        }
1689
1690        Ok(())
1691    }
1692}
1693
1694/// Specifies which MySQL SSL Mode to use:
1695/// <https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#option_general_ssl-mode>
1696/// This is not available as an enum in the mysql-async crate, so we define our own.
1697#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
1698pub enum MySqlSslMode {
1699    Disabled,
1700    Required,
1701    VerifyCa,
1702    VerifyIdentity,
1703}
1704
1705impl RustType<i32> for MySqlSslMode {
1706    fn into_proto(&self) -> i32 {
1707        match self {
1708            MySqlSslMode::Disabled => ProtoMySqlSslMode::Disabled.into(),
1709            MySqlSslMode::Required => ProtoMySqlSslMode::Required.into(),
1710            MySqlSslMode::VerifyCa => ProtoMySqlSslMode::VerifyCa.into(),
1711            MySqlSslMode::VerifyIdentity => ProtoMySqlSslMode::VerifyIdentity.into(),
1712        }
1713    }
1714
1715    fn from_proto(proto: i32) -> Result<Self, TryFromProtoError> {
1716        Ok(match ProtoMySqlSslMode::try_from(proto) {
1717            Ok(ProtoMySqlSslMode::Disabled) => MySqlSslMode::Disabled,
1718            Ok(ProtoMySqlSslMode::Required) => MySqlSslMode::Required,
1719            Ok(ProtoMySqlSslMode::VerifyCa) => MySqlSslMode::VerifyCa,
1720            Ok(ProtoMySqlSslMode::VerifyIdentity) => MySqlSslMode::VerifyIdentity,
1721            Err(_) => {
1722                return Err(TryFromProtoError::UnknownEnumVariant(
1723                    "tls_mode".to_string(),
1724                ));
1725            }
1726        })
1727    }
1728}
1729
1730pub fn any_mysql_ssl_mode() -> impl Strategy<Value = MySqlSslMode> {
1731    proptest::sample::select(vec![
1732        MySqlSslMode::Disabled,
1733        MySqlSslMode::Required,
1734        MySqlSslMode::VerifyCa,
1735        MySqlSslMode::VerifyIdentity,
1736    ])
1737}
1738
1739/// A connection to a MySQL server.
1740#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
1741pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
1742    /// The hostname of the server.
1743    pub host: String,
1744    /// The port of the server.
1745    pub port: u16,
1746    /// The username to authenticate as.
1747    pub user: StringOrSecret,
1748    /// An optional password for authentication.
1749    pub password: Option<CatalogItemId>,
1750    /// A tunnel through which to route traffic.
1751    pub tunnel: Tunnel<C>,
1752    /// Whether to use TLS for encryption, verify the server's certificate, and identity.
1753    #[proptest(strategy = "any_mysql_ssl_mode()")]
1754    pub tls_mode: MySqlSslMode,
1755    /// An optional root TLS certificate in PEM format, to verify the server's
1756    /// identity.
1757    pub tls_root_cert: Option<StringOrSecret>,
1758    /// An optional TLS client certificate for authentication.
1759    pub tls_identity: Option<TlsIdentity>,
1760    /// Reference to the AWS connection information to be used for IAM authenitcation and
1761    /// assuming AWS roles.
1762    pub aws_connection: Option<AwsConnectionReference<C>>,
1763}
1764
1765impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
1766    for MySqlConnection<ReferencedConnection>
1767{
1768    fn into_inline_connection(self, r: R) -> MySqlConnection {
1769        let MySqlConnection {
1770            host,
1771            port,
1772            user,
1773            password,
1774            tunnel,
1775            tls_mode,
1776            tls_root_cert,
1777            tls_identity,
1778            aws_connection,
1779        } = self;
1780
1781        MySqlConnection {
1782            host,
1783            port,
1784            user,
1785            password,
1786            tunnel: tunnel.into_inline_connection(&r),
1787            tls_mode,
1788            tls_root_cert,
1789            tls_identity,
1790            aws_connection: aws_connection.map(|aws| aws.into_inline_connection(&r)),
1791        }
1792    }
1793}
1794
1795impl<C: ConnectionAccess> MySqlConnection<C> {
1796    fn validate_by_default(&self) -> bool {
1797        true
1798    }
1799}
1800
1801impl MySqlConnection<InlinedConnection> {
1802    pub async fn config(
1803        &self,
1804        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
1805        storage_configuration: &StorageConfiguration,
1806        in_task: InTask,
1807    ) -> Result<mz_mysql_util::Config, anyhow::Error> {
1808        // TODO(roshan): Set appropriate connection timeouts
1809        let mut opts = mysql_async::OptsBuilder::default()
1810            .ip_or_hostname(&self.host)
1811            .tcp_port(self.port)
1812            .user(Some(&self.user.get_string(in_task, secrets_reader).await?));
1813
1814        if let Some(password) = self.password {
1815            let password = secrets_reader
1816                .read_string_in_task_if(in_task, password)
1817                .await?;
1818            opts = opts.pass(Some(password));
1819        }
1820
1821        // Our `MySqlSslMode` enum matches the official MySQL Client `--ssl-mode` parameter values
1822        // which uses opt-in security features (SSL, CA verification, & Identity verification).
1823        // The mysql_async crate `SslOpts` struct uses an opt-out mechanism for each of these, so
1824        // we need to appropriately disable features to match the intent of each enum value.
1825        let mut ssl_opts = match self.tls_mode {
1826            MySqlSslMode::Disabled => None,
1827            MySqlSslMode::Required => Some(
1828                mysql_async::SslOpts::default()
1829                    .with_danger_accept_invalid_certs(true)
1830                    .with_danger_skip_domain_validation(true),
1831            ),
1832            MySqlSslMode::VerifyCa => {
1833                Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
1834            }
1835            MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
1836        };
1837
1838        if matches!(
1839            self.tls_mode,
1840            MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
1841        ) {
1842            if let Some(tls_root_cert) = &self.tls_root_cert {
1843                let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
1844                ssl_opts = ssl_opts.map(|opts| {
1845                    opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
1846                });
1847            }
1848        }
1849
1850        if let Some(identity) = &self.tls_identity {
1851            let key = secrets_reader
1852                .read_string_in_task_if(in_task, identity.key)
1853                .await?;
1854            let cert = identity.cert.get_string(in_task, secrets_reader).await?;
1855            let Pkcs12Archive { der, pass } =
1856                mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
1857
1858            // Add client identity to SSLOpts
1859            ssl_opts = ssl_opts.map(|opts| {
1860                opts.with_client_identity(Some(
1861                    mysql_async::ClientIdentity::new(der.into()).with_password(pass),
1862                ))
1863            });
1864        }
1865
1866        opts = opts.ssl_opts(ssl_opts);
1867
1868        let tunnel = match &self.tunnel {
1869            Tunnel::Direct => {
1870                // Ensure any host we connect to is resolved to an external address.
1871                let resolved = resolve_address(
1872                    &self.host,
1873                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1874                )
1875                .await?;
1876                mz_mysql_util::TunnelConfig::Direct {
1877                    resolved_ips: Some(resolved),
1878                }
1879            }
1880            Tunnel::Ssh(SshTunnel {
1881                connection_id,
1882                connection,
1883            }) => {
1884                let secret = secrets_reader
1885                    .read_in_task_if(in_task, *connection_id)
1886                    .await?;
1887                let key_pair = SshKeyPair::from_bytes(&secret)?;
1888                // Ensure any ssh-bastion host we connect to is resolved to an external address.
1889                let resolved = resolve_address(
1890                    &connection.host,
1891                    ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
1892                )
1893                .await?;
1894                mz_mysql_util::TunnelConfig::Ssh {
1895                    config: SshTunnelConfig {
1896                        host: resolved
1897                            .iter()
1898                            .map(|a| a.to_string())
1899                            .collect::<BTreeSet<_>>(),
1900                        port: connection.port,
1901                        user: connection.user.clone(),
1902                        key_pair,
1903                    },
1904                }
1905            }
1906            Tunnel::AwsPrivatelink(connection) => {
1907                assert_none!(connection.port);
1908                mz_mysql_util::TunnelConfig::AwsPrivatelink {
1909                    connection_id: connection.connection_id,
1910                }
1911            }
1912        };
1913
1914        let aws_config = match self.aws_connection.as_ref() {
1915            None => None,
1916            Some(aws_ref) => Some(
1917                aws_ref
1918                    .connection
1919                    .load_sdk_config(
1920                        &storage_configuration.connection_context,
1921                        aws_ref.connection_id,
1922                        in_task,
1923                    )
1924                    .await?,
1925            ),
1926        };
1927
1928        Ok(mz_mysql_util::Config::new(
1929            opts,
1930            tunnel,
1931            storage_configuration.parameters.ssh_timeout_config,
1932            in_task,
1933            storage_configuration
1934                .parameters
1935                .mysql_source_timeouts
1936                .clone(),
1937            aws_config,
1938        )?)
1939    }
1940
1941    async fn validate(
1942        &self,
1943        _id: CatalogItemId,
1944        storage_configuration: &StorageConfiguration,
1945    ) -> Result<(), anyhow::Error> {
1946        let config = self
1947            .config(
1948                &storage_configuration.connection_context.secrets_reader,
1949                storage_configuration,
1950                // We are in a normal tokio context during validation, already.
1951                InTask::No,
1952            )
1953            .await?;
1954        let conn = config
1955            .connect(
1956                "connection validation",
1957                &storage_configuration.connection_context.ssh_tunnel_manager,
1958            )
1959            .await?;
1960        conn.disconnect().await?;
1961        Ok(())
1962    }
1963}
1964
1965impl RustType<ProtoMySqlConnection> for MySqlConnection {
1966    fn into_proto(&self) -> ProtoMySqlConnection {
1967        ProtoMySqlConnection {
1968            host: self.host.into_proto(),
1969            port: self.port.into_proto(),
1970            user: Some(self.user.into_proto()),
1971            password: self.password.into_proto(),
1972            tls_mode: self.tls_mode.into_proto(),
1973            tls_root_cert: self.tls_root_cert.into_proto(),
1974            tls_identity: self.tls_identity.into_proto(),
1975            tunnel: Some(self.tunnel.into_proto()),
1976            aws_connection: self.aws_connection.into_proto(),
1977        }
1978    }
1979
1980    fn from_proto(proto: ProtoMySqlConnection) -> Result<Self, TryFromProtoError> {
1981        Ok(MySqlConnection {
1982            host: proto.host,
1983            port: proto.port.into_rust()?,
1984            user: proto.user.into_rust_if_some("ProtoMySqlConnection::user")?,
1985            password: proto.password.into_rust()?,
1986            tunnel: proto
1987                .tunnel
1988                .into_rust_if_some("ProtoMySqlConnection::tunnel")?,
1989            tls_mode: proto.tls_mode.into_rust()?,
1990            tls_root_cert: proto.tls_root_cert.into_rust()?,
1991            tls_identity: proto.tls_identity.into_rust()?,
1992            aws_connection: proto.aws_connection.into_rust()?,
1993        })
1994    }
1995}
1996
1997impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
1998    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
1999        let MySqlConnection {
2000            tunnel,
2001            // All non-tunnel options may change arbitrarily
2002            host: _,
2003            port: _,
2004            user: _,
2005            password: _,
2006            tls_mode: _,
2007            tls_root_cert: _,
2008            tls_identity: _,
2009            aws_connection: _,
2010        } = self;
2011
2012        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2013
2014        for (compatible, field) in compatibility_checks {
2015            if !compatible {
2016                tracing::warn!(
2017                    "MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2018                    self,
2019                    other
2020                );
2021
2022                return Err(AlterError { id });
2023            }
2024        }
2025        Ok(())
2026    }
2027}
2028
2029/// Details how to connect to an instance of Microsoft SQL Server.
2030///
2031/// For specifics of connecting to SQL Server for purposes of creating a
2032/// Materialize Source, see [`SqlServerSource`] which wraps this type.
2033///
2034/// [`SqlServerSource`]: crate::sources::SqlServerSource
2035#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2036pub struct SqlServerConnectionDetails<C: ConnectionAccess = InlinedConnection> {
2037    /// The hostname of the server.
2038    pub host: String,
2039    /// The port of the server.
2040    pub port: u16,
2041    /// Database we should connect to.
2042    pub database: String,
2043    /// The username to authenticate as.
2044    pub user: StringOrSecret,
2045    /// Password used for authentication.
2046    pub password: CatalogItemId,
2047    /// A tunnel through which to route traffic.
2048    pub tunnel: Tunnel<C>,
2049    /// Level of encryption to use for the connection.
2050    pub encryption: mz_sql_server_util::config::EncryptionLevel,
2051}
2052
2053impl<C: ConnectionAccess> SqlServerConnectionDetails<C> {
2054    fn validate_by_default(&self) -> bool {
2055        true
2056    }
2057}
2058
2059impl SqlServerConnectionDetails<InlinedConnection> {
2060    /// Attempts to open a connection to the upstream SQL Server instance.
2061    async fn validate(
2062        &self,
2063        _id: CatalogItemId,
2064        storage_configuration: &StorageConfiguration,
2065    ) -> Result<(), anyhow::Error> {
2066        let config = self
2067            .resolve_config(
2068                &storage_configuration.connection_context.secrets_reader,
2069                storage_configuration,
2070                InTask::No,
2071            )
2072            .await?;
2073        tracing::debug!(?config, "Validating SQL Server connection");
2074
2075        // Just connecting is enough to validate, no need to send any queries.
2076        let _client = mz_sql_server_util::Client::connect(config).await?;
2077
2078        Ok(())
2079    }
2080
2081    /// Resolve all of the connection details (e.g. read from the [`SecretsReader`])
2082    /// so the returned [`Config`] can be used to open a connection with the
2083    /// upstream system.
2084    ///
2085    /// The provided [`InTask`] argument determines whether any I/O is run in an
2086    /// [`mz_ore::task`] (i.e. a different thread) or directly in the returned
2087    /// future. The main goal here is to prevent running I/O in timely threads.
2088    ///
2089    /// [`Config`]: mz_sql_server_util::Config
2090    pub async fn resolve_config(
2091        &self,
2092        secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
2093        storage_configuration: &StorageConfiguration,
2094        in_task: InTask,
2095    ) -> Result<mz_sql_server_util::Config, anyhow::Error> {
2096        let dyncfg = storage_configuration.config_set();
2097        let mut inner_config = tiberius::Config::new();
2098
2099        // Setup default connection params.
2100        inner_config.host(&self.host);
2101        inner_config.port(self.port);
2102        inner_config.database(self.database.clone());
2103        // TODO(sql_server1): Figure out the right settings for encryption.
2104        // inner_config.encryption(self.encryption.into());
2105        inner_config.application_name("materialize");
2106
2107        // Read our auth settings from
2108        let user = self
2109            .user
2110            .get_string(in_task, secrets_reader)
2111            .await
2112            .context("username")?;
2113        let password = secrets_reader
2114            .read_string_in_task_if(in_task, self.password)
2115            .await
2116            .context("password")?;
2117        // TODO(sql_server3): Support other methods of authentication besides
2118        // username and password.
2119        inner_config.authentication(tiberius::AuthMethod::sql_server(user, password));
2120
2121        // TODO(sql_server2): Fork the tiberius library and add support for
2122        // specifying a cert bundle from a binary blob.
2123        //
2124        // See: <https://github.com/prisma/tiberius/pull/290>
2125        inner_config.trust_cert();
2126
2127        // Prevent users from probing our internal network ports by trying to
2128        // connect to localhost, or another non-external IP.
2129        let enfoce_external_addresses = ENFORCE_EXTERNAL_ADDRESSES.get(dyncfg);
2130
2131        let tunnel = match &self.tunnel {
2132            Tunnel::Direct => mz_sql_server_util::config::TunnelConfig::Direct,
2133            Tunnel::Ssh(SshTunnel {
2134                connection_id,
2135                connection: ssh_connection,
2136            }) => {
2137                let secret = secrets_reader
2138                    .read_in_task_if(in_task, *connection_id)
2139                    .await
2140                    .context("ssh secret")?;
2141                let key_pair = SshKeyPair::from_bytes(&secret).context("ssh key pair")?;
2142                // Ensure any SSH-bastion host we connect to is resolved to an
2143                // external address.
2144                let addresses = resolve_address(&ssh_connection.host, enfoce_external_addresses)
2145                    .await
2146                    .context("ssh tunnel")?;
2147
2148                let config = SshTunnelConfig {
2149                    host: addresses.into_iter().map(|a| a.to_string()).collect(),
2150                    port: ssh_connection.port,
2151                    user: ssh_connection.user.clone(),
2152                    key_pair,
2153                };
2154                mz_sql_server_util::config::TunnelConfig::Ssh {
2155                    config,
2156                    manager: storage_configuration
2157                        .connection_context
2158                        .ssh_tunnel_manager
2159                        .clone(),
2160                    timeout: storage_configuration.parameters.ssh_timeout_config.clone(),
2161                    host: self.host.clone(),
2162                    port: self.port,
2163                }
2164            }
2165            Tunnel::AwsPrivatelink(private_link_connection) => {
2166                assert_none!(private_link_connection.port);
2167                mz_sql_server_util::config::TunnelConfig::AwsPrivatelink {
2168                    connection_id: private_link_connection.connection_id,
2169                }
2170            }
2171        };
2172
2173        Ok(mz_sql_server_util::Config::new(
2174            inner_config,
2175            tunnel,
2176            in_task,
2177        ))
2178    }
2179}
2180
2181impl<R: ConnectionResolver> IntoInlineConnection<SqlServerConnectionDetails, R>
2182    for SqlServerConnectionDetails<ReferencedConnection>
2183{
2184    fn into_inline_connection(self, r: R) -> SqlServerConnectionDetails {
2185        let SqlServerConnectionDetails {
2186            host,
2187            port,
2188            database,
2189            user,
2190            password,
2191            tunnel,
2192            encryption,
2193        } = self;
2194
2195        SqlServerConnectionDetails {
2196            host,
2197            port,
2198            database,
2199            user,
2200            password,
2201            tunnel: tunnel.into_inline_connection(&r),
2202            encryption,
2203        }
2204    }
2205}
2206
2207impl<C: ConnectionAccess> AlterCompatible for SqlServerConnectionDetails<C> {
2208    fn alter_compatible(
2209        &self,
2210        id: mz_repr::GlobalId,
2211        other: &Self,
2212    ) -> Result<(), crate::controller::AlterError> {
2213        let SqlServerConnectionDetails {
2214            tunnel,
2215            // TODO(sql_server2): Figure out how these variables are allowed to change.
2216            host: _,
2217            port: _,
2218            database: _,
2219            user: _,
2220            password: _,
2221            encryption: _,
2222        } = self;
2223
2224        let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
2225
2226        for (compatible, field) in compatibility_checks {
2227            if !compatible {
2228                tracing::warn!(
2229                    "SqlServerConnectionDetails incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2230                    self,
2231                    other
2232                );
2233
2234                return Err(AlterError { id });
2235            }
2236        }
2237        Ok(())
2238    }
2239}
2240
2241impl RustType<ProtoSqlServerConnectionDetails> for SqlServerConnectionDetails {
2242    fn into_proto(&self) -> ProtoSqlServerConnectionDetails {
2243        ProtoSqlServerConnectionDetails {
2244            host: self.host.into_proto(),
2245            port: self.port.into_proto(),
2246            database: self.database.into_proto(),
2247            user: Some(self.user.into_proto()),
2248            password: Some(self.password.into_proto()),
2249            tunnel: Some(self.tunnel.into_proto()),
2250            encryption: self.encryption.into_proto().into(),
2251        }
2252    }
2253
2254    fn from_proto(proto: ProtoSqlServerConnectionDetails) -> Result<Self, TryFromProtoError> {
2255        Ok(SqlServerConnectionDetails {
2256            host: proto.host,
2257            port: proto.port.into_rust()?,
2258            database: proto.database.into_rust()?,
2259            user: proto
2260                .user
2261                .into_rust_if_some("ProtoSqlServerConnectionDetails::user")?,
2262            password: proto
2263                .password
2264                .into_rust_if_some("ProtoSqlServerConnectionDetails::password")?,
2265            tunnel: proto
2266                .tunnel
2267                .into_rust_if_some("ProtoSqlServerConnectionDetails::tunnel")?,
2268            encryption: ProtoSqlServerEncryptionLevel::try_from(proto.encryption)?.into_rust()?,
2269        })
2270    }
2271}
2272
2273impl RustType<ProtoSqlServerEncryptionLevel> for mz_sql_server_util::config::EncryptionLevel {
2274    fn into_proto(&self) -> ProtoSqlServerEncryptionLevel {
2275        match self {
2276            Self::None => ProtoSqlServerEncryptionLevel::SqlServerNone,
2277            Self::Login => ProtoSqlServerEncryptionLevel::SqlServerLogin,
2278            Self::Preferred => ProtoSqlServerEncryptionLevel::SqlServerPreferred,
2279            Self::Required => ProtoSqlServerEncryptionLevel::SqlServerRequired,
2280        }
2281    }
2282
2283    fn from_proto(proto: ProtoSqlServerEncryptionLevel) -> Result<Self, TryFromProtoError> {
2284        Ok(match proto {
2285            ProtoSqlServerEncryptionLevel::SqlServerNone => {
2286                mz_sql_server_util::config::EncryptionLevel::None
2287            }
2288            ProtoSqlServerEncryptionLevel::SqlServerLogin => {
2289                mz_sql_server_util::config::EncryptionLevel::Login
2290            }
2291            ProtoSqlServerEncryptionLevel::SqlServerPreferred => {
2292                mz_sql_server_util::config::EncryptionLevel::Preferred
2293            }
2294            ProtoSqlServerEncryptionLevel::SqlServerRequired => {
2295                mz_sql_server_util::config::EncryptionLevel::Required
2296            }
2297        })
2298    }
2299}
2300
2301/// A connection to an SSH tunnel.
2302#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2303pub struct SshConnection {
2304    pub host: String,
2305    pub port: u16,
2306    pub user: String,
2307}
2308
2309use self::inline::{
2310    ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
2311    ReferencedConnection,
2312};
2313
2314impl RustType<ProtoSshConnection> for SshConnection {
2315    fn into_proto(&self) -> ProtoSshConnection {
2316        ProtoSshConnection {
2317            host: self.host.into_proto(),
2318            port: self.port.into_proto(),
2319            user: self.user.into_proto(),
2320        }
2321    }
2322
2323    fn from_proto(proto: ProtoSshConnection) -> Result<Self, TryFromProtoError> {
2324        Ok(SshConnection {
2325            host: proto.host,
2326            port: proto.port.into_rust()?,
2327            user: proto.user,
2328        })
2329    }
2330}
2331
2332impl AlterCompatible for SshConnection {
2333    fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
2334        // Every element of the SSH connection is configurable.
2335        Ok(())
2336    }
2337}
2338
2339/// Specifies an AWS PrivateLink service for a [`Tunnel`].
2340#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2341pub struct AwsPrivatelink {
2342    /// The ID of the connection to the AWS PrivateLink service.
2343    pub connection_id: CatalogItemId,
2344    // The availability zone to use when connecting to the AWS PrivateLink service.
2345    pub availability_zone: Option<String>,
2346    /// The port to use when connecting to the AWS PrivateLink service, if
2347    /// different from the port in [`KafkaBroker::address`].
2348    pub port: Option<u16>,
2349}
2350
2351impl RustType<ProtoAwsPrivatelink> for AwsPrivatelink {
2352    fn into_proto(&self) -> ProtoAwsPrivatelink {
2353        ProtoAwsPrivatelink {
2354            connection_id: Some(self.connection_id.into_proto()),
2355            availability_zone: self.availability_zone.into_proto(),
2356            port: self.port.into_proto(),
2357        }
2358    }
2359
2360    fn from_proto(proto: ProtoAwsPrivatelink) -> Result<Self, TryFromProtoError> {
2361        Ok(AwsPrivatelink {
2362            connection_id: proto
2363                .connection_id
2364                .into_rust_if_some("ProtoAwsPrivatelink::connection_id")?,
2365            availability_zone: proto.availability_zone.into_rust()?,
2366            port: proto.port.into_rust()?,
2367        })
2368    }
2369}
2370
2371impl AlterCompatible for AwsPrivatelink {
2372    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2373        let AwsPrivatelink {
2374            connection_id,
2375            availability_zone: _,
2376            port: _,
2377        } = self;
2378
2379        let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
2380
2381        for (compatible, field) in compatibility_checks {
2382            if !compatible {
2383                tracing::warn!(
2384                    "AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2385                    self,
2386                    other
2387                );
2388
2389                return Err(AlterError { id });
2390            }
2391        }
2392
2393        Ok(())
2394    }
2395}
2396
2397/// Specifies an SSH tunnel connection.
2398#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
2399pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
2400    /// id of the ssh connection
2401    pub connection_id: CatalogItemId,
2402    /// ssh connection object
2403    pub connection: C::Ssh,
2404}
2405
2406impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
2407    fn into_inline_connection(self, r: R) -> SshTunnel {
2408        let SshTunnel {
2409            connection,
2410            connection_id,
2411        } = self;
2412
2413        SshTunnel {
2414            connection: r.resolve_connection(connection).unwrap_ssh(),
2415            connection_id,
2416        }
2417    }
2418}
2419
2420impl RustType<ProtoSshTunnel> for SshTunnel<InlinedConnection> {
2421    fn into_proto(&self) -> ProtoSshTunnel {
2422        ProtoSshTunnel {
2423            connection_id: Some(self.connection_id.into_proto()),
2424            connection: Some(self.connection.into_proto()),
2425        }
2426    }
2427
2428    fn from_proto(proto: ProtoSshTunnel) -> Result<Self, TryFromProtoError> {
2429        Ok(SshTunnel {
2430            connection_id: proto
2431                .connection_id
2432                .into_rust_if_some("ProtoSshTunnel::connection_id")?,
2433            connection: proto
2434                .connection
2435                .into_rust_if_some("ProtoSshTunnel::connection")?,
2436        })
2437    }
2438}
2439
2440impl SshTunnel<InlinedConnection> {
2441    /// Like [`SshTunnelConfig::connect`], but the SSH key is loaded from a
2442    /// secret.
2443    async fn connect(
2444        &self,
2445        storage_configuration: &StorageConfiguration,
2446        remote_host: &str,
2447        remote_port: u16,
2448        in_task: InTask,
2449    ) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
2450        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2451        let resolved = resolve_address(
2452            &self.connection.host,
2453            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2454        )
2455        .await?;
2456        storage_configuration
2457            .connection_context
2458            .ssh_tunnel_manager
2459            .connect(
2460                SshTunnelConfig {
2461                    host: resolved
2462                        .iter()
2463                        .map(|a| a.to_string())
2464                        .collect::<BTreeSet<_>>(),
2465                    port: self.connection.port,
2466                    user: self.connection.user.clone(),
2467                    key_pair: SshKeyPair::from_bytes(
2468                        &storage_configuration
2469                            .connection_context
2470                            .secrets_reader
2471                            .read_in_task_if(in_task, self.connection_id)
2472                            .await?,
2473                    )?,
2474                },
2475                remote_host,
2476                remote_port,
2477                storage_configuration.parameters.ssh_timeout_config,
2478                in_task,
2479            )
2480            .await
2481    }
2482}
2483
2484impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
2485    fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
2486        let SshTunnel {
2487            connection_id,
2488            connection,
2489        } = self;
2490
2491        let compatibility_checks = [
2492            (connection_id == &other.connection_id, "connection_id"),
2493            (
2494                connection.alter_compatible(id, &other.connection).is_ok(),
2495                "connection",
2496            ),
2497        ];
2498
2499        for (compatible, field) in compatibility_checks {
2500            if !compatible {
2501                tracing::warn!(
2502                    "SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
2503                    self,
2504                    other
2505                );
2506
2507                return Err(AlterError { id });
2508            }
2509        }
2510
2511        Ok(())
2512    }
2513}
2514
2515impl SshConnection {
2516    #[allow(clippy::unused_async)]
2517    async fn validate(
2518        &self,
2519        id: CatalogItemId,
2520        storage_configuration: &StorageConfiguration,
2521    ) -> Result<(), anyhow::Error> {
2522        let secret = storage_configuration
2523            .connection_context
2524            .secrets_reader
2525            .read_in_task_if(
2526                // We are in a normal tokio context during validation, already.
2527                InTask::No,
2528                id,
2529            )
2530            .await?;
2531        let key_pair = SshKeyPair::from_bytes(&secret)?;
2532
2533        // Ensure any ssh-bastion host we connect to is resolved to an external address.
2534        let resolved = resolve_address(
2535            &self.host,
2536            ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
2537        )
2538        .await?;
2539
2540        let config = SshTunnelConfig {
2541            host: resolved
2542                .iter()
2543                .map(|a| a.to_string())
2544                .collect::<BTreeSet<_>>(),
2545            port: self.port,
2546            user: self.user.clone(),
2547            key_pair,
2548        };
2549        // Note that we do NOT use the `SshTunnelManager` here, as we want to validate that we
2550        // can actually create a new connection to the ssh bastion, without tunneling.
2551        config
2552            .validate(storage_configuration.parameters.ssh_timeout_config)
2553            .await
2554    }
2555
2556    fn validate_by_default(&self) -> bool {
2557        false
2558    }
2559}
2560
2561impl AwsPrivatelinkConnection {
2562    #[allow(clippy::unused_async)]
2563    async fn validate(
2564        &self,
2565        id: CatalogItemId,
2566        storage_configuration: &StorageConfiguration,
2567    ) -> Result<(), anyhow::Error> {
2568        let Some(ref cloud_resource_reader) = storage_configuration
2569            .connection_context
2570            .cloud_resource_reader
2571        else {
2572            return Err(anyhow!("AWS PrivateLink connections are unsupported"));
2573        };
2574
2575        // No need to optionally run this in a task, as we are just validating from envd.
2576        let status = cloud_resource_reader.read(id).await?;
2577
2578        let availability = status
2579            .conditions
2580            .as_ref()
2581            .and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
2582
2583        match availability {
2584            Some(condition) if condition.status == "True" => Ok(()),
2585            Some(condition) => Err(anyhow!("{}", condition.message)),
2586            None => Err(anyhow!("Endpoint availability is unknown")),
2587        }
2588    }
2589
2590    fn validate_by_default(&self) -> bool {
2591        false
2592    }
2593}