1use aws_config::{Region, SdkConfig};
11use aws_credential_types::provider::ProvideCredentials;
12use aws_credential_types::provider::error::CredentialsError;
13use aws_sigv4::http_request::{
14 self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
15};
16use aws_sigv4::sign::v4::SigningParams;
17use aws_sigv4::sign::v4::signing_params::BuildError;
18use std::time::Duration;
19
20const AWS_TOKEN_EXPIRATION_SECONDS: u64 = 900;
21const AWS_SERVICE: &str = "rds-db";
22const DB_ACTION: &str = "connect";
23
24#[derive(Debug, thiserror::Error)]
25pub enum RdsTokenError {
26 #[error("Credentials provider required to sign RDS IAM request")]
28 MissingCredentialsProvider,
29
30 #[error(transparent)]
32 CredentialsError(#[from] CredentialsError),
33
34 #[error(transparent)]
36 SigningParametersBuildError(#[from] BuildError),
37
38 #[error(transparent)]
40 SigningError(#[from] SigningError),
41
42 #[error("invalid RDS endpoint: {0}")]
44 InvalidEndpoint(#[from] url::ParseError),
45}
46
47pub(crate) async fn rds_auth_token(
50 host: &str,
51 port: u16,
52 username: &str,
53 aws_config: &SdkConfig,
54) -> Result<String, RdsTokenError> {
55 let credentials = aws_config
56 .credentials_provider()
57 .ok_or(RdsTokenError::MissingCredentialsProvider)?
58 .provide_credentials()
59 .await?;
60
61 let region = aws_config
63 .region()
64 .cloned()
65 .unwrap_or_else(|| Region::new("us-east-1"));
66
67 let time_source = aws_config.time_source().unwrap_or_default();
69
70 let mut signing_settings = SigningSettings::default();
71 signing_settings.expires_in = Some(Duration::from_secs(AWS_TOKEN_EXPIRATION_SECONDS));
72 signing_settings.signature_location = SignatureLocation::QueryParams;
73
74 let identity = credentials.into();
75 let signing_params = SigningParams::builder()
76 .identity(&identity)
77 .region(region.as_ref())
78 .name(AWS_SERVICE)
79 .time(time_source.now())
80 .settings(signing_settings)
81 .build()?;
82
83 let mut url = url::Url::parse(&format!("https://{}:{}/", host, port))?;
84 url.query_pairs_mut()
85 .append_pair("Action", DB_ACTION)
86 .append_pair("DBUser", username);
87
88 let url_str = url.as_str().to_owned();
89 let signable_req = SignableRequest::new(
90 "GET",
91 &url_str,
92 std::iter::empty(),
93 SignableBody::Bytes(&[]),
94 )?;
95
96 let (signing_instructions, _sig) =
97 http_request::sign(signable_req, &signing_params.into())?.into_parts();
98
99 for (key, val) in signing_instructions.params() {
100 url.query_pairs_mut().append_pair(key, val);
101 }
102
103 Ok(url.as_str()["https://".len()..].to_owned())
104}
105
106#[cfg(test)]
107mod test {
108 use aws_credential_types::Credentials;
109 use aws_types::sdk_config::{SharedCredentialsProvider, TimeSource};
110 use std::time::SystemTime;
111
112 use super::*;
113 #[derive(Debug)]
114 struct TestTimeSource {
115 time: SystemTime,
116 }
117 impl TimeSource for TestTimeSource {
118 fn now(&self) -> SystemTime {
119 self.time.clone()
120 }
121 }
122
123 #[mz_ore::test(tokio::test)]
124 #[cfg_attr(miri, ignore)] async fn test_signature() {
126 let time_source = TestTimeSource {
127 time: SystemTime::UNIX_EPOCH + Duration::from_secs(1740690000),
128 };
129 let aws_config = SdkConfig::builder()
130 .credentials_provider(SharedCredentialsProvider::new(Credentials::new(
131 "drjekyll", "mrhyde", None, None, "test",
132 )))
133 .time_source(time_source)
134 .build();
135 let signature = rds_auth_token("mysql", 3306, "root", &aws_config)
136 .await
137 .unwrap();
138 assert_eq!(
139 &signature,
140 "mysql:3306/?Action=connect&DBUser=root&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=drjekyll%2F20250227%2Fus-east-1%2Frds-db%2Faws4_request&X-Amz-Date=20250227T210000Z&X-Amz-Expires=900&X-Amz-SignedHeaders=host&X-Amz-Signature=6eb04929394feeb9e070621ecafc731145914cc53e542d5302c251b705c4ac72"
141 );
142
143 for (username, reject_key) in [("root&Foo=bar", Some("Foo")), ("root#admin", None)] {
146 let token = rds_auth_token("mysql", 3306, username, &aws_config)
147 .await
148 .unwrap();
149 let parsed = url::Url::parse(&format!("https://{token}")).unwrap();
150 let pairs: std::collections::BTreeMap<_, _> = parsed.query_pairs().collect();
151 assert_eq!(pairs.get("DBUser").unwrap(), username);
152 if let Some(key) = reject_key {
153 assert!(!pairs.contains_key(key));
154 }
155 }
156
157 assert!(rds_auth_token("", 3306, "root", &aws_config).await.is_err());
158 }
159}