1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! Utility functions for AWS.

use anyhow::{anyhow, Context};
use rusoto_core::Region;
use rusoto_credential::{
    AutoRefreshingProvider, AwsCredentials, ChainProvider, ProvideAwsCredentials, StaticProvider,
};
use rusoto_sts::{GetCallerIdentityRequest, Sts, StsClient};
use serde::{Deserialize, Serialize};
use tokio::time::{self, Duration};

/// How long Materialize waits for various parts of initial authorization
// TODO(#7115): Make this configurable everywhere it is used
pub const AUTH_TIMEOUT: Duration = Duration::from_secs(10);

/// Information required to connnect to AWS
///
/// Credentials are optional because in most cases users should use the
/// [`ChainProvider`] to pull information from the process or AWS environment.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct ConnectInfo {
    /// The AWS Region to connect to
    pub region: Region,
    /// Credentials, if missing will be obtained from environment
    pub credentials: Option<Credentials>,
}

/// A thin dupe of [`AwsCredentials`] so we can impl Serialize
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Credentials {
    key: String,
    secret: String,
    token: Option<String>,
}

impl From<Credentials> for AwsCredentials {
    fn from(creds: Credentials) -> AwsCredentials {
        AwsCredentials::new(creds.key, creds.secret, creds.token, None)
    }
}

impl ConnectInfo {
    /// Construct a ConnectInfo
    pub fn new(
        region: Region,
        key: Option<String>,
        secret: Option<String>,
        token: Option<String>,
    ) -> Result<ConnectInfo, anyhow::Error> {
        match (key, secret) {
            (Some(key), Some(secret)) => Ok(ConnectInfo {
                region,
                credentials: Some(Credentials { key, secret, token }),
            }),
            (None, None) => Ok(ConnectInfo {
                region,
                credentials: None,
            }),
            (_, _) => {
                anyhow::bail!(
                    "Both aws_acccess_key_id and aws_secret_access_key \
                               must be provided, or neither"
                );
            }
        }
    }
}

/// Fetches the AWS account number of the caller via AWS Security Token Service.
///
/// For details about STS, see [AWS documentation][].
///
/// [AWS documentation]: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
pub async fn account(
    provider: impl ProvideAwsCredentials + Send + Sync + 'static,
    region: Region,
    timeout: Duration,
) -> Result<String, anyhow::Error> {
    let dispatcher =
        crate::client::http().context("creating HTTP client for AWS STS Account verification")?;
    let sts_client = StsClient::new_with(dispatcher, provider, region);
    let get_identity = sts_client.get_caller_identity(GetCallerIdentityRequest {});
    let account = time::timeout(timeout, get_identity)
        .await
        .context("timeout while retrieving AWS account number from STS".to_owned())?
        .context("retrieving AWS account ID")?
        .account
        .ok_or_else(|| anyhow!("AWS did not return account ID"))?;
    Ok(account)
}

/// Verify that the provided credentials are legitimate
///
/// This uses an [always-valid][] API request to check that the AWS credentials
/// provided are recognized by AWS. It does not verify that the credentials can
/// perform all of the actions required for any specific source.
///
/// [always-valid]: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
pub async fn validate_credentials(
    conn_info: ConnectInfo,
    timeout: Duration,
) -> Result<(), anyhow::Error> {
    if let Some(creds) = conn_info.credentials {
        let provider = StaticProvider::from(AwsCredentials::from(creds));
        account(provider.clone(), conn_info.region, timeout)
            .await
            .context("Using statically provided credentials")?;
    } else {
        let mut provider = ChainProvider::new();
        provider.set_timeout(Duration::from_secs(10));
        let provider =
            AutoRefreshingProvider::new(provider).context("generating AWS credentials")?;
        account(provider.clone(), conn_info.region, timeout)
            .await
            .context("Looking through the environment for credentials")?;
    }
    Ok(())
}