use aws_credential_types::{
provider::{self, error::CredentialsError, future, ProvideCredentials},
Credentials,
};
use aws_smithy_types::error::display::DisplayErrorContext;
use std::borrow::Cow;
use tracing::Instrument;
#[derive(Debug)]
pub struct CredentialsProviderChain {
providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
}
impl CredentialsProviderChain {
pub fn first_try(
name: impl Into<Cow<'static, str>>,
provider: impl ProvideCredentials + 'static,
) -> Self {
CredentialsProviderChain {
providers: vec![(name.into(), Box::new(provider))],
}
}
pub fn or_else(
mut self,
name: impl Into<Cow<'static, str>>,
provider: impl ProvideCredentials + 'static,
) -> Self {
self.providers.push((name.into(), Box::new(provider)));
self
}
#[cfg(feature = "rustls")]
pub async fn or_default_provider(self) -> Self {
self.or_else(
"DefaultProviderChain",
crate::default_provider::credentials::default_provider().await,
)
}
#[cfg(feature = "rustls")]
pub async fn default_provider() -> Self {
Self::first_try(
"DefaultProviderChain",
crate::default_provider::credentials::default_provider().await,
)
}
async fn credentials(&self) -> provider::Result {
for (name, provider) in &self.providers {
let span = tracing::debug_span!("load_credentials", provider = %name);
match provider.provide_credentials().instrument(span).await {
Ok(credentials) => {
tracing::debug!(provider = %name, "loaded credentials");
return Ok(credentials);
}
Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
}
Err(err) => {
tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
return Err(err);
}
}
}
Err(CredentialsError::not_loaded(
"no providers in chain provided credentials",
))
}
}
impl ProvideCredentials for CredentialsProviderChain {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'_>
where
Self: 'a,
{
future::ProvideCredentials::new(self.credentials())
}
fn fallback_on_interrupt(&self) -> Option<Credentials> {
for (_, provider) in &self.providers {
match provider.fallback_on_interrupt() {
creds @ Some(_) => return creds,
None => {}
}
}
None
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use aws_credential_types::{
credential_fn::provide_credentials_fn,
provider::{error::CredentialsError, future, ProvideCredentials},
Credentials,
};
use aws_smithy_async::future::timeout::Timeout;
use crate::meta::credentials::CredentialsProviderChain;
#[derive(Debug)]
struct FallbackCredentials(Credentials);
impl ProvideCredentials for FallbackCredentials {
fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
where
Self: 'a,
{
future::ProvideCredentials::new(async {
tokio::time::sleep(Duration::from_millis(200)).await;
Ok(self.0.clone())
})
}
fn fallback_on_interrupt(&self) -> Option<Credentials> {
Some(self.0.clone())
}
}
#[tokio::test]
async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
) {
let chain = CredentialsProviderChain::first_try(
"provider1",
provide_credentials_fn(|| async {
tokio::time::sleep(Duration::from_millis(200)).await;
Err(CredentialsError::not_loaded(
"no providers in chain provided credentials",
))
}),
)
.or_else("provider2", FallbackCredentials(Credentials::for_tests()));
let expected = chain.provide_credentials().await.unwrap();
let timeout = Timeout::new(
chain.provide_credentials(),
tokio::time::sleep(Duration::from_millis(300)),
);
match timeout.await {
Ok(_) => panic!("provide_credentials completed before timeout future"),
Err(_err) => match chain.fallback_on_interrupt() {
Some(actual) => assert_eq!(actual, expected),
None => panic!(
"provide_credentials timed out and no credentials returned from fallback_on_interrupt"
),
},
};
}
#[tokio::test]
async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
) {
let chain = CredentialsProviderChain::first_try(
"provider1",
provide_credentials_fn(|| async {
tokio::time::sleep(Duration::from_millis(200)).await;
Err(CredentialsError::not_loaded(
"no providers in chain provided credentials",
))
}),
)
.or_else("provider2", FallbackCredentials(Credentials::for_tests()));
let expected = chain.provide_credentials().await.unwrap();
let timeout = Timeout::new(
chain.provide_credentials(),
tokio::time::sleep(Duration::from_millis(100)),
);
match timeout.await {
Ok(_) => panic!("provide_credentials completed before timeout future"),
Err(_err) => match chain.fallback_on_interrupt() {
Some(actual) => assert_eq!(actual, expected),
None => panic!(
"provide_credentials timed out and no credentials returned from fallback_on_interrupt"
),
},
};
}
}