Skip to main content

mz_mysql_util/
aws_rds.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use aws_config::{Region, SdkConfig};
11use aws_credential_types::provider::ProvideCredentials;
12use aws_credential_types::provider::error::CredentialsError;
13use aws_sigv4::http_request::{
14    self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
15};
16use aws_sigv4::sign::v4::SigningParams;
17use aws_sigv4::sign::v4::signing_params::BuildError;
18use std::time::Duration;
19
20const AWS_TOKEN_EXPIRATION_SECONDS: u64 = 900;
21const AWS_SERVICE: &str = "rds-db";
22const DB_ACTION: &str = "connect";
23
24#[derive(Debug, thiserror::Error)]
25pub enum RdsTokenError {
26    /// The AWS SdkConfig did not contain a credentials provider.
27    #[error("Credentials provider required to sign RDS IAM request")]
28    MissingCredentialsProvider,
29
30    /// The supplied credentials provider failed to generate credentials for URL signing.
31    #[error(transparent)]
32    CredentialsError(#[from] CredentialsError),
33
34    /// The signing parameters could not be created due to a missing required argument.
35    #[error(transparent)]
36    SigningParametersBuildError(#[from] BuildError),
37
38    /// The URL could not be signed.
39    #[error(transparent)]
40    SigningError(#[from] SigningError),
41
42    /// The host could not be parsed as a valid URL.
43    #[error("invalid RDS endpoint: {0}")]
44    InvalidEndpoint(#[from] url::ParseError),
45}
46
47// Generate an RDS authentication token.  This should mirror what can be found
48// in aws_sdk_rds::auth_token, but without restricted creates.
49pub(crate) async fn rds_auth_token(
50    host: &str,
51    port: u16,
52    username: &str,
53    aws_config: &SdkConfig,
54) -> Result<String, RdsTokenError> {
55    let credentials = aws_config
56        .credentials_provider()
57        .ok_or(RdsTokenError::MissingCredentialsProvider)?
58        .provide_credentials()
59        .await?;
60
61    // mirror existing SDK behaviour: default to us-east-1,
62    let region = aws_config
63        .region()
64        .cloned()
65        .unwrap_or_else(|| Region::new("us-east-1"));
66
67    // defaults to SystemTime
68    let time_source = aws_config.time_source().unwrap_or_default();
69
70    let mut signing_settings = SigningSettings::default();
71    signing_settings.expires_in = Some(Duration::from_secs(AWS_TOKEN_EXPIRATION_SECONDS));
72    signing_settings.signature_location = SignatureLocation::QueryParams;
73
74    let identity = credentials.into();
75    let signing_params = SigningParams::builder()
76        .identity(&identity)
77        .region(region.as_ref())
78        .name(AWS_SERVICE)
79        .time(time_source.now())
80        .settings(signing_settings)
81        .build()?;
82
83    let mut url = url::Url::parse(&format!("https://{}:{}/", host, port))?;
84    url.query_pairs_mut()
85        .append_pair("Action", DB_ACTION)
86        .append_pair("DBUser", username);
87
88    let url_str = url.as_str().to_owned();
89    let signable_req = SignableRequest::new(
90        "GET",
91        &url_str,
92        std::iter::empty(),
93        SignableBody::Bytes(&[]),
94    )?;
95
96    let (signing_instructions, _sig) =
97        http_request::sign(signable_req, &signing_params.into())?.into_parts();
98
99    for (key, val) in signing_instructions.params() {
100        url.query_pairs_mut().append_pair(key, val);
101    }
102
103    Ok(url.as_str()["https://".len()..].to_owned())
104}
105
106#[cfg(test)]
107mod test {
108    use aws_credential_types::Credentials;
109    use aws_types::sdk_config::{SharedCredentialsProvider, TimeSource};
110    use std::time::SystemTime;
111
112    use super::*;
113    #[derive(Debug)]
114    struct TestTimeSource {
115        time: SystemTime,
116    }
117    impl TimeSource for TestTimeSource {
118        fn now(&self) -> SystemTime {
119            self.time.clone()
120        }
121    }
122
123    #[mz_ore::test(tokio::test)]
124    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `getauxval` on OS `linux`
125    async fn test_signature() {
126        let time_source = TestTimeSource {
127            time: SystemTime::UNIX_EPOCH + Duration::from_secs(1740690000),
128        };
129        let aws_config = SdkConfig::builder()
130            .credentials_provider(SharedCredentialsProvider::new(Credentials::new(
131                "drjekyll", "mrhyde", None, None, "test",
132            )))
133            .time_source(time_source)
134            .build();
135        let signature = rds_auth_token("mysql", 3306, "root", &aws_config)
136            .await
137            .unwrap();
138        assert_eq!(
139            &signature,
140            "mysql:3306/?Action=connect&DBUser=root&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=drjekyll%2F20250227%2Fus-east-1%2Frds-db%2Faws4_request&X-Amz-Date=20250227T210000Z&X-Amz-Expires=900&X-Amz-SignedHeaders=host&X-Amz-Signature=6eb04929394feeb9e070621ecafc731145914cc53e542d5302c251b705c4ac72"
141        );
142
143        // Usernames with URL-special characters must be percent-encoded in the
144        // token, not interpolated raw into the query string.
145        for (username, reject_key) in [("root&Foo=bar", Some("Foo")), ("root#admin", None)] {
146            let token = rds_auth_token("mysql", 3306, username, &aws_config)
147                .await
148                .unwrap();
149            let parsed = url::Url::parse(&format!("https://{token}")).unwrap();
150            let pairs: std::collections::BTreeMap<_, _> = parsed.query_pairs().collect();
151            assert_eq!(pairs.get("DBUser").unwrap(), username);
152            if let Some(key) = reject_key {
153                assert!(!pairs.contains_key(key));
154            }
155        }
156
157        assert!(rds_auth_token("", 3306, "root", &aws_config).await.is_err());
158    }
159}