use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet};
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, bail, Context};
use itertools::Itertools;
use mz_ccsr::tls::{Certificate, Identity};
use mz_cloud_resources::{vpc_endpoint_host, AwsExternalIdPrefix, CloudResourceReader};
use mz_dyncfg::ConfigSet;
use mz_kafka_util::client::{
BrokerAddr, BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
};
use mz_ore::assert_none;
use mz_ore::error::ErrorExt;
use mz_ore::future::{InTask, OreFutureExt};
use mz_ore::netio::resolve_address;
use mz_ore::num::NonNeg;
use mz_postgres_util::tunnel::PostgresFlavor;
use mz_proto::tokio_postgres::any_ssl_mode;
use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
use mz_repr::url::any_url;
use mz_repr::{CatalogItemId, GlobalId};
use mz_secrets::SecretsReader;
use mz_ssh_util::keys::SshKeyPair;
use mz_ssh_util::tunnel::SshTunnelConfig;
use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
use mz_tls_util::Pkcs12Archive;
use mz_tracing::CloneableEnvFilter;
use proptest::strategy::Strategy;
use proptest_derive::Arbitrary;
use rdkafka::config::FromClientConfigAndContext;
use rdkafka::consumer::{BaseConsumer, Consumer};
use rdkafka::ClientContext;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize};
use tokio::net;
use tokio::runtime::Handle;
use tokio_postgres::config::SslMode;
use tracing::{debug, warn};
use url::Url;
use crate::configuration::StorageConfiguration;
use crate::connections::aws::{
AwsConnection, AwsConnectionReference, AwsConnectionValidationError,
};
use crate::connections::string_or_secret::StringOrSecret;
use crate::controller::AlterError;
use crate::dyncfgs::{
ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES,
KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM,
};
use crate::errors::{ContextCreationError, CsrConnectError};
use crate::AlterCompatible;
pub mod aws;
pub mod inline;
pub mod string_or_secret;
include!(concat!(env!("OUT_DIR"), "/mz_storage_types.connections.rs"));
#[async_trait::async_trait]
trait SecretsReaderExt {
async fn read_in_task_if(
&self,
in_task: InTask,
id: CatalogItemId,
) -> Result<Vec<u8>, anyhow::Error>;
async fn read_string_in_task_if(
&self,
in_task: InTask,
id: CatalogItemId,
) -> Result<String, anyhow::Error>;
}
#[async_trait::async_trait]
impl SecretsReaderExt for Arc<dyn SecretsReader> {
async fn read_in_task_if(
&self,
in_task: InTask,
id: CatalogItemId,
) -> Result<Vec<u8>, anyhow::Error> {
let sr = Arc::clone(self);
async move { sr.read(id).await }
.run_in_task_if(in_task, || "secrets_reader_read".to_string())
.await
}
async fn read_string_in_task_if(
&self,
in_task: InTask,
id: CatalogItemId,
) -> Result<String, anyhow::Error> {
let sr = Arc::clone(self);
async move { sr.read_string(id).await }
.run_in_task_if(in_task, || "secrets_reader_read".to_string())
.await
}
}
#[derive(Debug, Clone)]
pub struct ConnectionContext {
pub environment_id: String,
pub librdkafka_log_level: tracing::Level,
pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
pub aws_connection_role_arn: Option<String>,
pub secrets_reader: Arc<dyn SecretsReader>,
pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
pub ssh_tunnel_manager: SshTunnelManager,
}
impl ConnectionContext {
pub fn from_cli_args(
environment_id: String,
startup_log_level: &CloneableEnvFilter,
aws_external_id_prefix: Option<AwsExternalIdPrefix>,
aws_connection_role_arn: Option<String>,
secrets_reader: Arc<dyn SecretsReader>,
cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
) -> ConnectionContext {
ConnectionContext {
environment_id,
librdkafka_log_level: mz_ore::tracing::crate_level(
&startup_log_level.clone().into(),
"librdkafka",
),
aws_external_id_prefix,
aws_connection_role_arn,
secrets_reader,
cloud_resource_reader,
ssh_tunnel_manager: SshTunnelManager::default(),
}
}
pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
ConnectionContext {
environment_id: "test-environment-id".into(),
librdkafka_log_level: tracing::Level::INFO,
aws_external_id_prefix: Some(
AwsExternalIdPrefix::new_from_cli_argument_or_environment_variable(
"test-aws-external-id-prefix",
)
.expect("infallible"),
),
aws_connection_role_arn: Some(
"arn:aws:iam::123456789000:role/MaterializeConnection".into(),
),
secrets_reader,
cloud_resource_reader: None,
ssh_tunnel_manager: SshTunnelManager::default(),
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum Connection<C: ConnectionAccess = InlinedConnection> {
Kafka(KafkaConnection<C>),
Csr(CsrConnection<C>),
Postgres(PostgresConnection<C>),
Ssh(SshConnection),
Aws(AwsConnection),
AwsPrivatelink(AwsPrivatelinkConnection),
MySql(MySqlConnection<C>),
}
impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
for Connection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> Connection {
match self {
Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
Connection::Ssh(ssh) => Connection::Ssh(ssh),
Connection::Aws(aws) => Connection::Aws(aws),
Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
}
}
}
impl<C: ConnectionAccess> Connection<C> {
pub fn validate_by_default(&self) -> bool {
match self {
Connection::Kafka(conn) => conn.validate_by_default(),
Connection::Csr(conn) => conn.validate_by_default(),
Connection::Postgres(conn) => conn.validate_by_default(),
Connection::Ssh(conn) => conn.validate_by_default(),
Connection::Aws(conn) => conn.validate_by_default(),
Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
Connection::MySql(conn) => conn.validate_by_default(),
}
}
}
impl Connection<InlinedConnection> {
pub async fn validate(
&self,
id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), ConnectionValidationError> {
match self {
Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
Connection::Postgres(conn) => conn.validate(id, storage_configuration).await?,
Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
Connection::MySql(conn) => conn.validate(id, storage_configuration).await?,
}
Ok(())
}
pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
match self {
Self::Kafka(conn) => conn,
o => unreachable!("{o:?} is not a Kafka connection"),
}
}
pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
match self {
Self::Postgres(conn) => conn,
o => unreachable!("{o:?} is not a Postgres connection"),
}
}
pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
match self {
Self::MySql(conn) => conn,
o => unreachable!("{o:?} is not a MySQL connection"),
}
}
pub fn unwrap_aws(self) -> <InlinedConnection as ConnectionAccess>::Aws {
match self {
Self::Aws(conn) => conn,
o => unreachable!("{o:?} is not an AWS connection"),
}
}
pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
match self {
Self::Ssh(conn) => conn,
o => unreachable!("{o:?} is not an SSH connection"),
}
}
pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
match self {
Self::Csr(conn) => conn,
o => unreachable!("{o:?} is not a Kafka connection"),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ConnectionValidationError {
#[error(transparent)]
Aws(#[from] AwsConnectionValidationError),
#[error("{}", .0.display_with_causes())]
Other(#[from] anyhow::Error),
}
impl ConnectionValidationError {
pub fn detail(&self) -> Option<String> {
match self {
ConnectionValidationError::Aws(e) => e.detail(),
ConnectionValidationError::Other(_) => None,
}
}
pub fn hint(&self) -> Option<String> {
match self {
ConnectionValidationError::Aws(e) => e.hint(),
ConnectionValidationError::Other(_) => None,
}
}
}
impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
match (self, other) {
(Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
(Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
(Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
(Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
(Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
(Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
(Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
_ => {
tracing::warn!(
"Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
Err(AlterError { id })
}
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct AwsPrivatelinkConnection {
pub service_name: String,
pub availability_zones: Vec<String>,
}
impl AlterCompatible for AwsPrivatelinkConnection {
fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaTlsConfig {
pub identity: Option<TlsIdentity>,
pub root_cert: Option<StringOrSecret>,
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct KafkaSaslConfig<C: ConnectionAccess = InlinedConnection> {
pub mechanism: String,
pub username: StringOrSecret,
pub password: Option<CatalogItemId>,
pub aws: Option<AwsConnectionReference<C>>,
}
impl<R: ConnectionResolver> IntoInlineConnection<KafkaSaslConfig, R>
for KafkaSaslConfig<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> KafkaSaslConfig {
KafkaSaslConfig {
mechanism: self.mechanism,
username: self.username,
password: self.password,
aws: self.aws.map(|aws| aws.into_inline_connection(&r)),
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
pub address: String,
pub tunnel: Tunnel<C>,
}
impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
for KafkaBroker<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> KafkaBroker {
let KafkaBroker { address, tunnel } = self;
KafkaBroker {
address,
tunnel: tunnel.into_inline_connection(r),
}
}
}
impl RustType<ProtoKafkaBroker> for KafkaBroker {
fn into_proto(&self) -> ProtoKafkaBroker {
ProtoKafkaBroker {
address: self.address.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoKafkaBroker) -> Result<Self, TryFromProtoError> {
Ok(KafkaBroker {
address: proto.address.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoKafkaConnection::tunnel")?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Default)]
pub struct KafkaTopicOptions {
pub replication_factor: Option<NonNeg<i32>>,
pub partition_count: Option<NonNeg<i32>>,
pub topic_config: BTreeMap<String, String>,
}
impl RustType<ProtoKafkaTopicOptions> for KafkaTopicOptions {
fn into_proto(&self) -> ProtoKafkaTopicOptions {
ProtoKafkaTopicOptions {
replication_factor: self.replication_factor.map(|f| *f),
partition_count: self.partition_count.map(|f| *f),
topic_config: self.topic_config.clone(),
}
}
fn from_proto(proto: ProtoKafkaTopicOptions) -> Result<Self, TryFromProtoError> {
Ok(KafkaTopicOptions {
replication_factor: proto.replication_factor.map(NonNeg::try_from).transpose()?,
partition_count: proto.partition_count.map(NonNeg::try_from).transpose()?,
topic_config: proto.topic_config,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
pub brokers: Vec<KafkaBroker<C>>,
pub default_tunnel: Tunnel<C>,
pub progress_topic: Option<String>,
pub progress_topic_options: KafkaTopicOptions,
pub options: BTreeMap<String, StringOrSecret>,
pub tls: Option<KafkaTlsConfig>,
pub sasl: Option<KafkaSaslConfig<C>>,
}
impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
for KafkaConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> KafkaConnection {
let KafkaConnection {
brokers,
progress_topic,
progress_topic_options,
default_tunnel,
options,
tls,
sasl,
} = self;
let brokers = brokers
.into_iter()
.map(|broker| broker.into_inline_connection(&r))
.collect();
KafkaConnection {
brokers,
progress_topic,
progress_topic_options,
default_tunnel: default_tunnel.into_inline_connection(&r),
options,
tls,
sasl: sasl.map(|sasl| sasl.into_inline_connection(&r)),
}
}
}
impl<C: ConnectionAccess> KafkaConnection<C> {
pub fn progress_topic(
&self,
connection_context: &ConnectionContext,
connection_id: CatalogItemId,
) -> Cow<str> {
if let Some(progress_topic) = &self.progress_topic {
Cow::Borrowed(progress_topic)
} else {
Cow::Owned(format!(
"_materialize-progress-{}-{}",
connection_context.environment_id, connection_id,
))
}
}
fn validate_by_default(&self) -> bool {
true
}
}
impl KafkaConnection {
pub fn id_base(
connection_context: &ConnectionContext,
connection_id: CatalogItemId,
object_id: GlobalId,
) -> String {
format!(
"materialize-{}-{}-{}",
connection_context.environment_id, connection_id, object_id,
)
}
pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
#[derive(Debug, Deserialize)]
struct EnrichmentRule {
#[serde(deserialize_with = "deserialize_regex")]
pattern: Regex,
payload: String,
}
fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
where
D: Deserializer<'de>,
{
let buf = String::deserialize(deserializer)?;
Regex::new(&buf).map_err(serde::de::Error::custom)
}
let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
Ok(rules) => rules,
Err(e) => {
warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
return;
}
};
debug!(?self.brokers, "evaluating client ID enrichment rules");
for rule in rules {
let is_match = self
.brokers
.iter()
.any(|b| rule.pattern.is_match(&b.address));
debug!(?rule, is_match, "evaluated client ID enrichment rule");
if is_match {
client_id.push('-');
client_id.push_str(&rule.payload);
}
}
}
pub async fn create_with_context<C, T>(
&self,
storage_configuration: &StorageConfiguration,
context: C,
extra_options: &BTreeMap<&str, String>,
in_task: InTask,
) -> Result<T, ContextCreationError>
where
C: ClientContext,
T: FromClientConfigAndContext<TunnelingClientContext<C>>,
{
let mut options = self.options.clone();
options.insert("allow.auto.create.topics".into(), "false".into());
let brokers = match &self.default_tunnel {
Tunnel::AwsPrivatelink(t) => {
assert!(&self.brokers.is_empty());
let algo = KAFKA_DEFAULT_AWS_PRIVATELINK_ENDPOINT_IDENTIFICATION_ALGORITHM
.get(storage_configuration.config_set());
options.insert("ssl.endpoint.identification.algorithm".into(), algo.into());
format!(
"{}:{}",
vpc_endpoint_host(
t.connection_id,
None, ),
t.port.unwrap_or(9092)
)
}
_ => self.brokers.iter().map(|b| &b.address).join(","),
};
options.insert("bootstrap.servers".into(), brokers.into());
let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
(false, false) => "PLAINTEXT",
(true, false) => "SSL",
(false, true) => "SASL_PLAINTEXT",
(true, true) => "SASL_SSL",
};
options.insert("security.protocol".into(), security_protocol.into());
if let Some(tls) = &self.tls {
if let Some(root_cert) = &tls.root_cert {
options.insert("ssl.ca.pem".into(), root_cert.clone());
}
if let Some(identity) = &tls.identity {
options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
options.insert("ssl.certificate.pem".into(), identity.cert.clone());
}
}
if let Some(sasl) = &self.sasl {
options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
options.insert("sasl.username".into(), sasl.username.clone());
if let Some(password) = sasl.password {
options.insert("sasl.password".into(), StringOrSecret::Secret(password));
}
}
let mut config = mz_kafka_util::client::create_new_client_config(
storage_configuration
.connection_context
.librdkafka_log_level,
storage_configuration.parameters.kafka_timeout_config,
);
for (k, v) in options {
config.set(
k,
v.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await
.context("reading kafka secret")?,
);
}
for (k, v) in extra_options {
config.set(*k, v);
}
let aws_config = match self.sasl.as_ref().and_then(|sasl| sasl.aws.as_ref()) {
None => None,
Some(aws) => Some(
aws.connection
.load_sdk_config(
&storage_configuration.connection_context,
aws.connection_id,
in_task,
)
.await?,
),
};
let mut context = TunnelingClientContext::new(
context,
Handle::current(),
storage_configuration
.connection_context
.ssh_tunnel_manager
.clone(),
storage_configuration.parameters.ssh_timeout_config,
aws_config,
in_task,
);
match &self.default_tunnel {
Tunnel::Direct => {
}
Tunnel::AwsPrivatelink(pl) => {
context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
pl.connection_id,
None, )));
}
Tunnel::Ssh(ssh_tunnel) => {
let secret = storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, ssh_tunnel.connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&ssh_tunnel.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: ssh_tunnel.connection.port,
user: ssh_tunnel.connection.user.clone(),
key_pair,
}));
}
}
for broker in &self.brokers {
let mut addr_parts = broker.address.splitn(2, ':');
let addr = BrokerAddr {
host: addr_parts
.next()
.context("BROKER is not address:port")?
.into(),
port: addr_parts
.next()
.unwrap_or("9092")
.parse()
.context("parsing BROKER port")?,
};
match &broker.tunnel {
Tunnel::Direct => {
}
Tunnel::AwsPrivatelink(aws_privatelink) => {
let host = mz_cloud_resources::vpc_endpoint_host(
aws_privatelink.connection_id,
aws_privatelink.availability_zone.as_deref(),
);
let port = aws_privatelink.port;
context.add_broker_rewrite(
addr,
BrokerRewrite {
host: host.clone(),
port,
},
);
}
Tunnel::Ssh(ssh_tunnel) => {
let ssh_host_resolved = resolve_address(
&ssh_tunnel.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
context
.add_ssh_tunnel(
addr,
SshTunnelConfig {
host: ssh_host_resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: ssh_tunnel.connection.port,
user: ssh_tunnel.connection.user.clone(),
key_pair: SshKeyPair::from_bytes(
&storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, ssh_tunnel.connection_id)
.await?,
)?,
},
)
.await
.map_err(ContextCreationError::Ssh)?;
}
}
}
Ok(config.create_with_context(context)?)
}
async fn validate(
&self,
_id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let (context, error_rx) = MzClientContext::with_errors();
let consumer: BaseConsumer<_> = self
.create_with_context(
storage_configuration,
context,
&BTreeMap::new(),
InTask::No,
)
.await?;
let consumer = Arc::new(consumer);
let timeout = storage_configuration
.parameters
.kafka_timeout_config
.fetch_metadata_timeout;
let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
let consumer = Arc::clone(&consumer);
move || consumer.fetch_metadata(None, timeout)
})
.await?;
match result {
Ok(_) => Ok(()),
Err(err) => {
let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
MzKafkaError::Internal(_) => new,
_ => cur,
});
drop(consumer);
match main_err {
Some(err) => Err(err.into()),
None => Err(err.into()),
}
}
}
}
}
impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let KafkaConnection {
brokers: _,
default_tunnel: _,
progress_topic,
progress_topic_options,
options: _,
tls: _,
sasl: _,
} = self;
let compatibility_checks = [
(progress_topic == &other.progress_topic, "progress_topic"),
(
progress_topic_options == &other.progress_topic_options,
"progress_topic_options",
),
];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl RustType<ProtoKafkaConnectionTlsConfig> for KafkaTlsConfig {
fn into_proto(&self) -> ProtoKafkaConnectionTlsConfig {
ProtoKafkaConnectionTlsConfig {
identity: self.identity.into_proto(),
root_cert: self.root_cert.into_proto(),
}
}
fn from_proto(proto: ProtoKafkaConnectionTlsConfig) -> Result<Self, TryFromProtoError> {
Ok(KafkaTlsConfig {
root_cert: proto.root_cert.into_rust()?,
identity: proto.identity.into_rust()?,
})
}
}
impl RustType<ProtoKafkaConnectionSaslConfig> for KafkaSaslConfig {
fn into_proto(&self) -> ProtoKafkaConnectionSaslConfig {
ProtoKafkaConnectionSaslConfig {
mechanism: self.mechanism.into_proto(),
username: Some(self.username.into_proto()),
password: self.password.into_proto(),
aws: self.aws.into_proto(),
}
}
fn from_proto(proto: ProtoKafkaConnectionSaslConfig) -> Result<Self, TryFromProtoError> {
Ok(KafkaSaslConfig {
mechanism: proto.mechanism,
username: proto
.username
.into_rust_if_some("ProtoKafkaConnectionSaslConfig::username")?,
password: proto.password.into_rust()?,
aws: proto.aws.into_rust()?,
})
}
}
impl RustType<ProtoKafkaConnection> for KafkaConnection {
fn into_proto(&self) -> ProtoKafkaConnection {
ProtoKafkaConnection {
brokers: self.brokers.into_proto(),
default_tunnel: Some(self.default_tunnel.into_proto()),
progress_topic: self.progress_topic.into_proto(),
progress_topic_options: Some(self.progress_topic_options.into_proto()),
options: self
.options
.iter()
.map(|(k, v)| (k.clone(), v.into_proto()))
.collect(),
tls: self.tls.into_proto(),
sasl: self.sasl.into_proto(),
}
}
fn from_proto(proto: ProtoKafkaConnection) -> Result<Self, TryFromProtoError> {
Ok(KafkaConnection {
brokers: proto.brokers.into_rust()?,
default_tunnel: proto
.default_tunnel
.into_rust_if_some("ProtoKafkaConnection::default_tunnel")?,
progress_topic: proto.progress_topic,
progress_topic_options: match proto.progress_topic_options {
Some(progress_topic_options) => progress_topic_options.into_rust()?,
None => Default::default(),
},
options: proto
.options
.into_iter()
.map(|(k, v)| StringOrSecret::from_proto(v).map(|v| (k, v)))
.collect::<Result<_, _>>()?,
tls: proto.tls.into_rust()?,
sasl: proto.sasl.into_rust()?,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
#[proptest(strategy = "any_url()")]
pub url: Url,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
pub http_auth: Option<CsrConnectionHttpAuth>,
pub tunnel: Tunnel<C>,
}
impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
for CsrConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> CsrConnection {
let CsrConnection {
url,
tls_root_cert,
tls_identity,
http_auth,
tunnel,
} = self;
CsrConnection {
url,
tls_root_cert,
tls_identity,
http_auth,
tunnel: tunnel.into_inline_connection(r),
}
}
}
impl<C: ConnectionAccess> CsrConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl CsrConnection {
pub async fn connect(
&self,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_ccsr::Client, CsrConnectError> {
let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
if let Some(root_cert) = &self.tls_root_cert {
let root_cert = root_cert
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
client_config = client_config.add_root_certificate(root_cert);
}
if let Some(tls_identity) = &self.tls_identity {
let key = &storage_configuration
.connection_context
.secrets_reader
.read_string_in_task_if(in_task, tls_identity.key)
.await?;
let cert = tls_identity
.cert
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
client_config = client_config.identity(ident);
}
if let Some(http_auth) = &self.http_auth {
let username = http_auth
.username
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let password = match http_auth.password {
None => None,
Some(password) => Some(
storage_configuration
.connection_context
.secrets_reader
.read_string_in_task_if(in_task, password)
.await?,
),
};
client_config = client_config.auth(username, password);
}
const DUMMY_PORT: u16 = 11111;
let host = self
.url
.host_str()
.ok_or_else(|| anyhow!("url missing host"))?;
match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
client_config = client_config.resolve_to_addrs(
host,
&resolved
.iter()
.map(|addr| SocketAddr::new(*addr, DUMMY_PORT))
.collect::<Vec<_>>(),
)
}
Tunnel::Ssh(ssh_tunnel) => {
let ssh_tunnel = ssh_tunnel
.connect(
storage_configuration,
host,
self.url.port().unwrap_or(80),
in_task,
)
.await
.map_err(CsrConnectError::Ssh)?;
client_config = client_config
.resolve_to_addrs(
host,
&[SocketAddr::new(ssh_tunnel.local_addr().ip(), DUMMY_PORT)],
)
.dynamic_url({
let remote_url = self.url.clone();
move || {
let mut url = remote_url.clone();
url.set_port(Some(ssh_tunnel.local_addr().port()))
.expect("cannot fail");
url
}
});
}
Tunnel::AwsPrivatelink(connection) => {
assert_none!(connection.port);
let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
connection.connection_id,
connection.availability_zone.as_deref(),
);
let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_PORT))
.await
.context("resolving PrivateLink host")?
.collect();
client_config = client_config.resolve_to_addrs(host, &addrs)
}
}
Ok(client_config.build()?)
}
async fn validate(
&self,
_id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let client = self
.connect(
storage_configuration,
InTask::No,
)
.await?;
client.list_subjects().await?;
Ok(())
}
}
impl RustType<ProtoCsrConnection> for CsrConnection {
fn into_proto(&self) -> ProtoCsrConnection {
ProtoCsrConnection {
url: Some(self.url.into_proto()),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
http_auth: self.http_auth.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoCsrConnection) -> Result<Self, TryFromProtoError> {
Ok(CsrConnection {
url: proto.url.into_rust_if_some("ProtoCsrConnection::url")?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
http_auth: proto.http_auth.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoCsrConnection::tunnel")?,
})
}
}
impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let CsrConnection {
tunnel,
url: _,
tls_root_cert: _,
tls_identity: _,
http_auth: _,
} = self;
let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct TlsIdentity {
pub cert: StringOrSecret,
pub key: CatalogItemId,
}
impl RustType<ProtoTlsIdentity> for TlsIdentity {
fn into_proto(&self) -> ProtoTlsIdentity {
ProtoTlsIdentity {
cert: Some(self.cert.into_proto()),
key: Some(self.key.into_proto()),
}
}
fn from_proto(proto: ProtoTlsIdentity) -> Result<Self, TryFromProtoError> {
Ok(TlsIdentity {
cert: proto.cert.into_rust_if_some("ProtoTlsIdentity::cert")?,
key: proto.key.into_rust_if_some("ProtoTlsIdentity::key")?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct CsrConnectionHttpAuth {
pub username: StringOrSecret,
pub password: Option<CatalogItemId>,
}
impl RustType<ProtoCsrConnectionHttpAuth> for CsrConnectionHttpAuth {
fn into_proto(&self) -> ProtoCsrConnectionHttpAuth {
ProtoCsrConnectionHttpAuth {
username: Some(self.username.into_proto()),
password: self.password.into_proto(),
}
}
fn from_proto(proto: ProtoCsrConnectionHttpAuth) -> Result<Self, TryFromProtoError> {
Ok(CsrConnectionHttpAuth {
username: proto
.username
.into_rust_if_some("ProtoCsrConnectionHttpAuth::username")?,
password: proto.password.into_rust()?,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
pub host: String,
pub port: u16,
pub database: String,
pub user: StringOrSecret,
pub password: Option<CatalogItemId>,
pub tunnel: Tunnel<C>,
#[proptest(strategy = "any_ssl_mode()")]
pub tls_mode: SslMode,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
pub flavor: PostgresFlavor,
}
impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
for PostgresConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> PostgresConnection {
let PostgresConnection {
host,
port,
database,
user,
password,
tunnel,
tls_mode,
tls_root_cert,
tls_identity,
flavor,
} = self;
PostgresConnection {
host,
port,
database,
user,
password,
tunnel: tunnel.into_inline_connection(r),
tls_mode,
tls_root_cert,
tls_identity,
flavor,
}
}
}
impl<C: ConnectionAccess> PostgresConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl PostgresConnection<InlinedConnection> {
pub async fn config(
&self,
secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_postgres_util::Config, anyhow::Error> {
let params = &storage_configuration.parameters;
let mut config = tokio_postgres::Config::new();
config
.host(&self.host)
.port(self.port)
.dbname(&self.database)
.user(&self.user.get_string(in_task, secrets_reader).await?)
.ssl_mode(self.tls_mode);
if let Some(password) = self.password {
let password = secrets_reader
.read_string_in_task_if(in_task, password)
.await?;
config.password(password);
}
if let Some(tls_root_cert) = &self.tls_root_cert {
let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
config.ssl_root_cert(tls_root_cert.as_bytes());
}
if let Some(tls_identity) = &self.tls_identity {
let cert = tls_identity
.cert
.get_string(in_task, secrets_reader)
.await?;
let key = secrets_reader
.read_string_in_task_if(in_task, tls_identity.key)
.await?;
config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
}
if let Some(connect_timeout) = params.pg_source_connect_timeout {
config.connect_timeout(connect_timeout);
}
if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
config.keepalives_retries(keepalives_retries);
}
if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
config.keepalives_idle(keepalives_idle);
}
if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
config.keepalives_interval(keepalives_interval);
}
if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
config.tcp_user_timeout(tcp_user_timeout);
}
let mut options = vec![];
if let Some(wal_sender_timeout) = params.pg_source_wal_sender_timeout {
options.push(format!(
"--wal_sender_timeout={}",
wal_sender_timeout.as_millis()
));
};
if params.pg_source_tcp_configure_server {
if let Some(keepalives_retries) = params.pg_source_tcp_keepalives_retries {
options.push(format!("--tcp_keepalives_count={}", keepalives_retries));
}
if let Some(keepalives_idle) = params.pg_source_tcp_keepalives_idle {
options.push(format!(
"--tcp_keepalives_idle={}",
keepalives_idle.as_secs()
));
}
if let Some(keepalives_interval) = params.pg_source_tcp_keepalives_interval {
options.push(format!(
"--tcp_keepalives_interval={}",
keepalives_interval.as_secs()
));
}
if let Some(tcp_user_timeout) = params.pg_source_tcp_user_timeout {
options.push(format!(
"--tcp_user_timeout={}",
tcp_user_timeout.as_millis()
));
}
}
config.options(options.join(" ").as_str());
let tunnel = match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_postgres_util::TunnelConfig::Direct {
resolved_ips: Some(resolved),
}
}
Tunnel::Ssh(SshTunnel {
connection_id,
connection,
}) => {
let secret = secrets_reader
.read_in_task_if(in_task, *connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_postgres_util::TunnelConfig::Ssh {
config: SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: connection.port,
user: connection.user.clone(),
key_pair,
},
}
}
Tunnel::AwsPrivatelink(connection) => {
assert_none!(connection.port);
mz_postgres_util::TunnelConfig::AwsPrivatelink {
connection_id: connection.connection_id,
}
}
};
Ok(mz_postgres_util::Config::new(
config,
tunnel,
params.ssh_timeout_config,
in_task,
)?)
}
async fn validate(
&self,
_id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let config = self
.config(
&storage_configuration.connection_context.secrets_reader,
storage_configuration,
InTask::No,
)
.await?;
let client = config
.connect(
"connection validation",
&storage_configuration.connection_context.ssh_tunnel_manager,
)
.await?;
use PostgresFlavor::*;
match (client.server_flavor(), &self.flavor) {
(Vanilla, Yugabyte) => bail!("Expected to find PostgreSQL server, found Yugabyte."),
(Yugabyte, Vanilla) => bail!("Expected to find Yugabyte server, found PostgreSQL."),
(Vanilla, Vanilla) | (Yugabyte, Yugabyte) => {}
}
Ok(())
}
}
impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let PostgresConnection {
tunnel,
flavor,
host: _,
port: _,
database: _,
user: _,
password: _,
tls_mode: _,
tls_root_cert: _,
tls_identity: _,
} = self;
let compatibility_checks = [
(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel"),
(flavor == &other.flavor, "flavor"),
];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl RustType<ProtoPostgresConnection> for PostgresConnection {
fn into_proto(&self) -> ProtoPostgresConnection {
ProtoPostgresConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
database: self.database.into_proto(),
user: Some(self.user.into_proto()),
password: self.password.into_proto(),
tls_mode: Some(self.tls_mode.into_proto()),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
flavor: Some(self.flavor.into_proto()),
}
}
fn from_proto(proto: ProtoPostgresConnection) -> Result<Self, TryFromProtoError> {
Ok(PostgresConnection {
host: proto.host,
port: proto.port.into_rust()?,
database: proto.database,
user: proto
.user
.into_rust_if_some("ProtoPostgresConnection::user")?,
password: proto.password.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoPostgresConnection::tunnel")?,
tls_mode: proto
.tls_mode
.into_rust_if_some("ProtoPostgresConnection::tls_mode")?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
flavor: proto
.flavor
.into_rust_if_some("ProtoPostgresConnection::flavor")?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
Direct,
Ssh(SshTunnel<C>),
AwsPrivatelink(AwsPrivatelink),
}
impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
fn into_inline_connection(self, r: R) -> Tunnel {
match self {
Tunnel::Direct => Tunnel::Direct,
Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
}
}
}
impl RustType<ProtoTunnel> for Tunnel<InlinedConnection> {
fn into_proto(&self) -> ProtoTunnel {
use proto_tunnel::Tunnel as ProtoTunnelField;
ProtoTunnel {
tunnel: Some(match &self {
Tunnel::Direct => ProtoTunnelField::Direct(()),
Tunnel::Ssh(ssh) => ProtoTunnelField::Ssh(ssh.into_proto()),
Tunnel::AwsPrivatelink(aws) => ProtoTunnelField::AwsPrivatelink(aws.into_proto()),
}),
}
}
fn from_proto(proto: ProtoTunnel) -> Result<Self, TryFromProtoError> {
use proto_tunnel::Tunnel as ProtoTunnelField;
Ok(match proto.tunnel {
None => return Err(TryFromProtoError::missing_field("ProtoTunnel::tunnel")),
Some(ProtoTunnelField::Direct(())) => Tunnel::Direct,
Some(ProtoTunnelField::Ssh(ssh)) => Tunnel::Ssh(ssh.into_rust()?),
Some(ProtoTunnelField::AwsPrivatelink(aws)) => Tunnel::AwsPrivatelink(aws.into_rust()?),
})
}
}
impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let compatible = match (self, other) {
(Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
(s, o) => s == o,
};
if !compatible {
tracing::warn!(
"Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum MySqlSslMode {
Disabled,
Required,
VerifyCa,
VerifyIdentity,
}
impl RustType<i32> for MySqlSslMode {
fn into_proto(&self) -> i32 {
match self {
MySqlSslMode::Disabled => ProtoMySqlSslMode::Disabled.into(),
MySqlSslMode::Required => ProtoMySqlSslMode::Required.into(),
MySqlSslMode::VerifyCa => ProtoMySqlSslMode::VerifyCa.into(),
MySqlSslMode::VerifyIdentity => ProtoMySqlSslMode::VerifyIdentity.into(),
}
}
fn from_proto(proto: i32) -> Result<Self, TryFromProtoError> {
Ok(match ProtoMySqlSslMode::try_from(proto) {
Ok(ProtoMySqlSslMode::Disabled) => MySqlSslMode::Disabled,
Ok(ProtoMySqlSslMode::Required) => MySqlSslMode::Required,
Ok(ProtoMySqlSslMode::VerifyCa) => MySqlSslMode::VerifyCa,
Ok(ProtoMySqlSslMode::VerifyIdentity) => MySqlSslMode::VerifyIdentity,
Err(_) => {
return Err(TryFromProtoError::UnknownEnumVariant(
"tls_mode".to_string(),
))
}
})
}
}
pub fn any_mysql_ssl_mode() -> impl Strategy<Value = MySqlSslMode> {
proptest::sample::select(vec![
MySqlSslMode::Disabled,
MySqlSslMode::Required,
MySqlSslMode::VerifyCa,
MySqlSslMode::VerifyIdentity,
])
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
pub host: String,
pub port: u16,
pub user: StringOrSecret,
pub password: Option<CatalogItemId>,
pub tunnel: Tunnel<C>,
#[proptest(strategy = "any_mysql_ssl_mode()")]
pub tls_mode: MySqlSslMode,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
}
impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
for MySqlConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> MySqlConnection {
let MySqlConnection {
host,
port,
user,
password,
tunnel,
tls_mode,
tls_root_cert,
tls_identity,
} = self;
MySqlConnection {
host,
port,
user,
password,
tunnel: tunnel.into_inline_connection(r),
tls_mode,
tls_root_cert,
tls_identity,
}
}
}
impl<C: ConnectionAccess> MySqlConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl MySqlConnection<InlinedConnection> {
pub async fn config(
&self,
secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_mysql_util::Config, anyhow::Error> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname(&self.host)
.tcp_port(self.port)
.user(Some(&self.user.get_string(in_task, secrets_reader).await?));
if let Some(password) = self.password {
let password = secrets_reader
.read_string_in_task_if(in_task, password)
.await?;
opts = opts.pass(Some(password));
}
let mut ssl_opts = match self.tls_mode {
MySqlSslMode::Disabled => None,
MySqlSslMode::Required => Some(
mysql_async::SslOpts::default()
.with_danger_accept_invalid_certs(true)
.with_danger_skip_domain_validation(true),
),
MySqlSslMode::VerifyCa => {
Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
}
MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
};
if matches!(
self.tls_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
) {
if let Some(tls_root_cert) = &self.tls_root_cert {
let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
ssl_opts = ssl_opts.map(|opts| {
opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
});
}
}
if let Some(identity) = &self.tls_identity {
let key = secrets_reader
.read_string_in_task_if(in_task, identity.key)
.await?;
let cert = identity.cert.get_string(in_task, secrets_reader).await?;
let Pkcs12Archive { der, pass } =
mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
ssl_opts = ssl_opts.map(|opts| {
opts.with_client_identity(Some(
mysql_async::ClientIdentity::new(der.into()).with_password(pass),
))
});
}
opts = opts.ssl_opts(ssl_opts);
let tunnel = match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_mysql_util::TunnelConfig::Direct {
resolved_ips: Some(resolved),
}
}
Tunnel::Ssh(SshTunnel {
connection_id,
connection,
}) => {
let secret = secrets_reader
.read_in_task_if(in_task, *connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_mysql_util::TunnelConfig::Ssh {
config: SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: connection.port,
user: connection.user.clone(),
key_pair,
},
}
}
Tunnel::AwsPrivatelink(connection) => {
assert_none!(connection.port);
mz_mysql_util::TunnelConfig::AwsPrivatelink {
connection_id: connection.connection_id,
}
}
};
opts = storage_configuration
.parameters
.mysql_source_timeouts
.apply_to_opts(opts)?;
Ok(mz_mysql_util::Config::new(
opts.into(),
tunnel,
storage_configuration.parameters.ssh_timeout_config,
in_task,
))
}
async fn validate(
&self,
_id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let config = self
.config(
&storage_configuration.connection_context.secrets_reader,
storage_configuration,
InTask::No,
)
.await?;
let conn = config
.connect(
"connection validation",
&storage_configuration.connection_context.ssh_tunnel_manager,
)
.await?;
conn.disconnect().await?;
Ok(())
}
}
impl RustType<ProtoMySqlConnection> for MySqlConnection {
fn into_proto(&self) -> ProtoMySqlConnection {
ProtoMySqlConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
user: Some(self.user.into_proto()),
password: self.password.into_proto(),
tls_mode: self.tls_mode.into_proto(),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoMySqlConnection) -> Result<Self, TryFromProtoError> {
Ok(MySqlConnection {
host: proto.host,
port: proto.port.into_rust()?,
user: proto.user.into_rust_if_some("ProtoMySqlConnection::user")?,
password: proto.password.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoMySqlConnection::tunnel")?,
tls_mode: proto.tls_mode.into_rust()?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
})
}
}
impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let MySqlConnection {
tunnel,
host: _,
port: _,
user: _,
password: _,
tls_mode: _,
tls_root_cert: _,
tls_identity: _,
} = self;
let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct SshConnection {
pub host: String,
pub port: u16,
pub user: String,
}
use self::inline::{
ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
ReferencedConnection,
};
impl RustType<ProtoSshConnection> for SshConnection {
fn into_proto(&self) -> ProtoSshConnection {
ProtoSshConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
user: self.user.into_proto(),
}
}
fn from_proto(proto: ProtoSshConnection) -> Result<Self, TryFromProtoError> {
Ok(SshConnection {
host: proto.host,
port: proto.port.into_rust()?,
user: proto.user,
})
}
}
impl AlterCompatible for SshConnection {
fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct AwsPrivatelink {
pub connection_id: CatalogItemId,
pub availability_zone: Option<String>,
pub port: Option<u16>,
}
impl RustType<ProtoAwsPrivatelink> for AwsPrivatelink {
fn into_proto(&self) -> ProtoAwsPrivatelink {
ProtoAwsPrivatelink {
connection_id: Some(self.connection_id.into_proto()),
availability_zone: self.availability_zone.into_proto(),
port: self.port.into_proto(),
}
}
fn from_proto(proto: ProtoAwsPrivatelink) -> Result<Self, TryFromProtoError> {
Ok(AwsPrivatelink {
connection_id: proto
.connection_id
.into_rust_if_some("ProtoAwsPrivatelink::connection_id")?,
availability_zone: proto.availability_zone.into_rust()?,
port: proto.port.into_rust()?,
})
}
}
impl AlterCompatible for AwsPrivatelink {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let AwsPrivatelink {
connection_id,
availability_zone: _,
port: _,
} = self;
let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
pub connection_id: CatalogItemId,
pub connection: C::Ssh,
}
impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
fn into_inline_connection(self, r: R) -> SshTunnel {
let SshTunnel {
connection,
connection_id,
} = self;
SshTunnel {
connection: r.resolve_connection(connection).unwrap_ssh(),
connection_id,
}
}
}
impl RustType<ProtoSshTunnel> for SshTunnel<InlinedConnection> {
fn into_proto(&self) -> ProtoSshTunnel {
ProtoSshTunnel {
connection_id: Some(self.connection_id.into_proto()),
connection: Some(self.connection.into_proto()),
}
}
fn from_proto(proto: ProtoSshTunnel) -> Result<Self, TryFromProtoError> {
Ok(SshTunnel {
connection_id: proto
.connection_id
.into_rust_if_some("ProtoSshTunnel::connection_id")?,
connection: proto
.connection
.into_rust_if_some("ProtoSshTunnel::connection")?,
})
}
}
impl SshTunnel<InlinedConnection> {
async fn connect(
&self,
storage_configuration: &StorageConfiguration,
remote_host: &str,
remote_port: u16,
in_task: InTask,
) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
let resolved = resolve_address(
&self.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
storage_configuration
.connection_context
.ssh_tunnel_manager
.connect(
SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: self.connection.port,
user: self.connection.user.clone(),
key_pair: SshKeyPair::from_bytes(
&storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, self.connection_id)
.await?,
)?,
},
remote_host,
remote_port,
storage_configuration.parameters.ssh_timeout_config,
in_task,
)
.await
}
}
impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let SshTunnel {
connection_id,
connection,
} = self;
let compatibility_checks = [
(connection_id == &other.connection_id, "connection_id"),
(
connection.alter_compatible(id, &other.connection).is_ok(),
"connection",
),
];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl SshConnection {
#[allow(clippy::unused_async)]
async fn validate(
&self,
id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let secret = storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(
InTask::No,
id,
)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
let config = SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: self.port,
user: self.user.clone(),
key_pair,
};
config
.validate(storage_configuration.parameters.ssh_timeout_config)
.await
}
fn validate_by_default(&self) -> bool {
false
}
}
impl AwsPrivatelinkConnection {
#[allow(clippy::unused_async)]
async fn validate(
&self,
id: CatalogItemId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let Some(ref cloud_resource_reader) = storage_configuration
.connection_context
.cloud_resource_reader
else {
return Err(anyhow!("AWS PrivateLink connections are unsupported"));
};
let status = cloud_resource_reader.read(id).await?;
let availability = status
.conditions
.as_ref()
.and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
match availability {
Some(condition) if condition.status == "True" => Ok(()),
Some(condition) => Err(anyhow!("{}", condition.message)),
None => Err(anyhow!("Endpoint availability is unknown")),
}
}
fn validate_by_default(&self) -> bool {
false
}
}