aws_config/meta/credentials/
chain.rs
1use aws_credential_types::{
7 provider::{self, error::CredentialsError, future, ProvideCredentials},
8 Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use tracing::Instrument;
13
14#[derive(Debug)]
35pub struct CredentialsProviderChain {
36 providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl CredentialsProviderChain {
40 pub fn first_try(
42 name: impl Into<Cow<'static, str>>,
43 provider: impl ProvideCredentials + 'static,
44 ) -> Self {
45 CredentialsProviderChain {
46 providers: vec![(name.into(), Box::new(provider))],
47 }
48 }
49
50 pub fn or_else(
52 mut self,
53 name: impl Into<Cow<'static, str>>,
54 provider: impl ProvideCredentials + 'static,
55 ) -> Self {
56 self.providers.push((name.into(), Box::new(provider)));
57 self
58 }
59
60 #[cfg(feature = "rustls")]
62 pub async fn or_default_provider(self) -> Self {
63 self.or_else(
64 "DefaultProviderChain",
65 crate::default_provider::credentials::default_provider().await,
66 )
67 }
68
69 #[cfg(feature = "rustls")]
71 pub async fn default_provider() -> Self {
72 Self::first_try(
73 "DefaultProviderChain",
74 crate::default_provider::credentials::default_provider().await,
75 )
76 }
77
78 async fn credentials(&self) -> provider::Result {
79 for (name, provider) in &self.providers {
80 let span = tracing::debug_span!("load_credentials", provider = %name);
81 match provider.provide_credentials().instrument(span).await {
82 Ok(credentials) => {
83 tracing::debug!(provider = %name, "loaded credentials");
84 return Ok(credentials);
85 }
86 Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
87 tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
88 }
89 Err(err) => {
90 tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
91 return Err(err);
92 }
93 }
94 }
95 Err(CredentialsError::not_loaded(
96 "no providers in chain provided credentials",
97 ))
98 }
99}
100
101impl ProvideCredentials for CredentialsProviderChain {
102 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'_>
103 where
104 Self: 'a,
105 {
106 future::ProvideCredentials::new(self.credentials())
107 }
108
109 fn fallback_on_interrupt(&self) -> Option<Credentials> {
110 for (_, provider) in &self.providers {
111 match provider.fallback_on_interrupt() {
112 creds @ Some(_) => return creds,
113 None => {}
114 }
115 }
116 None
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use std::time::Duration;
123
124 use aws_credential_types::{
125 credential_fn::provide_credentials_fn,
126 provider::{error::CredentialsError, future, ProvideCredentials},
127 Credentials,
128 };
129 use aws_smithy_async::future::timeout::Timeout;
130
131 use crate::meta::credentials::CredentialsProviderChain;
132
133 #[derive(Debug)]
134 struct FallbackCredentials(Credentials);
135
136 impl ProvideCredentials for FallbackCredentials {
137 fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
138 where
139 Self: 'a,
140 {
141 future::ProvideCredentials::new(async {
142 tokio::time::sleep(Duration::from_millis(200)).await;
143 Ok(self.0.clone())
144 })
145 }
146
147 fn fallback_on_interrupt(&self) -> Option<Credentials> {
148 Some(self.0.clone())
149 }
150 }
151
152 #[tokio::test]
153 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
154 ) {
155 let chain = CredentialsProviderChain::first_try(
156 "provider1",
157 provide_credentials_fn(|| async {
158 tokio::time::sleep(Duration::from_millis(200)).await;
159 Err(CredentialsError::not_loaded(
160 "no providers in chain provided credentials",
161 ))
162 }),
163 )
164 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
165
166 let expected = chain.provide_credentials().await.unwrap();
168
169 let timeout = Timeout::new(
171 chain.provide_credentials(),
172 tokio::time::sleep(Duration::from_millis(300)),
173 );
174 match timeout.await {
175 Ok(_) => panic!("provide_credentials completed before timeout future"),
176 Err(_err) => match chain.fallback_on_interrupt() {
177 Some(actual) => assert_eq!(actual, expected),
178 None => panic!(
179 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
180 ),
181 },
182 };
183 }
184
185 #[tokio::test]
186 async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
187 ) {
188 let chain = CredentialsProviderChain::first_try(
189 "provider1",
190 provide_credentials_fn(|| async {
191 tokio::time::sleep(Duration::from_millis(200)).await;
192 Err(CredentialsError::not_loaded(
193 "no providers in chain provided credentials",
194 ))
195 }),
196 )
197 .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
198
199 let expected = chain.provide_credentials().await.unwrap();
201
202 let timeout = Timeout::new(
204 chain.provide_credentials(),
205 tokio::time::sleep(Duration::from_millis(100)),
206 );
207 match timeout.await {
208 Ok(_) => panic!("provide_credentials completed before timeout future"),
209 Err(_err) => match chain.fallback_on_interrupt() {
210 Some(actual) => assert_eq!(actual, expected),
211 None => panic!(
212 "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
213 ),
214 },
215 };
216 }
217}