use anyhow::{anyhow, bail};
use aws_config::sts::AssumeRoleProvider;
use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
use aws_credential_types::Credentials;
use aws_sdk_sts::error::SdkError;
use aws_sdk_sts::operation::get_caller_identity::GetCallerIdentityError;
use aws_types::region::Region;
use aws_types::SdkConfig;
use mz_ore::error::ErrorExt;
use mz_ore::future::{InTask, OreFutureExt};
use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
use mz_repr::GlobalId;
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
use serde_json::json;
use crate::connections::inline::{
ConnectionAccess, ConnectionResolver, InlinedConnection, IntoInlineConnection,
ReferencedConnection,
};
use crate::controller::AlterError;
use crate::AlterCompatible;
use crate::{
configuration::StorageConfiguration,
connections::{ConnectionContext, StringOrSecret},
};
include!(concat!(
env!("OUT_DIR"),
"/mz_storage_types.connections.aws.rs"
));
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct AwsConnection {
pub auth: AwsAuth,
pub region: Option<String>,
pub endpoint: Option<String>,
}
impl RustType<ProtoAwsConnection> for AwsConnection {
fn into_proto(&self) -> ProtoAwsConnection {
let auth = match &self.auth {
AwsAuth::Credentials(credentials) => {
proto_aws_connection::Auth::Credentials(credentials.into_proto())
}
AwsAuth::AssumeRole(assume_role) => {
proto_aws_connection::Auth::AssumeRole(assume_role.into_proto())
}
};
ProtoAwsConnection {
auth: Some(auth),
region: self.region.clone(),
endpoint: self.endpoint.clone(),
}
}
fn from_proto(proto: ProtoAwsConnection) -> Result<Self, TryFromProtoError> {
let auth = match proto.auth.expect("auth expected") {
proto_aws_connection::Auth::Credentials(credentials) => {
AwsAuth::Credentials(credentials.into_rust()?)
}
proto_aws_connection::Auth::AssumeRole(assume_role) => {
AwsAuth::AssumeRole(assume_role.into_rust()?)
}
};
Ok(AwsConnection {
auth,
region: proto.region,
endpoint: proto.endpoint,
})
}
}
impl AlterCompatible for AwsConnection {
fn alter_compatible(&self, _id: GlobalId, _other: &Self) -> Result<(), AlterError> {
Ok(())
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub enum AwsAuth {
Credentials(AwsCredentials),
AssumeRole(AwsAssumeRole),
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct AwsCredentials {
pub access_key_id: StringOrSecret,
pub secret_access_key: GlobalId,
pub session_token: Option<StringOrSecret>,
}
impl AwsCredentials {
async fn load_credentials_provider(
&self,
connection_context: &ConnectionContext,
in_task: InTask,
) -> Result<impl ProvideCredentials, anyhow::Error> {
let secrets_reader = &connection_context.secrets_reader;
Ok(Credentials::from_keys(
self.access_key_id
.get_string(in_task, secrets_reader)
.await
.map_err(|_| {
anyhow!("internal error: failed to read access key ID from secret store")
})?,
connection_context
.secrets_reader
.read_string(self.secret_access_key)
.await
.map_err(|_| {
anyhow!("internal error: failed to read secret access key from secret store")
})?,
match &self.session_token {
Some(t) => {
let t = t.get_string(in_task, secrets_reader).await.map_err(|_| {
anyhow!("internal error: failed to read session token from secret store")
})?;
Some(t)
}
None => None,
},
))
}
}
impl RustType<ProtoAwsCredentials> for AwsCredentials {
fn into_proto(&self) -> ProtoAwsCredentials {
ProtoAwsCredentials {
access_key_id: Some(self.access_key_id.into_proto()),
secret_access_key: Some(self.secret_access_key.into_proto()),
session_token: self.session_token.into_proto(),
}
}
fn from_proto(proto: ProtoAwsCredentials) -> Result<Self, TryFromProtoError> {
Ok(AwsCredentials {
access_key_id: proto
.access_key_id
.into_rust_if_some("ProtoAwsCredentials::access_key_id")?,
secret_access_key: proto
.secret_access_key
.into_rust_if_some("ProtoAwsCredentials::secret_access_key")?,
session_token: proto.session_token.into_rust()?,
})
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
pub struct AwsAssumeRole {
pub arn: String,
pub session_name: Option<String>,
}
impl AwsAssumeRole {
async fn load_credentials_provider(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
) -> Result<impl ProvideCredentials, anyhow::Error> {
let external_id = self.external_id(connection_context, connection_id)?;
self.dangerously_load_credentials_provider(
connection_context,
connection_id,
Some(external_id),
)
.await
}
async fn dangerously_load_credentials_provider(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
external_id: Option<String>,
) -> Result<impl ProvideCredentials, anyhow::Error> {
let Some(aws_connection_role_arn) = &connection_context.aws_connection_role_arn else {
bail!("internal error: no AWS connection role configured");
};
let assume_role_sdk_config = mz_aws_util::defaults().load().await;
let default_session_name =
format!("{}-{}", &connection_context.environment_id, connection_id);
let jump_credentials = AssumeRoleProvider::builder(aws_connection_role_arn)
.configure(&assume_role_sdk_config)
.session_name(default_session_name.clone())
.build()
.await;
let mut credentials = AssumeRoleProvider::builder(&self.arn)
.configure(&assume_role_sdk_config)
.session_name(self.session_name.clone().unwrap_or(default_session_name));
if let Some(external_id) = external_id {
credentials = credentials.external_id(external_id);
}
Ok(credentials.build_from_provider(jump_credentials).await)
}
pub fn external_id(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
) -> Result<String, anyhow::Error> {
let Some(aws_external_id_prefix) = &connection_context.aws_external_id_prefix else {
bail!("internal error: no AWS external ID prefix configured");
};
Ok(format!("mz_{}_{}", aws_external_id_prefix, connection_id))
}
pub fn example_trust_policy(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
) -> Result<serde_json::Value, anyhow::Error> {
let Some(aws_connection_role_arn) = &connection_context.aws_connection_role_arn else {
bail!("internal error: no AWS connection role configured");
};
Ok(json!(
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"AWS": aws_connection_role_arn
},
"Action": "sts:AssumeRole",
"Condition": {
"StringEquals": {
"sts:ExternalId": self.external_id(connection_context, connection_id)?
}
}
}
]
}
))
}
}
impl RustType<ProtoAwsAssumeRole> for AwsAssumeRole {
fn into_proto(&self) -> ProtoAwsAssumeRole {
ProtoAwsAssumeRole {
arn: self.arn.clone(),
session_name: self.session_name.clone(),
}
}
fn from_proto(proto: ProtoAwsAssumeRole) -> Result<Self, TryFromProtoError> {
Ok(AwsAssumeRole {
arn: proto.arn,
session_name: proto.session_name,
})
}
}
impl AwsConnection {
pub async fn load_sdk_config(
&self,
connection_context: &ConnectionContext,
connection_id: GlobalId,
in_task: InTask,
) -> Result<SdkConfig, anyhow::Error> {
let connection_context = connection_context.clone();
let this = self.clone();
async move {
let credentials = match &this.auth {
AwsAuth::Credentials(credentials) => SharedCredentialsProvider::new(
credentials
.load_credentials_provider(&connection_context, InTask::No)
.await?,
),
AwsAuth::AssumeRole(assume_role) => SharedCredentialsProvider::new(
assume_role
.load_credentials_provider(&connection_context, connection_id)
.await?,
),
};
this.load_sdk_config_from_credentials(credentials).await
}
.run_in_task_if(in_task, || "load_sdk_config".to_string())
.await
}
async fn load_sdk_config_from_credentials(
&self,
credentials: impl ProvideCredentials + 'static,
) -> Result<SdkConfig, anyhow::Error> {
let mut loader = mz_aws_util::defaults().credentials_provider(credentials);
if let Some(region) = &self.region {
loader = loader.region(Region::new(region.clone()));
}
if let Some(endpoint) = &self.endpoint {
loader = loader.endpoint_url(endpoint);
}
Ok(loader.load().await)
}
pub(crate) async fn validate(
&self,
id: GlobalId,
storage_configuration: &StorageConfiguration,
) -> Result<(), AwsConnectionValidationError> {
let aws_config = self
.load_sdk_config(
&storage_configuration.connection_context,
id,
InTask::No,
)
.await?;
let sts_client = aws_sdk_sts::Client::new(&aws_config);
let _ = sts_client.get_caller_identity().send().await?;
if let AwsAuth::AssumeRole(assume_role) = &self.auth {
let external_id = None;
let credentials = assume_role
.dangerously_load_credentials_provider(
&storage_configuration.connection_context,
id,
external_id,
)
.await?;
let aws_config = self.load_sdk_config_from_credentials(credentials).await?;
let sts_client = aws_sdk_sts::Client::new(&aws_config);
if sts_client.get_caller_identity().send().await.is_ok() {
return Err(AwsConnectionValidationError::RoleDoesNotRequireExternalId {
role_arn: assume_role.arn.clone(),
});
}
}
Ok(())
}
pub(crate) fn validate_by_default(&self) -> bool {
false
}
}
#[derive(thiserror::Error, Debug)]
pub enum AwsConnectionValidationError {
#[error("role trust policy does not require an external ID")]
RoleDoesNotRequireExternalId { role_arn: String },
#[error("{}", .0.display_with_causes())]
StsGetCallerIdentityError(#[from] SdkError<GetCallerIdentityError>),
#[error("{}", .0.display_with_causes())]
Other(#[from] anyhow::Error),
}
impl AwsConnectionValidationError {
pub fn detail(&self) -> Option<String> {
match self {
AwsConnectionValidationError::RoleDoesNotRequireExternalId {
role_arn
} => Some(format!("The trust policy for the connection's role ({role_arn}) is insecure and allows any Materialize customer to assume the role.")),
_ => None
}
}
pub fn hint(&self) -> Option<String> {
match self {
AwsConnectionValidationError::RoleDoesNotRequireExternalId { .. } => {
Some("See: https://materialize.com/s/aws-connection-role-trust-policy".into())
}
_ => None,
}
}
}
#[derive(Arbitrary, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct AwsConnectionReference<C: ConnectionAccess = InlinedConnection> {
pub connection_id: GlobalId,
pub connection: C::Aws,
}
impl<R: ConnectionResolver> IntoInlineConnection<AwsConnectionReference, R>
for AwsConnectionReference<ReferencedConnection>
{
fn into_inline_connection(self, r: R) -> AwsConnectionReference {
let AwsConnectionReference {
connection,
connection_id,
} = self;
AwsConnectionReference {
connection: r.resolve_connection(connection).unwrap_aws(),
connection_id,
}
}
}
impl RustType<ProtoAwsConnectionReference> for AwsConnectionReference<InlinedConnection> {
fn into_proto(&self) -> ProtoAwsConnectionReference {
ProtoAwsConnectionReference {
connection_id: Some(self.connection_id.into_proto()),
connection: Some(self.connection.into_proto()),
}
}
fn from_proto(proto: ProtoAwsConnectionReference) -> Result<Self, TryFromProtoError> {
Ok(AwsConnectionReference {
connection_id: proto
.connection_id
.into_rust_if_some("ProtoAwsConnectionReference::connection_id")?,
connection: proto
.connection
.into_rust_if_some("ProtoAwsConnectionReference::connection")?,
})
}
}