use super::client::error::ImdsError;
use crate::imds::{self, Client};
use crate::json_credentials::{parse_json_credentials, JsonCredentials, RefreshableCredentials};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
use aws_credential_types::Credentials;
use aws_smithy_async::time::SharedTimeSource;
use aws_types::os_shim_internal::Env;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
const CREDENTIAL_EXPIRATION_INTERVAL: Duration = Duration::from_secs(10 * 60);
const WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY: &str =
"Attempting credential expiration extension due to a credential service availability issue. \
A refresh of these credentials will be attempted again within the next";
#[derive(Debug)]
struct ImdsCommunicationError {
source: Box<dyn StdError + Send + Sync + 'static>,
}
impl fmt::Display for ImdsCommunicationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "could not communicate with IMDS")
}
}
impl StdError for ImdsCommunicationError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(self.source.as_ref())
}
}
#[derive(Debug)]
pub struct ImdsCredentialsProvider {
client: Client,
env: Env,
profile: Option<String>,
time_source: SharedTimeSource,
last_retrieved_credentials: Arc<RwLock<Option<Credentials>>>,
}
#[derive(Default, Debug)]
pub struct Builder {
provider_config: Option<ProviderConfig>,
profile_override: Option<String>,
imds_override: Option<imds::Client>,
last_retrieved_credentials: Option<Credentials>,
}
impl Builder {
pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
self.provider_config = Some(provider_config.clone());
self
}
pub fn profile(mut self, profile: impl Into<String>) -> Self {
self.profile_override = Some(profile.into());
self
}
pub fn imds_client(mut self, client: imds::Client) -> Self {
self.imds_override = Some(client);
self
}
#[allow(dead_code)]
#[cfg(test)]
fn last_retrieved_credentials(mut self, credentials: Credentials) -> Self {
self.last_retrieved_credentials = Some(credentials);
self
}
pub fn build(self) -> ImdsCredentialsProvider {
let provider_config = self.provider_config.unwrap_or_default();
let env = provider_config.env();
let client = self
.imds_override
.unwrap_or_else(|| imds::Client::builder().configure(&provider_config).build());
ImdsCredentialsProvider {
client,
env,
profile: self.profile_override,
time_source: provider_config.time_source(),
last_retrieved_credentials: Arc::new(RwLock::new(self.last_retrieved_credentials)),
}
}
}
mod codes {
pub(super) const ASSUME_ROLE_UNAUTHORIZED_ACCESS: &str = "AssumeRoleUnauthorizedAccess";
}
impl ProvideCredentials for ImdsCredentialsProvider {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
where
Self: 'a,
{
future::ProvideCredentials::new(self.credentials())
}
fn fallback_on_interrupt(&self) -> Option<Credentials> {
self.last_retrieved_credentials.read().unwrap().clone()
}
}
impl ImdsCredentialsProvider {
pub fn builder() -> Builder {
Builder::default()
}
fn imds_disabled(&self) -> bool {
match self.env.get(super::env::EC2_METADATA_DISABLED) {
Ok(value) => value.eq_ignore_ascii_case("true"),
_ => false,
}
}
async fn get_profile_uncached(&self) -> Result<String, CredentialsError> {
match self
.client
.get("/latest/meta-data/iam/security-credentials/")
.await
{
Ok(profile) => Ok(profile.as_ref().into()),
Err(ImdsError::ErrorResponse(context))
if context.response().status().as_u16() == 404 =>
{
tracing::warn!(
"received 404 from IMDS when loading profile information. \
Hint: This instance may not have an IAM role associated."
);
Err(CredentialsError::not_loaded("received 404 from IMDS"))
}
Err(ImdsError::FailedToLoadToken(context)) if context.is_dispatch_failure() => {
Err(CredentialsError::not_loaded(ImdsCommunicationError {
source: context.into_source().into(),
}))
}
Err(other) => Err(CredentialsError::provider_error(other)),
}
}
fn maybe_extend_expiration(&self, expiration: SystemTime) -> SystemTime {
let now = self.time_source.now();
if now < expiration {
return expiration;
}
let mut rng = fastrand::Rng::with_seed(
now.duration_since(SystemTime::UNIX_EPOCH)
.expect("now should be after UNIX EPOCH")
.as_secs(),
);
let refresh_offset = CREDENTIAL_EXPIRATION_INTERVAL + Duration::from_secs(rng.u64(0..=300));
let new_expiry = now + refresh_offset;
tracing::warn!(
"{WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY} {:.2} minutes.",
refresh_offset.as_secs_f64() / 60.0,
);
new_expiry
}
async fn retrieve_credentials(&self) -> provider::Result {
if self.imds_disabled() {
tracing::debug!(
"IMDS disabled because AWS_EC2_METADATA_DISABLED env var was set to `true`"
);
return Err(CredentialsError::not_loaded(
"IMDS disabled by AWS_ECS_METADATA_DISABLED env var",
));
}
tracing::debug!("loading credentials from IMDS");
let profile: Cow<'_, str> = match &self.profile {
Some(profile) => profile.into(),
None => self.get_profile_uncached().await?.into(),
};
tracing::debug!(profile = %profile, "loaded profile");
let credentials = self
.client
.get(format!(
"/latest/meta-data/iam/security-credentials/{}",
profile
))
.await
.map_err(CredentialsError::provider_error)?;
match parse_json_credentials(credentials.as_ref()) {
Ok(JsonCredentials::RefreshableCredentials(RefreshableCredentials {
access_key_id,
secret_access_key,
session_token,
expiration,
..
})) => {
let expiration = self.maybe_extend_expiration(expiration);
let creds = Credentials::new(
access_key_id,
secret_access_key,
Some(session_token.to_string()),
expiration.into(),
"IMDSv2",
);
*self.last_retrieved_credentials.write().unwrap() = Some(creds.clone());
Ok(creds)
}
Ok(JsonCredentials::Error { code, message })
if code == codes::ASSUME_ROLE_UNAUTHORIZED_ACCESS =>
{
Err(CredentialsError::invalid_configuration(format!(
"Incorrect IMDS/IAM configuration: [{}] {}. \
Hint: Does this role have a trust relationship with EC2?",
code, message
)))
}
Ok(JsonCredentials::Error { code, message }) => {
Err(CredentialsError::provider_error(format!(
"Error retrieving credentials from IMDS: {} {}",
code, message
)))
}
Err(invalid) => Err(CredentialsError::unhandled(invalid)),
}
}
async fn credentials(&self) -> provider::Result {
match self.retrieve_credentials().await {
creds @ Ok(_) => creds,
err => match &*self.last_retrieved_credentials.read().unwrap() {
Some(creds) => Ok(creds.clone()),
_ => err,
},
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::imds::client::test::{
imds_request, imds_response, make_imds_client, token_request, token_response,
};
use crate::provider_config::ProviderConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_smithy_async::test_util::instant_time_and_sleep;
use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient};
use aws_smithy_types::body::SdkBody;
use std::time::{Duration, UNIX_EPOCH};
use tracing_test::traced_test;
const TOKEN_A: &str = "token_a";
#[tokio::test]
async fn profile_is_not_cached() {
let http_client = StaticReplayClient::new(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
imds_response(r#"profile-name"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
imds_response("{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIARTEST\",\n \"SecretAccessKey\" : \"testsecret\",\n \"Token\" : \"testtoken\",\n \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
imds_response(r#"different-profile"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/different-profile", TOKEN_A),
imds_response("{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIARTEST2\",\n \"SecretAccessKey\" : \"testsecret\",\n \"Token\" : \"testtoken\",\n \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
),
]);
let client = ImdsCredentialsProvider::builder()
.imds_client(make_imds_client(&http_client))
.build();
let creds1 = client.provide_credentials().await.expect("valid creds");
let creds2 = client.provide_credentials().await.expect("valid creds");
assert_eq!(creds1.access_key_id(), "ASIARTEST");
assert_eq!(creds2.access_key_id(), "ASIARTEST2");
http_client.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn credentials_not_stale_should_be_used_as_they_are() {
let http_client = StaticReplayClient::new(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
imds_response(r#"profile-name"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
imds_response("{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIARTEST\",\n \"SecretAccessKey\" : \"testsecret\",\n \"Token\" : \"testtoken\",\n \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
),
]);
let time_of_request_to_fetch_credentials = UNIX_EPOCH + Duration::from_secs(1632197810);
let (time_source, sleep) = instant_time_and_sleep(time_of_request_to_fetch_credentials);
let provider_config = ProviderConfig::no_configuration()
.with_http_client(http_client.clone())
.with_sleep_impl(sleep)
.with_time_source(time_source);
let client = crate::imds::Client::builder()
.configure(&provider_config)
.build();
let provider = ImdsCredentialsProvider::builder()
.configure(&provider_config)
.imds_client(client)
.build();
let creds = provider.provide_credentials().await.expect("valid creds");
assert_eq!(
creds.expiry(),
UNIX_EPOCH.checked_add(Duration::from_secs(1632197813))
);
http_client.assert_requests_match(&[]);
assert!(!logs_contain(WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY));
}
#[tokio::test]
#[traced_test]
async fn expired_credentials_should_be_extended() {
let http_client = StaticReplayClient::new(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
imds_response(r#"profile-name"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
imds_response("{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIARTEST\",\n \"SecretAccessKey\" : \"testsecret\",\n \"Token\" : \"testtoken\",\n \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
),
]);
let time_of_request_to_fetch_credentials = UNIX_EPOCH + Duration::from_secs(1632246085);
let (time_source, sleep) = instant_time_and_sleep(time_of_request_to_fetch_credentials);
let provider_config = ProviderConfig::no_configuration()
.with_http_client(http_client.clone())
.with_sleep_impl(sleep)
.with_time_source(time_source);
let client = crate::imds::Client::builder()
.configure(&provider_config)
.build();
let provider = ImdsCredentialsProvider::builder()
.configure(&provider_config)
.imds_client(client)
.build();
let creds = provider.provide_credentials().await.expect("valid creds");
assert!(creds.expiry().unwrap() > time_of_request_to_fetch_credentials);
http_client.assert_requests_match(&[]);
assert!(logs_contain(WARNING_FOR_EXTENDING_CREDENTIALS_EXPIRY));
}
#[tokio::test]
#[cfg(feature = "rustls")]
async fn read_timeout_during_credentials_refresh_should_yield_last_retrieved_credentials() {
let client = crate::imds::Client::builder()
.endpoint("http://240.0.0.0")
.unwrap()
.build();
let expected = aws_credential_types::Credentials::for_tests();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
.last_retrieved_credentials(expected.clone())
.build();
let actual = provider.provide_credentials().await;
assert_eq!(actual.unwrap(), expected);
}
#[tokio::test]
#[cfg(feature = "rustls")]
async fn read_timeout_during_credentials_refresh_should_error_without_last_retrieved_credentials(
) {
let client = crate::imds::Client::builder()
.endpoint("http://240.0.0.0")
.unwrap()
.build();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
.build();
let actual = provider.provide_credentials().await;
assert!(
matches!(actual, Err(CredentialsError::CredentialsNotLoaded(_))),
"\nexpected: Err(CredentialsError::CredentialsNotLoaded(_))\nactual: {actual:?}"
);
}
#[cfg_attr(windows, ignore)]
#[tokio::test]
#[cfg(feature = "rustls")]
async fn external_timeout_during_credentials_refresh_should_yield_last_retrieved_credentials() {
use aws_smithy_async::rt::sleep::AsyncSleep;
let client = crate::imds::Client::builder()
.endpoint("http://240.0.0.0")
.unwrap()
.build();
let expected = aws_credential_types::Credentials::for_tests();
let provider = ImdsCredentialsProvider::builder()
.imds_client(client)
.last_retrieved_credentials(expected.clone())
.build();
let sleeper = aws_smithy_async::rt::sleep::TokioSleep::new();
let timeout = aws_smithy_async::future::timeout::Timeout::new(
provider.provide_credentials(),
sleeper.sleep(std::time::Duration::from_millis(100)),
);
match timeout.await {
Ok(_) => panic!("provide_credentials completed before timeout future"),
Err(_err) => match provider.fallback_on_interrupt() {
Some(actual) => assert_eq!(actual, expected),
None => panic!(
"provide_credentials timed out and no credentials returned from fallback_on_interrupt"
),
},
};
}
#[tokio::test]
async fn fallback_credentials_should_be_used_when_imds_returns_500_during_credentials_refresh()
{
let http_client = StaticReplayClient::new(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
imds_response(r#"profile-name"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/profile-name", TOKEN_A),
imds_response("{\n \"Code\" : \"Success\",\n \"LastUpdated\" : \"2021-09-20T21:42:26Z\",\n \"Type\" : \"AWS-HMAC\",\n \"AccessKeyId\" : \"ASIARTEST\",\n \"SecretAccessKey\" : \"testsecret\",\n \"Token\" : \"testtoken\",\n \"Expiration\" : \"2021-09-21T04:16:53Z\"\n}"),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/meta-data/iam/security-credentials/", TOKEN_A),
http::Response::builder().status(500).body(SdkBody::empty()).unwrap(),
),
]);
let provider = ImdsCredentialsProvider::builder()
.imds_client(make_imds_client(&http_client))
.build();
let creds1 = provider.provide_credentials().await.expect("valid creds");
assert_eq!(creds1.access_key_id(), "ASIARTEST");
let creds2 = provider.provide_credentials().await.expect("valid creds");
assert_eq!(creds1, creds2);
http_client.assert_requests_match(&[]);
}
}