use std::borrow::Cow;
use std::collections::{BTreeMap, BTreeSet};
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{anyhow, Context};
use itertools::Itertools;
use mz_ccsr::tls::{Certificate, Identity};
use mz_cloud_resources::{vpc_endpoint_host, AwsExternalIdPrefix, CloudResourceReader};
use mz_dyncfg::ConfigSet;
use mz_kafka_util::client::{
BrokerRewrite, MzClientContext, MzKafkaError, TunnelConfig, TunnelingClientContext,
};
use mz_ore::error::ErrorExt;
use mz_ore::future::{InTask, OreFutureExt};
use mz_ore::netio::resolve_address;
use mz_proto::tokio_postgres::any_ssl_mode;
use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
use mz_repr::url::any_url;
use mz_repr::GlobalId;
use mz_secrets::SecretsReader;
use mz_ssh_util::keys::SshKeyPair;
use mz_ssh_util::tunnel::SshTunnelConfig;
use mz_ssh_util::tunnel_manager::{ManagedSshTunnelHandle, SshTunnelManager};
use mz_tls_util::Pkcs12Archive;
use mz_tracing::CloneableEnvFilter;
use proptest::strategy::Strategy;
use proptest_derive::Arbitrary;
use rdkafka::client::BrokerAddr;
use rdkafka::config::FromClientConfigAndContext;
use rdkafka::consumer::{BaseConsumer, Consumer};
use rdkafka::ClientContext;
use regex::Regex;
use serde::{Deserialize, Deserializer, Serialize};
use tokio::net;
use tokio::runtime::Handle;
use tokio_postgres::config::SslMode;
use tracing::{debug, warn};
use url::Url;
use crate::configuration::StorageConfiguration;
use crate::connections::aws::{AwsConnection, AwsConnectionValidationError};
use crate::controller::AlterError;
use crate::dyncfgs::{ENFORCE_EXTERNAL_ADDRESSES, KAFKA_CLIENT_ID_ENRICHMENT_RULES};
use crate::errors::{ContextCreationError, CsrConnectError};
use crate::AlterCompatible;
pub mod aws;
pub mod inline;
include!(concat!(env!("OUT_DIR"), "/mz_storage_types.connections.rs"));
#[async_trait::async_trait]
trait SecretsReaderExt {
async fn read_in_task_if(
&self,
in_task: InTask,
id: GlobalId,
) -> Result<Vec<u8>, anyhow::Error>;
async fn read_string_in_task_if(
&self,
in_task: InTask,
id: GlobalId,
) -> Result<String, anyhow::Error>;
}
#[async_trait::async_trait]
impl SecretsReaderExt for Arc<dyn SecretsReader> {
async fn read_in_task_if(
&self,
in_task: InTask,
id: GlobalId,
) -> Result<Vec<u8>, anyhow::Error> {
let sr = Arc::clone(self);
async move { sr.read(id).await }
.run_in_task_if(in_task, || "secrets_reader_read".to_string())
.await
}
async fn read_string_in_task_if(
&self,
in_task: InTask,
id: GlobalId,
) -> Result<String, anyhow::Error> {
let sr = Arc::clone(self);
async move { sr.read_string(id).await }
.run_in_task_if(in_task, || "secrets_reader_read".to_string())
.await
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum StringOrSecret {
String(String),
Secret(GlobalId),
}
impl StringOrSecret {
pub async fn get_string(
&self,
in_task: InTask,
secrets_reader: &Arc<dyn SecretsReader>,
) -> anyhow::Result<String> {
match self {
StringOrSecret::String(s) => Ok(s.clone()),
StringOrSecret::Secret(id) => secrets_reader.read_string_in_task_if(in_task, *id).await,
}
}
pub fn unwrap_string(&self) -> &str {
match self {
StringOrSecret::String(s) => s,
StringOrSecret::Secret(_) => panic!("StringOrSecret::unwrap_string called on a secret"),
}
}
pub fn unwrap_secret(&self) -> GlobalId {
match self {
StringOrSecret::String(_) => panic!("StringOrSecret::unwrap_secret called on a string"),
StringOrSecret::Secret(id) => *id,
}
}
}
impl RustType<ProtoStringOrSecret> for StringOrSecret {
fn into_proto(&self) -> ProtoStringOrSecret {
use proto_string_or_secret::Kind;
ProtoStringOrSecret {
kind: Some(match self {
StringOrSecret::String(s) => Kind::String(s.clone()),
StringOrSecret::Secret(id) => Kind::Secret(id.into_proto()),
}),
}
}
fn from_proto(proto: ProtoStringOrSecret) -> Result<Self, TryFromProtoError> {
use proto_string_or_secret::Kind;
let kind = proto
.kind
.ok_or_else(|| TryFromProtoError::missing_field("ProtoStringOrSecret::kind"))?;
Ok(match kind {
Kind::String(s) => StringOrSecret::String(s),
Kind::Secret(id) => StringOrSecret::Secret(GlobalId::from_proto(id)?),
})
}
}
impl<V: std::fmt::Display> From<V> for StringOrSecret {
fn from(v: V) -> StringOrSecret {
StringOrSecret::String(format!("{}", v))
}
}
#[derive(Debug, Clone)]
pub struct ConnectionContext {
pub environment_id: String,
pub librdkafka_log_level: tracing::Level,
pub aws_external_id_prefix: Option<AwsExternalIdPrefix>,
pub aws_connection_role_arn: Option<String>,
pub secrets_reader: Arc<dyn SecretsReader>,
pub cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
pub ssh_tunnel_manager: SshTunnelManager,
}
impl ConnectionContext {
pub fn from_cli_args(
environment_id: String,
startup_log_level: &CloneableEnvFilter,
aws_external_id_prefix: Option<AwsExternalIdPrefix>,
aws_connection_role_arn: Option<String>,
secrets_reader: Arc<dyn SecretsReader>,
cloud_resource_reader: Option<Arc<dyn CloudResourceReader>>,
) -> ConnectionContext {
ConnectionContext {
environment_id,
librdkafka_log_level: mz_ore::tracing::crate_level(
&startup_log_level.clone().into(),
"librdkafka",
),
aws_external_id_prefix,
aws_connection_role_arn,
secrets_reader,
cloud_resource_reader,
ssh_tunnel_manager: SshTunnelManager::default(),
}
}
pub fn for_tests(secrets_reader: Arc<dyn SecretsReader>) -> ConnectionContext {
ConnectionContext {
environment_id: "test-environment-id".into(),
librdkafka_log_level: tracing::Level::INFO,
aws_external_id_prefix: None,
aws_connection_role_arn: None,
secrets_reader,
cloud_resource_reader: None,
ssh_tunnel_manager: SshTunnelManager::default(),
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum Connection<C: ConnectionAccess = InlinedConnection> {
Kafka(KafkaConnection<C>),
Csr(CsrConnection<C>),
Postgres(PostgresConnection<C>),
Ssh(SshConnection),
Aws(AwsConnection),
AwsPrivatelink(AwsPrivatelinkConnection),
MySql(MySqlConnection<C>),
}
impl<R: ConnectionResolver> IntoInlineConnection<Connection, R>
for Connection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> Connection {
match self {
Connection::Kafka(kafka) => Connection::Kafka(kafka.into_inline_connection(r)),
Connection::Csr(csr) => Connection::Csr(csr.into_inline_connection(r)),
Connection::Postgres(pg) => Connection::Postgres(pg.into_inline_connection(r)),
Connection::Ssh(ssh) => Connection::Ssh(ssh),
Connection::Aws(aws) => Connection::Aws(aws),
Connection::AwsPrivatelink(awspl) => Connection::AwsPrivatelink(awspl),
Connection::MySql(mysql) => Connection::MySql(mysql.into_inline_connection(r)),
}
}
}
impl<C: ConnectionAccess> Connection<C> {
pub fn validate_by_default(&self) -> bool {
match self {
Connection::Kafka(conn) => conn.validate_by_default(),
Connection::Csr(conn) => conn.validate_by_default(),
Connection::Postgres(conn) => conn.validate_by_default(),
Connection::Ssh(conn) => conn.validate_by_default(),
Connection::Aws(conn) => conn.validate_by_default(),
Connection::AwsPrivatelink(conn) => conn.validate_by_default(),
Connection::MySql(conn) => conn.validate_by_default(),
}
}
}
impl Connection<InlinedConnection> {
pub async fn validate(
&self,
id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), ConnectionValidationError> {
match self {
Connection::Kafka(conn) => conn.validate(id, storage_configuration).await?,
Connection::Csr(conn) => conn.validate(id, storage_configuration).await?,
Connection::Postgres(conn) => conn.validate(id, storage_configuration).await?,
Connection::Ssh(conn) => conn.validate(id, storage_configuration).await?,
Connection::Aws(conn) => conn.validate(id, storage_configuration).await?,
Connection::AwsPrivatelink(conn) => conn.validate(id, storage_configuration).await?,
Connection::MySql(conn) => conn.validate(id, storage_configuration).await?,
}
Ok(())
}
pub fn unwrap_kafka(self) -> <InlinedConnection as ConnectionAccess>::Kafka {
match self {
Self::Kafka(conn) => conn,
o => unreachable!("{o:?} is not a Kafka connection"),
}
}
pub fn unwrap_pg(self) -> <InlinedConnection as ConnectionAccess>::Pg {
match self {
Self::Postgres(conn) => conn,
o => unreachable!("{o:?} is not a Postgres connection"),
}
}
pub fn unwrap_mysql(self) -> <InlinedConnection as ConnectionAccess>::MySql {
match self {
Self::MySql(conn) => conn,
o => unreachable!("{o:?} is not a MySQL connection"),
}
}
pub fn unwrap_ssh(self) -> <InlinedConnection as ConnectionAccess>::Ssh {
match self {
Self::Ssh(conn) => conn,
o => unreachable!("{o:?} is not an SSH connection"),
}
}
pub fn unwrap_csr(self) -> <InlinedConnection as ConnectionAccess>::Csr {
match self {
Self::Csr(conn) => conn,
o => unreachable!("{o:?} is not a Kafka connection"),
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum ConnectionValidationError {
#[error(transparent)]
Aws(#[from] AwsConnectionValidationError),
#[error("{}", .0.display_with_causes())]
Other(#[from] anyhow::Error),
}
impl ConnectionValidationError {
pub fn detail(&self) -> Option<String> {
match self {
ConnectionValidationError::Aws(e) => e.detail(),
ConnectionValidationError::Other(_) => None,
}
}
pub fn hint(&self) -> Option<String> {
match self {
ConnectionValidationError::Aws(e) => e.hint(),
ConnectionValidationError::Other(_) => None,
}
}
}
impl<C: ConnectionAccess> AlterCompatible for Connection<C> {
fn alter_compatible(&self, id: mz_repr::GlobalId, other: &Self) -> Result<(), AlterError> {
match (self, other) {
(Self::Aws(s), Self::Aws(o)) => s.alter_compatible(id, o),
(Self::AwsPrivatelink(s), Self::AwsPrivatelink(o)) => s.alter_compatible(id, o),
(Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o),
(Self::Csr(s), Self::Csr(o)) => s.alter_compatible(id, o),
(Self::Kafka(s), Self::Kafka(o)) => s.alter_compatible(id, o),
(Self::Postgres(s), Self::Postgres(o)) => s.alter_compatible(id, o),
(Self::MySql(s), Self::MySql(o)) => s.alter_compatible(id, o),
_ => {
tracing::warn!(
"Connection incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
Err(AlterError { id })
}
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct AwsPrivatelinkConnection {
pub service_name: String,
pub availability_zones: Vec<String>,
}
impl AlterCompatible for AwsPrivatelinkConnection {
fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaTlsConfig {
pub identity: Option<TlsIdentity>,
pub root_cert: Option<StringOrSecret>,
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct KafkaSaslConfig {
pub mechanism: String,
pub username: StringOrSecret,
pub password: GlobalId,
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaBroker<C: ConnectionAccess = InlinedConnection> {
pub address: String,
pub tunnel: Tunnel<C>,
}
impl<R: ConnectionResolver> IntoInlineConnection<KafkaBroker, R>
for KafkaBroker<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> KafkaBroker {
let KafkaBroker { address, tunnel } = self;
KafkaBroker {
address,
tunnel: tunnel.into_inline_connection(r),
}
}
}
impl RustType<ProtoKafkaBroker> for KafkaBroker {
fn into_proto(&self) -> ProtoKafkaBroker {
ProtoKafkaBroker {
address: self.address.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoKafkaBroker) -> Result<Self, TryFromProtoError> {
Ok(KafkaBroker {
address: proto.address.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoKafkaConnection::tunnel")?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct KafkaConnection<C: ConnectionAccess = InlinedConnection> {
pub brokers: Vec<KafkaBroker<C>>,
pub default_tunnel: Tunnel<C>,
pub progress_topic: Option<String>,
pub options: BTreeMap<String, StringOrSecret>,
pub tls: Option<KafkaTlsConfig>,
pub sasl: Option<KafkaSaslConfig>,
}
impl<R: ConnectionResolver> IntoInlineConnection<KafkaConnection, R>
for KafkaConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> KafkaConnection {
let KafkaConnection {
brokers,
progress_topic,
default_tunnel,
options,
tls,
sasl,
} = self;
let brokers = brokers
.into_iter()
.map(|broker| broker.into_inline_connection(&r))
.collect();
KafkaConnection {
brokers,
progress_topic,
default_tunnel: default_tunnel.into_inline_connection(&r),
options,
tls,
sasl,
}
}
}
impl<C: ConnectionAccess> KafkaConnection<C> {
pub fn progress_topic(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
) -> Cow<str> {
if let Some(progress_topic) = &self.progress_topic {
Cow::Borrowed(progress_topic)
} else {
Cow::Owned(format!(
"_materialize-progress-{}-{}",
connection_context.environment_id, connection_id,
))
}
}
fn validate_by_default(&self) -> bool {
true
}
}
impl KafkaConnection {
pub fn id_base(
connection_context: &ConnectionContext,
connection_id: GlobalId,
object_id: GlobalId,
) -> String {
format!(
"materialize-{}-{}-{}",
connection_context.environment_id, connection_id, object_id,
)
}
pub fn enrich_client_id(&self, configs: &ConfigSet, client_id: &mut String) {
#[derive(Debug, Deserialize)]
struct EnrichmentRule {
#[serde(deserialize_with = "deserialize_regex")]
pattern: Regex,
payload: String,
}
fn deserialize_regex<'de, D>(deserializer: D) -> Result<Regex, D::Error>
where
D: Deserializer<'de>,
{
let buf = String::deserialize(deserializer)?;
Regex::new(&buf).map_err(serde::de::Error::custom)
}
let rules = KAFKA_CLIENT_ID_ENRICHMENT_RULES.get(configs);
let rules = match serde_json::from_value::<Vec<EnrichmentRule>>(rules) {
Ok(rules) => rules,
Err(e) => {
warn!(%e, "failed to decode kafka_client_id_enrichment_rules");
return;
}
};
debug!(?self.brokers, "evaluating client ID enrichment rules");
for rule in rules {
let is_match = self
.brokers
.iter()
.any(|b| rule.pattern.is_match(&b.address));
debug!(?rule, is_match, "evaluated client ID enrichment rule");
if is_match {
client_id.push('-');
client_id.push_str(&rule.payload);
}
}
}
pub async fn create_with_context<C, T>(
&self,
storage_configuration: &StorageConfiguration,
context: C,
extra_options: &BTreeMap<&str, String>,
in_task: InTask,
) -> Result<T, ContextCreationError>
where
C: ClientContext,
T: FromClientConfigAndContext<TunnelingClientContext<C>>,
{
let mut options = self.options.clone();
options.insert("allow.auto.create.topics".into(), "false".into());
let brokers = match &self.default_tunnel {
Tunnel::AwsPrivatelink(t) => {
assert!(&self.brokers.is_empty());
format!(
"{}:{}",
vpc_endpoint_host(
t.connection_id,
None, ),
t.port.unwrap_or(9092)
)
}
_ => self.brokers.iter().map(|b| &b.address).join(","),
};
options.insert("bootstrap.servers".into(), brokers.into());
let security_protocol = match (self.tls.is_some(), self.sasl.is_some()) {
(false, false) => "PLAINTEXT",
(true, false) => "SSL",
(false, true) => "SASL_PLAINTEXT",
(true, true) => "SASL_SSL",
};
options.insert("security.protocol".into(), security_protocol.into());
if let Some(tls) = &self.tls {
if let Some(root_cert) = &tls.root_cert {
options.insert("ssl.ca.pem".into(), root_cert.clone());
}
if let Some(identity) = &tls.identity {
options.insert("ssl.key.pem".into(), StringOrSecret::Secret(identity.key));
options.insert("ssl.certificate.pem".into(), identity.cert.clone());
}
}
if let Some(sasl) = &self.sasl {
options.insert("sasl.mechanisms".into(), (&sasl.mechanism).into());
options.insert("sasl.username".into(), sasl.username.clone());
options.insert(
"sasl.password".into(),
StringOrSecret::Secret(sasl.password),
);
}
let mut config = mz_kafka_util::client::create_new_client_config(
storage_configuration
.connection_context
.librdkafka_log_level,
storage_configuration.parameters.kafka_timeout_config,
);
for (k, v) in options {
config.set(
k,
v.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await
.context("reading kafka secret")?,
);
}
for (k, v) in extra_options {
config.set(*k, v);
}
let mut context = TunnelingClientContext::new(
context,
Handle::current(),
storage_configuration
.connection_context
.ssh_tunnel_manager
.clone(),
storage_configuration.parameters.ssh_timeout_config,
in_task,
);
match &self.default_tunnel {
Tunnel::Direct => {
}
Tunnel::AwsPrivatelink(pl) => {
context.set_default_tunnel(TunnelConfig::StaticHost(vpc_endpoint_host(
pl.connection_id,
None, )));
}
Tunnel::Ssh(ssh_tunnel) => {
let secret = storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, ssh_tunnel.connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&ssh_tunnel.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
context.set_default_tunnel(TunnelConfig::Ssh(SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: ssh_tunnel.connection.port,
user: ssh_tunnel.connection.user.clone(),
key_pair,
}));
}
}
for broker in &self.brokers {
let mut addr_parts = broker.address.splitn(2, ':');
let addr = BrokerAddr {
host: addr_parts
.next()
.context("BROKER is not address:port")?
.into(),
port: addr_parts.next().unwrap_or("9092").into(),
};
match &broker.tunnel {
Tunnel::Direct => {
}
Tunnel::AwsPrivatelink(aws_privatelink) => {
let host = mz_cloud_resources::vpc_endpoint_host(
aws_privatelink.connection_id,
aws_privatelink.availability_zone.as_deref(),
);
let port = aws_privatelink.port;
context.add_broker_rewrite(
addr,
BrokerRewrite {
host: host.clone(),
port,
},
);
}
Tunnel::Ssh(ssh_tunnel) => {
let ssh_host_resolved = resolve_address(
&ssh_tunnel.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
context
.add_ssh_tunnel(
addr,
SshTunnelConfig {
host: ssh_host_resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: ssh_tunnel.connection.port,
user: ssh_tunnel.connection.user.clone(),
key_pair: SshKeyPair::from_bytes(
&storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, ssh_tunnel.connection_id)
.await?,
)?,
},
)
.await
.map_err(ContextCreationError::Ssh)?;
}
}
}
Ok(config.create_with_context(context)?)
}
async fn validate(
&self,
_id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let (context, error_rx) = MzClientContext::with_errors();
let consumer: BaseConsumer<_> = self
.create_with_context(
storage_configuration,
context,
&BTreeMap::new(),
InTask::No,
)
.await?;
let consumer = Arc::new(consumer);
let timeout = storage_configuration
.parameters
.kafka_timeout_config
.fetch_metadata_timeout;
let result = mz_ore::task::spawn_blocking(|| "kafka_get_metadata", {
let consumer = Arc::clone(&consumer);
move || consumer.fetch_metadata(None, timeout)
})
.await?;
match result {
Ok(_) => Ok(()),
Err(err) => {
let main_err = error_rx.try_iter().reduce(|cur, new| match cur {
MzKafkaError::Internal(_) => new,
_ => cur,
});
drop(consumer);
match main_err {
Some(err) => Err(err.into()),
None => Err(err.into()),
}
}
}
}
}
impl<C: ConnectionAccess> AlterCompatible for KafkaConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let KafkaConnection {
brokers: _,
default_tunnel: _,
progress_topic,
options: _,
tls: _,
sasl: _,
} = self;
let compatibility_checks = [(progress_topic == &other.progress_topic, "progress_topic")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"KafkaConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl RustType<ProtoKafkaConnectionTlsConfig> for KafkaTlsConfig {
fn into_proto(&self) -> ProtoKafkaConnectionTlsConfig {
ProtoKafkaConnectionTlsConfig {
identity: self.identity.into_proto(),
root_cert: self.root_cert.into_proto(),
}
}
fn from_proto(proto: ProtoKafkaConnectionTlsConfig) -> Result<Self, TryFromProtoError> {
Ok(KafkaTlsConfig {
root_cert: proto.root_cert.into_rust()?,
identity: proto.identity.into_rust()?,
})
}
}
impl RustType<ProtoKafkaConnectionSaslConfig> for KafkaSaslConfig {
fn into_proto(&self) -> ProtoKafkaConnectionSaslConfig {
ProtoKafkaConnectionSaslConfig {
mechanism: self.mechanism.into_proto(),
username: Some(self.username.into_proto()),
password: Some(self.password.into_proto()),
}
}
fn from_proto(proto: ProtoKafkaConnectionSaslConfig) -> Result<Self, TryFromProtoError> {
Ok(KafkaSaslConfig {
mechanism: proto.mechanism,
username: proto
.username
.into_rust_if_some("ProtoKafkaConnectionSaslConfig::username")?,
password: proto
.password
.into_rust_if_some("ProtoKafkaConnectionSaslConfig::password")?,
})
}
}
impl RustType<ProtoKafkaConnection> for KafkaConnection {
fn into_proto(&self) -> ProtoKafkaConnection {
ProtoKafkaConnection {
brokers: self.brokers.into_proto(),
default_tunnel: Some(self.default_tunnel.into_proto()),
progress_topic: self.progress_topic.into_proto(),
options: self
.options
.iter()
.map(|(k, v)| (k.clone(), v.into_proto()))
.collect(),
tls: self.tls.into_proto(),
sasl: self.sasl.into_proto(),
}
}
fn from_proto(proto: ProtoKafkaConnection) -> Result<Self, TryFromProtoError> {
Ok(KafkaConnection {
brokers: proto.brokers.into_rust()?,
default_tunnel: proto
.default_tunnel
.into_rust_if_some("ProtoKafkaConnection::default_tunnel")?,
progress_topic: proto.progress_topic,
options: proto
.options
.into_iter()
.map(|(k, v)| StringOrSecret::from_proto(v).map(|v| (k, v)))
.collect::<Result<_, _>>()?,
tls: proto.tls.into_rust()?,
sasl: proto.sasl.into_rust()?,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct CsrConnection<C: ConnectionAccess = InlinedConnection> {
#[proptest(strategy = "any_url()")]
pub url: Url,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
pub http_auth: Option<CsrConnectionHttpAuth>,
pub tunnel: Tunnel<C>,
}
impl<R: ConnectionResolver> IntoInlineConnection<CsrConnection, R>
for CsrConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> CsrConnection {
let CsrConnection {
url,
tls_root_cert,
tls_identity,
http_auth,
tunnel,
} = self;
CsrConnection {
url,
tls_root_cert,
tls_identity,
http_auth,
tunnel: tunnel.into_inline_connection(r),
}
}
}
impl<C: ConnectionAccess> CsrConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl CsrConnection {
pub async fn connect(
&self,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_ccsr::Client, CsrConnectError> {
let mut client_config = mz_ccsr::ClientConfig::new(self.url.clone());
if let Some(root_cert) = &self.tls_root_cert {
let root_cert = root_cert
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let root_cert = Certificate::from_pem(root_cert.as_bytes())?;
client_config = client_config.add_root_certificate(root_cert);
}
if let Some(tls_identity) = &self.tls_identity {
let key = &storage_configuration
.connection_context
.secrets_reader
.read_string_in_task_if(in_task, tls_identity.key)
.await?;
let cert = tls_identity
.cert
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let ident = Identity::from_pem(key.as_bytes(), cert.as_bytes())?;
client_config = client_config.identity(ident);
}
if let Some(http_auth) = &self.http_auth {
let username = http_auth
.username
.get_string(
in_task,
&storage_configuration.connection_context.secrets_reader,
)
.await?;
let password = match http_auth.password {
None => None,
Some(password) => Some(
storage_configuration
.connection_context
.secrets_reader
.read_string_in_task_if(in_task, password)
.await?,
),
};
client_config = client_config.auth(username, password);
}
const DUMMY_PORT: u16 = 11111;
let host = self
.url
.host_str()
.ok_or_else(|| anyhow!("url missing host"))?;
match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
client_config = client_config.resolve_to_addrs(
host,
&resolved
.iter()
.map(|addr| SocketAddr::new(*addr, DUMMY_PORT))
.collect::<Vec<_>>(),
)
}
Tunnel::Ssh(ssh_tunnel) => {
let ssh_tunnel = ssh_tunnel
.connect(
storage_configuration,
host,
self.url.port().unwrap_or(80),
in_task,
)
.await
.map_err(CsrConnectError::Ssh)?;
client_config = client_config
.resolve_to_addrs(
host,
&[SocketAddr::new(ssh_tunnel.local_addr().ip(), DUMMY_PORT)],
)
.dynamic_url({
let remote_url = self.url.clone();
move || {
let mut url = remote_url.clone();
url.set_port(Some(ssh_tunnel.local_addr().port()))
.expect("cannot fail");
url
}
});
}
Tunnel::AwsPrivatelink(connection) => {
assert!(connection.port.is_none());
let privatelink_host = mz_cloud_resources::vpc_endpoint_host(
connection.connection_id,
connection.availability_zone.as_deref(),
);
let addrs: Vec<_> = net::lookup_host((privatelink_host, DUMMY_PORT))
.await
.context("resolving PrivateLink host")?
.collect();
client_config = client_config.resolve_to_addrs(host, &addrs)
}
}
Ok(client_config.build()?)
}
async fn validate(
&self,
_id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let client = self
.connect(
storage_configuration,
InTask::No,
)
.await?;
client.list_subjects().await?;
Ok(())
}
}
impl RustType<ProtoCsrConnection> for CsrConnection {
fn into_proto(&self) -> ProtoCsrConnection {
ProtoCsrConnection {
url: Some(self.url.into_proto()),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
http_auth: self.http_auth.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoCsrConnection) -> Result<Self, TryFromProtoError> {
Ok(CsrConnection {
url: proto.url.into_rust_if_some("ProtoCsrConnection::url")?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
http_auth: proto.http_auth.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoCsrConnection::tunnel")?,
})
}
}
impl<C: ConnectionAccess> AlterCompatible for CsrConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let CsrConnection {
tunnel,
url: _,
tls_root_cert: _,
tls_identity: _,
http_auth: _,
} = self;
let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"CsrConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct TlsIdentity {
pub cert: StringOrSecret,
pub key: GlobalId,
}
impl RustType<ProtoTlsIdentity> for TlsIdentity {
fn into_proto(&self) -> ProtoTlsIdentity {
ProtoTlsIdentity {
cert: Some(self.cert.into_proto()),
key: Some(self.key.into_proto()),
}
}
fn from_proto(proto: ProtoTlsIdentity) -> Result<Self, TryFromProtoError> {
Ok(TlsIdentity {
cert: proto.cert.into_rust_if_some("ProtoTlsIdentity::cert")?,
key: proto.key.into_rust_if_some("ProtoTlsIdentity::key")?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct CsrConnectionHttpAuth {
pub username: StringOrSecret,
pub password: Option<GlobalId>,
}
impl RustType<ProtoCsrConnectionHttpAuth> for CsrConnectionHttpAuth {
fn into_proto(&self) -> ProtoCsrConnectionHttpAuth {
ProtoCsrConnectionHttpAuth {
username: Some(self.username.into_proto()),
password: self.password.into_proto(),
}
}
fn from_proto(proto: ProtoCsrConnectionHttpAuth) -> Result<Self, TryFromProtoError> {
Ok(CsrConnectionHttpAuth {
username: proto
.username
.into_rust_if_some("ProtoCsrConnectionHttpAuth::username")?,
password: proto.password.into_rust()?,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct PostgresConnection<C: ConnectionAccess = InlinedConnection> {
pub host: String,
pub port: u16,
pub database: String,
pub user: StringOrSecret,
pub password: Option<GlobalId>,
pub tunnel: Tunnel<C>,
#[proptest(strategy = "any_ssl_mode()")]
pub tls_mode: SslMode,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
}
impl<R: ConnectionResolver> IntoInlineConnection<PostgresConnection, R>
for PostgresConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> PostgresConnection {
let PostgresConnection {
host,
port,
database,
user,
password,
tunnel,
tls_mode,
tls_root_cert,
tls_identity,
} = self;
PostgresConnection {
host,
port,
database,
user,
password,
tunnel: tunnel.into_inline_connection(r),
tls_mode,
tls_root_cert,
tls_identity,
}
}
}
impl<C: ConnectionAccess> PostgresConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl PostgresConnection<InlinedConnection> {
pub async fn config(
&self,
secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_postgres_util::Config, anyhow::Error> {
let mut config = tokio_postgres::Config::new();
config
.host(&self.host)
.port(self.port)
.dbname(&self.database)
.user(&self.user.get_string(in_task, secrets_reader).await?)
.ssl_mode(self.tls_mode);
if let Some(password) = self.password {
let password = secrets_reader
.read_string_in_task_if(in_task, password)
.await?;
config.password(password);
}
if let Some(tls_root_cert) = &self.tls_root_cert {
let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
config.ssl_root_cert(tls_root_cert.as_bytes());
}
if let Some(tls_identity) = &self.tls_identity {
let cert = tls_identity
.cert
.get_string(in_task, secrets_reader)
.await?;
let key = secrets_reader
.read_string_in_task_if(in_task, tls_identity.key)
.await?;
config.ssl_cert(cert.as_bytes()).ssl_key(key.as_bytes());
}
let tunnel = match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_postgres_util::TunnelConfig::Direct {
resolved_ips: Some(resolved),
}
}
Tunnel::Ssh(SshTunnel {
connection_id,
connection,
}) => {
let secret = secrets_reader
.read_in_task_if(in_task, *connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_postgres_util::TunnelConfig::Ssh {
config: SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: connection.port,
user: connection.user.clone(),
key_pair,
},
}
}
Tunnel::AwsPrivatelink(connection) => {
assert!(connection.port.is_none());
mz_postgres_util::TunnelConfig::AwsPrivatelink {
connection_id: connection.connection_id,
}
}
};
Ok(mz_postgres_util::Config::new(
config,
tunnel,
storage_configuration
.parameters
.pg_source_tcp_timeouts
.clone(),
storage_configuration.parameters.ssh_timeout_config,
in_task,
)?)
}
async fn validate(
&self,
_id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let config = self
.config(
&storage_configuration.connection_context.secrets_reader,
storage_configuration,
InTask::No,
)
.await?;
config
.connect(
"connection validation",
&storage_configuration.connection_context.ssh_tunnel_manager,
)
.await?;
Ok(())
}
}
impl<C: ConnectionAccess> AlterCompatible for PostgresConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let PostgresConnection {
tunnel,
host: _,
port: _,
database: _,
user: _,
password: _,
tls_mode: _,
tls_root_cert: _,
tls_identity: _,
} = self;
let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"PostgresConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl RustType<ProtoPostgresConnection> for PostgresConnection {
fn into_proto(&self) -> ProtoPostgresConnection {
ProtoPostgresConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
database: self.database.into_proto(),
user: Some(self.user.into_proto()),
password: self.password.into_proto(),
tls_mode: Some(self.tls_mode.into_proto()),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoPostgresConnection) -> Result<Self, TryFromProtoError> {
Ok(PostgresConnection {
host: proto.host,
port: proto.port.into_rust()?,
database: proto.database,
user: proto
.user
.into_rust_if_some("ProtoPostgresConnection::user")?,
password: proto.password.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoPostgresConnection::tunnel")?,
tls_mode: proto
.tls_mode
.into_rust_if_some("ProtoPostgresConnection::tls_mode")?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum Tunnel<C: ConnectionAccess = InlinedConnection> {
Direct,
Ssh(SshTunnel<C>),
AwsPrivatelink(AwsPrivatelink),
}
impl<R: ConnectionResolver> IntoInlineConnection<Tunnel, R> for Tunnel<ReferencedConnection> {
fn into_inline_connection(self, r: R) -> Tunnel {
match self {
Tunnel::Direct => Tunnel::Direct,
Tunnel::Ssh(ssh) => Tunnel::Ssh(ssh.into_inline_connection(r)),
Tunnel::AwsPrivatelink(awspl) => Tunnel::AwsPrivatelink(awspl),
}
}
}
impl RustType<ProtoTunnel> for Tunnel<InlinedConnection> {
fn into_proto(&self) -> ProtoTunnel {
use proto_tunnel::Tunnel as ProtoTunnelField;
ProtoTunnel {
tunnel: Some(match &self {
Tunnel::Direct => ProtoTunnelField::Direct(()),
Tunnel::Ssh(ssh) => ProtoTunnelField::Ssh(ssh.into_proto()),
Tunnel::AwsPrivatelink(aws) => ProtoTunnelField::AwsPrivatelink(aws.into_proto()),
}),
}
}
fn from_proto(proto: ProtoTunnel) -> Result<Self, TryFromProtoError> {
use proto_tunnel::Tunnel as ProtoTunnelField;
Ok(match proto.tunnel {
None => return Err(TryFromProtoError::missing_field("ProtoTunnel::tunnel")),
Some(ProtoTunnelField::Direct(())) => Tunnel::Direct,
Some(ProtoTunnelField::Ssh(ssh)) => Tunnel::Ssh(ssh.into_rust()?),
Some(ProtoTunnelField::AwsPrivatelink(aws)) => Tunnel::AwsPrivatelink(aws.into_rust()?),
})
}
}
impl<C: ConnectionAccess> AlterCompatible for Tunnel<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let compatible = match (self, other) {
(Self::Ssh(s), Self::Ssh(o)) => s.alter_compatible(id, o).is_ok(),
(s, o) => s == o,
};
if !compatible {
tracing::warn!(
"Tunnel incompatible:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum MySqlSslMode {
Disabled,
Required,
VerifyCa,
VerifyIdentity,
}
impl RustType<i32> for MySqlSslMode {
fn into_proto(&self) -> i32 {
match self {
MySqlSslMode::Disabled => ProtoMySqlSslMode::Disabled.into(),
MySqlSslMode::Required => ProtoMySqlSslMode::Required.into(),
MySqlSslMode::VerifyCa => ProtoMySqlSslMode::VerifyCa.into(),
MySqlSslMode::VerifyIdentity => ProtoMySqlSslMode::VerifyIdentity.into(),
}
}
fn from_proto(proto: i32) -> Result<Self, TryFromProtoError> {
Ok(match ProtoMySqlSslMode::from_i32(proto) {
Some(ProtoMySqlSslMode::Disabled) => MySqlSslMode::Disabled,
Some(ProtoMySqlSslMode::Required) => MySqlSslMode::Required,
Some(ProtoMySqlSslMode::VerifyCa) => MySqlSslMode::VerifyCa,
Some(ProtoMySqlSslMode::VerifyIdentity) => MySqlSslMode::VerifyIdentity,
None => {
return Err(TryFromProtoError::UnknownEnumVariant(
"tls_mode".to_string(),
))
}
})
}
}
pub fn any_mysql_ssl_mode() -> impl Strategy<Value = MySqlSslMode> {
proptest::sample::select(vec![
MySqlSslMode::Disabled,
MySqlSslMode::Required,
MySqlSslMode::VerifyCa,
MySqlSslMode::VerifyIdentity,
])
}
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize, Arbitrary)]
pub struct MySqlConnection<C: ConnectionAccess = InlinedConnection> {
pub host: String,
pub port: u16,
pub user: StringOrSecret,
pub password: Option<GlobalId>,
pub tunnel: Tunnel<C>,
#[proptest(strategy = "any_mysql_ssl_mode()")]
pub tls_mode: MySqlSslMode,
pub tls_root_cert: Option<StringOrSecret>,
pub tls_identity: Option<TlsIdentity>,
}
impl<R: ConnectionResolver> IntoInlineConnection<MySqlConnection, R>
for MySqlConnection<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> MySqlConnection {
let MySqlConnection {
host,
port,
user,
password,
tunnel,
tls_mode,
tls_root_cert,
tls_identity,
} = self;
MySqlConnection {
host,
port,
user,
password,
tunnel: tunnel.into_inline_connection(r),
tls_mode,
tls_root_cert,
tls_identity,
}
}
}
impl<C: ConnectionAccess> MySqlConnection<C> {
fn validate_by_default(&self) -> bool {
true
}
}
impl MySqlConnection<InlinedConnection> {
pub async fn config(
&self,
secrets_reader: &Arc<dyn mz_secrets::SecretsReader>,
storage_configuration: &StorageConfiguration,
in_task: InTask,
) -> Result<mz_mysql_util::Config, anyhow::Error> {
let mut opts = mysql_async::OptsBuilder::default()
.ip_or_hostname(&self.host)
.tcp_port(self.port)
.user(Some(&self.user.get_string(in_task, secrets_reader).await?));
if let Some(password) = self.password {
let password = secrets_reader
.read_string_in_task_if(in_task, password)
.await?;
opts = opts.pass(Some(password));
}
let mut ssl_opts = match self.tls_mode {
MySqlSslMode::Disabled => None,
MySqlSslMode::Required => Some(
mysql_async::SslOpts::default()
.with_danger_accept_invalid_certs(true)
.with_danger_skip_domain_validation(true),
),
MySqlSslMode::VerifyCa => {
Some(mysql_async::SslOpts::default().with_danger_skip_domain_validation(true))
}
MySqlSslMode::VerifyIdentity => Some(mysql_async::SslOpts::default()),
};
if matches!(
self.tls_mode,
MySqlSslMode::VerifyCa | MySqlSslMode::VerifyIdentity
) {
if let Some(tls_root_cert) = &self.tls_root_cert {
let tls_root_cert = tls_root_cert.get_string(in_task, secrets_reader).await?;
ssl_opts = ssl_opts.map(|opts| {
opts.with_root_certs(vec![tls_root_cert.as_bytes().to_vec().into()])
});
}
}
if let Some(identity) = &self.tls_identity {
let key = secrets_reader
.read_string_in_task_if(in_task, identity.key)
.await?;
let cert = identity.cert.get_string(in_task, secrets_reader).await?;
let Pkcs12Archive { der, pass } =
mz_tls_util::pkcs12der_from_pem(key.as_bytes(), cert.as_bytes())?;
ssl_opts = ssl_opts.map(|opts| {
opts.with_client_identity(Some(
mysql_async::ClientIdentity::new(der.into()).with_password(pass),
))
});
}
opts = opts.ssl_opts(ssl_opts);
let tunnel = match &self.tunnel {
Tunnel::Direct => {
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_mysql_util::TunnelConfig::Direct {
resolved_ips: Some(resolved),
}
}
Tunnel::Ssh(SshTunnel {
connection_id,
connection,
}) => {
let secret = secrets_reader
.read_in_task_if(in_task, *connection_id)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
mz_mysql_util::TunnelConfig::Ssh {
config: SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: connection.port,
user: connection.user.clone(),
key_pair,
},
}
}
Tunnel::AwsPrivatelink(connection) => {
assert!(connection.port.is_none());
mz_mysql_util::TunnelConfig::AwsPrivatelink {
connection_id: connection.connection_id,
}
}
};
opts = storage_configuration
.parameters
.mysql_source_timeouts
.apply_to_opts(opts)?;
Ok(mz_mysql_util::Config::new(
opts.into(),
tunnel,
storage_configuration.parameters.ssh_timeout_config,
in_task,
))
}
async fn validate(
&self,
_id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let config = self
.config(
&storage_configuration.connection_context.secrets_reader,
storage_configuration,
InTask::No,
)
.await?;
let conn = config
.connect(
"connection validation",
&storage_configuration.connection_context.ssh_tunnel_manager,
)
.await?;
conn.disconnect().await?;
Ok(())
}
}
impl RustType<ProtoMySqlConnection> for MySqlConnection {
fn into_proto(&self) -> ProtoMySqlConnection {
ProtoMySqlConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
user: Some(self.user.into_proto()),
password: self.password.into_proto(),
tls_mode: self.tls_mode.into_proto(),
tls_root_cert: self.tls_root_cert.into_proto(),
tls_identity: self.tls_identity.into_proto(),
tunnel: Some(self.tunnel.into_proto()),
}
}
fn from_proto(proto: ProtoMySqlConnection) -> Result<Self, TryFromProtoError> {
Ok(MySqlConnection {
host: proto.host,
port: proto.port.into_rust()?,
user: proto.user.into_rust_if_some("ProtoMySqlConnection::user")?,
password: proto.password.into_rust()?,
tunnel: proto
.tunnel
.into_rust_if_some("ProtoMySqlConnection::tunnel")?,
tls_mode: proto.tls_mode.into_rust()?,
tls_root_cert: proto.tls_root_cert.into_rust()?,
tls_identity: proto.tls_identity.into_rust()?,
})
}
}
impl<C: ConnectionAccess> AlterCompatible for MySqlConnection<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let MySqlConnection {
tunnel,
host: _,
port: _,
user: _,
password: _,
tls_mode: _,
tls_root_cert: _,
tls_identity: _,
} = self;
let compatibility_checks = [(tunnel.alter_compatible(id, &other.tunnel).is_ok(), "tunnel")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"MySqlConnection incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct SshConnection {
pub host: String,
pub port: u16,
pub user: String,
pub public_keys: Option<(String, String)>,
}
use proto_ssh_connection::ProtoPublicKeys;
use self::inline::{
ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
ReferencedConnection,
};
impl RustType<ProtoPublicKeys> for (String, String) {
fn into_proto(&self) -> ProtoPublicKeys {
ProtoPublicKeys {
primary_public_key: self.0.into_proto(),
secondary_public_key: self.1.into_proto(),
}
}
fn from_proto(proto: ProtoPublicKeys) -> Result<Self, TryFromProtoError> {
Ok((proto.primary_public_key, proto.secondary_public_key))
}
}
impl RustType<ProtoSshConnection> for SshConnection {
fn into_proto(&self) -> ProtoSshConnection {
ProtoSshConnection {
host: self.host.into_proto(),
port: self.port.into_proto(),
user: self.user.into_proto(),
public_keys: self.public_keys.into_proto(),
}
}
fn from_proto(proto: ProtoSshConnection) -> Result<Self, TryFromProtoError> {
Ok(SshConnection {
host: proto.host,
port: proto.port.into_rust()?,
user: proto.user,
public_keys: proto.public_keys.into_rust()?,
})
}
}
impl AlterCompatible for SshConnection {
fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct AwsPrivatelink {
pub connection_id: GlobalId,
pub availability_zone: Option<String>,
pub port: Option<u16>,
}
impl RustType<ProtoAwsPrivatelink> for AwsPrivatelink {
fn into_proto(&self) -> ProtoAwsPrivatelink {
ProtoAwsPrivatelink {
connection_id: Some(self.connection_id.into_proto()),
availability_zone: self.availability_zone.into_proto(),
port: self.port.into_proto(),
}
}
fn from_proto(proto: ProtoAwsPrivatelink) -> Result<Self, TryFromProtoError> {
Ok(AwsPrivatelink {
connection_id: proto
.connection_id
.into_rust_if_some("ProtoAwsPrivatelink::connection_id")?,
availability_zone: proto.availability_zone.into_rust()?,
port: proto.port.into_rust()?,
})
}
}
impl AlterCompatible for AwsPrivatelink {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let AwsPrivatelink {
connection_id,
availability_zone: _,
port: _,
} = self;
let compatibility_checks = [(connection_id == &other.connection_id, "connection_id")];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"AwsPrivatelink incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct SshTunnel<C: ConnectionAccess = InlinedConnection> {
pub connection_id: GlobalId,
pub connection: C::Ssh,
}
impl<R: ConnectionResolver> IntoInlineConnection<SshTunnel, R> for SshTunnel<ReferencedConnection> {
fn into_inline_connection(self, r: R) -> SshTunnel {
let SshTunnel {
connection,
connection_id,
} = self;
SshTunnel {
connection: r.resolve_connection(connection).unwrap_ssh(),
connection_id,
}
}
}
impl RustType<ProtoSshTunnel> for SshTunnel<InlinedConnection> {
fn into_proto(&self) -> ProtoSshTunnel {
ProtoSshTunnel {
connection_id: Some(self.connection_id.into_proto()),
connection: Some(self.connection.into_proto()),
}
}
fn from_proto(proto: ProtoSshTunnel) -> Result<Self, TryFromProtoError> {
Ok(SshTunnel {
connection_id: proto
.connection_id
.into_rust_if_some("ProtoSshTunnel::connection_id")?,
connection: proto
.connection
.into_rust_if_some("ProtoSshTunnel::connection")?,
})
}
}
impl SshTunnel<InlinedConnection> {
async fn connect(
&self,
storage_configuration: &StorageConfiguration,
remote_host: &str,
remote_port: u16,
in_task: InTask,
) -> Result<ManagedSshTunnelHandle, anyhow::Error> {
let resolved = resolve_address(
&self.connection.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
storage_configuration
.connection_context
.ssh_tunnel_manager
.connect(
SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: self.connection.port,
user: self.connection.user.clone(),
key_pair: SshKeyPair::from_bytes(
&storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(in_task, self.connection_id)
.await?,
)?,
},
remote_host,
remote_port,
storage_configuration.parameters.ssh_timeout_config,
in_task,
)
.await
}
}
impl<C: ConnectionAccess> AlterCompatible for SshTunnel<C> {
fn alter_compatible(&self, id: GlobalId, other: &Self) -> Result<(), AlterError> {
let SshTunnel {
connection_id,
connection,
} = self;
let compatibility_checks = [
(connection_id == &other.connection_id, "connection_id"),
(
connection.alter_compatible(id, &other.connection).is_ok(),
"connection",
),
];
for (compatible, field) in compatibility_checks {
if !compatible {
tracing::warn!(
"SshTunnel incompatible at {field}:\nself:\n{:#?}\n\nother\n{:#?}",
self,
other
);
return Err(AlterError { id });
}
}
Ok(())
}
}
impl SshConnection {
#[allow(clippy::unused_async)]
async fn validate(
&self,
id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let secret = storage_configuration
.connection_context
.secrets_reader
.read_in_task_if(
InTask::No,
id,
)
.await?;
let key_pair = SshKeyPair::from_bytes(&secret)?;
let resolved = resolve_address(
&self.host,
ENFORCE_EXTERNAL_ADDRESSES.get(storage_configuration.config_set()),
)
.await?;
let config = SshTunnelConfig {
host: resolved
.iter()
.map(|a| a.to_string())
.collect::<BTreeSet<_>>(),
port: self.port,
user: self.user.clone(),
key_pair,
};
config
.validate(storage_configuration.parameters.ssh_timeout_config)
.await
}
fn validate_by_default(&self) -> bool {
false
}
}
impl AwsPrivatelinkConnection {
#[allow(clippy::unused_async)]
async fn validate(
&self,
id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), anyhow::Error> {
let Some(ref cloud_resource_reader) = storage_configuration
.connection_context
.cloud_resource_reader
else {
return Err(anyhow!("AWS PrivateLink connections are unsupported"));
};
let status = cloud_resource_reader.read(id).await?;
let availability = status
.conditions
.as_ref()
.and_then(|conditions| conditions.iter().find(|c| c.type_ == "Available"));
match availability {
Some(condition) if condition.status == "True" => Ok(()),
Some(condition) => Err(anyhow!("{}", condition.message)),
None => Err(anyhow!("Endpoint availability is unknown")),
}
}
fn validate_by_default(&self) -> bool {
false
}
}