mz_mysql_util/
aws_rds.rs
1use aws_config::{Region, SdkConfig};
11use aws_credential_types::provider::ProvideCredentials;
12use aws_credential_types::provider::error::CredentialsError;
13use aws_sigv4::http_request::{
14 self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
15};
16use aws_sigv4::sign::v4::SigningParams;
17use aws_sigv4::sign::v4::signing_params::BuildError;
18use std::time::Duration;
19
20const AWS_TOKEN_EXPIRATION_SECONDS: u64 = 900;
21const AWS_SERVICE: &str = "rds-db";
22const DB_ACTION: &str = "connect";
23
24#[derive(Debug, thiserror::Error)]
25pub enum RdsTokenError {
26 #[error("Credentials provider required to sign RDS IAM request")]
28 MissingCredentialsProvider,
29
30 #[error(transparent)]
32 CredentialsError(#[from] CredentialsError),
33
34 #[error(transparent)]
36 SigningParametersBuildError(#[from] BuildError),
37
38 #[error(transparent)]
40 SigningError(#[from] SigningError),
41}
42
43pub(crate) async fn rds_auth_token(
46 host: &str,
47 port: u16,
48 username: &str,
49 aws_config: &SdkConfig,
50) -> Result<String, RdsTokenError> {
51 let credentials = aws_config
52 .credentials_provider()
53 .ok_or(RdsTokenError::MissingCredentialsProvider)?
54 .provide_credentials()
55 .await?;
56
57 let region = aws_config
59 .region()
60 .cloned()
61 .unwrap_or_else(|| Region::new("us-east-1"));
62
63 let time_source = aws_config.time_source().unwrap_or_default();
65
66 let mut signing_settings = SigningSettings::default();
67 signing_settings.expires_in = Some(Duration::from_secs(AWS_TOKEN_EXPIRATION_SECONDS));
68 signing_settings.signature_location = SignatureLocation::QueryParams;
69
70 let identity = credentials.into();
71 let signing_params = SigningParams::builder()
72 .identity(&identity)
73 .region(region.as_ref())
74 .name(AWS_SERVICE)
75 .time(time_source.now())
76 .settings(signing_settings)
77 .build()?;
78
79 let url = format!(
80 "https://{}:{}/?Action={}&DBUser={}",
81 host, port, DB_ACTION, username
82 );
83
84 let signable_req =
85 SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
86
87 let (signing_instructions, _sig) =
88 http_request::sign(signable_req, &signing_params.into())?.into_parts();
89
90 let mut url = url::Url::parse(&url).unwrap_or_else(|_| panic!("expect to parse {url}"));
91 for (key, val) in signing_instructions.params() {
92 url.query_pairs_mut().append_pair(key, val);
93 }
94
95 Ok(url.to_string().split_off("https://".len()))
96}
97
98#[cfg(test)]
99mod test {
100 use aws_credential_types::Credentials;
101 use aws_types::sdk_config::{SharedCredentialsProvider, TimeSource};
102 use std::time::SystemTime;
103
104 use super::*;
105 #[derive(Debug)]
106 struct TestTimeSource {
107 time: SystemTime,
108 }
109 impl TimeSource for TestTimeSource {
110 fn now(&self) -> SystemTime {
111 self.time.clone()
112 }
113 }
114
115 #[mz_ore::test(tokio::test)]
116 #[cfg_attr(miri, ignore)] async fn test_signature() {
118 let time_source = TestTimeSource {
119 time: SystemTime::UNIX_EPOCH + Duration::from_secs(1740690000),
120 };
121 let aws_config = SdkConfig::builder()
122 .credentials_provider(SharedCredentialsProvider::new(Credentials::new(
123 "drjekyll", "mrhyde", None, None, "test",
124 )))
125 .time_source(time_source)
126 .build();
127 let signature = rds_auth_token("mysql", 3306, "root", &aws_config)
128 .await
129 .unwrap();
130 assert_eq!(
131 &signature,
132 "mysql:3306/?Action=connect&DBUser=root&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=drjekyll%2F20250227%2Fus-east-1%2Frds-db%2Faws4_request&X-Amz-Date=20250227T210000Z&X-Amz-Expires=900&X-Amz-SignedHeaders=host&X-Amz-Signature=6eb04929394feeb9e070621ecafc731145914cc53e542d5302c251b705c4ac72"
133 );
134 }
135}