mz_mysql_util/
aws_rds.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use aws_config::{Region, SdkConfig};
11use aws_credential_types::provider::ProvideCredentials;
12use aws_credential_types::provider::error::CredentialsError;
13use aws_sigv4::http_request::{
14    self, SignableBody, SignableRequest, SignatureLocation, SigningError, SigningSettings,
15};
16use aws_sigv4::sign::v4::SigningParams;
17use aws_sigv4::sign::v4::signing_params::BuildError;
18use std::time::Duration;
19
20const AWS_TOKEN_EXPIRATION_SECONDS: u64 = 900;
21const AWS_SERVICE: &str = "rds-db";
22const DB_ACTION: &str = "connect";
23
24#[derive(Debug, thiserror::Error)]
25pub enum RdsTokenError {
26    /// The AWS SdkConfig did not contain a credentials provider.
27    #[error("Credentials provider required to sign RDS IAM request")]
28    MissingCredentialsProvider,
29
30    /// The supplied credentials provider failed to generate credentials for URL signing.
31    #[error(transparent)]
32    CredentialsError(#[from] CredentialsError),
33
34    /// The signing parameters could not be created due to a missing required argument.
35    #[error(transparent)]
36    SigningParametersBuildError(#[from] BuildError),
37
38    /// The URL could not be signed.
39    #[error(transparent)]
40    SigningError(#[from] SigningError),
41}
42
43// Generate an RDS authentication token.  This should mirror what can be found
44// in aws_sdk_rds::auth_token, but without restricted creates.
45pub(crate) async fn rds_auth_token(
46    host: &str,
47    port: u16,
48    username: &str,
49    aws_config: &SdkConfig,
50) -> Result<String, RdsTokenError> {
51    let credentials = aws_config
52        .credentials_provider()
53        .ok_or(RdsTokenError::MissingCredentialsProvider)?
54        .provide_credentials()
55        .await?;
56
57    // mirror existing SDK behaviour: default to us-east-1,
58    let region = aws_config
59        .region()
60        .cloned()
61        .unwrap_or_else(|| Region::new("us-east-1"));
62
63    // defaults to SystemTime
64    let time_source = aws_config.time_source().unwrap_or_default();
65
66    let mut signing_settings = SigningSettings::default();
67    signing_settings.expires_in = Some(Duration::from_secs(AWS_TOKEN_EXPIRATION_SECONDS));
68    signing_settings.signature_location = SignatureLocation::QueryParams;
69
70    let identity = credentials.into();
71    let signing_params = SigningParams::builder()
72        .identity(&identity)
73        .region(region.as_ref())
74        .name(AWS_SERVICE)
75        .time(time_source.now())
76        .settings(signing_settings)
77        .build()?;
78
79    let url = format!(
80        "https://{}:{}/?Action={}&DBUser={}",
81        host, port, DB_ACTION, username
82    );
83
84    let signable_req =
85        SignableRequest::new("GET", &url, std::iter::empty(), SignableBody::Bytes(&[]))?;
86
87    let (signing_instructions, _sig) =
88        http_request::sign(signable_req, &signing_params.into())?.into_parts();
89
90    let mut url = url::Url::parse(&url).unwrap_or_else(|_| panic!("expect to parse {url}"));
91    for (key, val) in signing_instructions.params() {
92        url.query_pairs_mut().append_pair(key, val);
93    }
94
95    Ok(url.to_string().split_off("https://".len()))
96}
97
98#[cfg(test)]
99mod test {
100    use aws_credential_types::Credentials;
101    use aws_types::sdk_config::{SharedCredentialsProvider, TimeSource};
102    use std::time::SystemTime;
103
104    use super::*;
105    #[derive(Debug)]
106    struct TestTimeSource {
107        time: SystemTime,
108    }
109    impl TimeSource for TestTimeSource {
110        fn now(&self) -> SystemTime {
111            self.time.clone()
112        }
113    }
114
115    #[mz_ore::test(tokio::test)]
116    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `getauxval` on OS `linux`
117    async fn test_signature() {
118        let time_source = TestTimeSource {
119            time: SystemTime::UNIX_EPOCH + Duration::from_secs(1740690000),
120        };
121        let aws_config = SdkConfig::builder()
122            .credentials_provider(SharedCredentialsProvider::new(Credentials::new(
123                "drjekyll", "mrhyde", None, None, "test",
124            )))
125            .time_source(time_source)
126            .build();
127        let signature = rds_auth_token("mysql", 3306, "root", &aws_config)
128            .await
129            .unwrap();
130        assert_eq!(
131            &signature,
132            "mysql:3306/?Action=connect&DBUser=root&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=drjekyll%2F20250227%2Fus-east-1%2Frds-db%2Faws4_request&X-Amz-Date=20250227T210000Z&X-Amz-Expires=900&X-Amz-SignedHeaders=host&X-Amz-Signature=6eb04929394feeb9e070621ecafc731145914cc53e542d5302c251b705c4ac72"
133        );
134    }
135}