aws_config/meta/credentials/
chain.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use aws_credential_types::{
7    provider::{self, error::CredentialsError, future, ProvideCredentials},
8    Credentials,
9};
10use aws_smithy_types::error::display::DisplayErrorContext;
11use std::borrow::Cow;
12use tracing::Instrument;
13
14/// Credentials provider that checks a series of inner providers
15///
16/// Each provider will be evaluated in order:
17/// * If a provider returns valid [`Credentials`] they will be returned immediately.
18///   No other credential providers will be used.
19/// * Otherwise, if a provider returns [`CredentialsError::CredentialsNotLoaded`], the next provider will be checked.
20/// * Finally, if a provider returns any other error condition, an error will be returned immediately.
21///
22/// # Examples
23///
24/// ```no_run
25/// # fn example() {
26/// use aws_config::meta::credentials::CredentialsProviderChain;
27/// use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider;
28/// use aws_config::profile::ProfileFileCredentialsProvider;
29///
30/// let provider = CredentialsProviderChain::first_try("Environment", EnvironmentVariableCredentialsProvider::new())
31///     .or_else("Profile", ProfileFileCredentialsProvider::builder().build());
32/// # }
33/// ```
34#[derive(Debug)]
35pub struct CredentialsProviderChain {
36    providers: Vec<(Cow<'static, str>, Box<dyn ProvideCredentials>)>,
37}
38
39impl CredentialsProviderChain {
40    /// Create a `CredentialsProviderChain` that begins by evaluating this provider
41    pub fn first_try(
42        name: impl Into<Cow<'static, str>>,
43        provider: impl ProvideCredentials + 'static,
44    ) -> Self {
45        CredentialsProviderChain {
46            providers: vec![(name.into(), Box::new(provider))],
47        }
48    }
49
50    /// Add a fallback provider to the credentials provider chain
51    pub fn or_else(
52        mut self,
53        name: impl Into<Cow<'static, str>>,
54        provider: impl ProvideCredentials + 'static,
55    ) -> Self {
56        self.providers.push((name.into(), Box::new(provider)));
57        self
58    }
59
60    /// Add a fallback to the default provider chain
61    #[cfg(feature = "rustls")]
62    pub async fn or_default_provider(self) -> Self {
63        self.or_else(
64            "DefaultProviderChain",
65            crate::default_provider::credentials::default_provider().await,
66        )
67    }
68
69    /// Creates a credential provider chain that starts with the default provider
70    #[cfg(feature = "rustls")]
71    pub async fn default_provider() -> Self {
72        Self::first_try(
73            "DefaultProviderChain",
74            crate::default_provider::credentials::default_provider().await,
75        )
76    }
77
78    async fn credentials(&self) -> provider::Result {
79        for (name, provider) in &self.providers {
80            let span = tracing::debug_span!("load_credentials", provider = %name);
81            match provider.provide_credentials().instrument(span).await {
82                Ok(credentials) => {
83                    tracing::debug!(provider = %name, "loaded credentials");
84                    return Ok(credentials);
85                }
86                Err(err @ CredentialsError::CredentialsNotLoaded(_)) => {
87                    tracing::debug!(provider = %name, context = %DisplayErrorContext(&err), "provider in chain did not provide credentials");
88                }
89                Err(err) => {
90                    tracing::warn!(provider = %name, error = %DisplayErrorContext(&err), "provider failed to provide credentials");
91                    return Err(err);
92                }
93            }
94        }
95        Err(CredentialsError::not_loaded(
96            "no providers in chain provided credentials",
97        ))
98    }
99}
100
101impl ProvideCredentials for CredentialsProviderChain {
102    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'_>
103    where
104        Self: 'a,
105    {
106        future::ProvideCredentials::new(self.credentials())
107    }
108
109    fn fallback_on_interrupt(&self) -> Option<Credentials> {
110        for (_, provider) in &self.providers {
111            match provider.fallback_on_interrupt() {
112                creds @ Some(_) => return creds,
113                None => {}
114            }
115        }
116        None
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use std::time::Duration;
123
124    use aws_credential_types::{
125        credential_fn::provide_credentials_fn,
126        provider::{error::CredentialsError, future, ProvideCredentials},
127        Credentials,
128    };
129    use aws_smithy_async::future::timeout::Timeout;
130
131    use crate::meta::credentials::CredentialsProviderChain;
132
133    #[derive(Debug)]
134    struct FallbackCredentials(Credentials);
135
136    impl ProvideCredentials for FallbackCredentials {
137        fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
138        where
139            Self: 'a,
140        {
141            future::ProvideCredentials::new(async {
142                tokio::time::sleep(Duration::from_millis(200)).await;
143                Ok(self.0.clone())
144            })
145        }
146
147        fn fallback_on_interrupt(&self) -> Option<Credentials> {
148            Some(self.0.clone())
149        }
150    }
151
152    #[tokio::test]
153    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider2_was_providing_credentials(
154    ) {
155        let chain = CredentialsProviderChain::first_try(
156            "provider1",
157            provide_credentials_fn(|| async {
158                tokio::time::sleep(Duration::from_millis(200)).await;
159                Err(CredentialsError::not_loaded(
160                    "no providers in chain provided credentials",
161                ))
162            }),
163        )
164        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
165
166        // Let the first call to `provide_credentials` succeed.
167        let expected = chain.provide_credentials().await.unwrap();
168
169        // Let the second call fail with an external timeout.
170        let timeout = Timeout::new(
171            chain.provide_credentials(),
172            tokio::time::sleep(Duration::from_millis(300)),
173        );
174        match timeout.await {
175            Ok(_) => panic!("provide_credentials completed before timeout future"),
176            Err(_err) => match chain.fallback_on_interrupt() {
177                Some(actual) => assert_eq!(actual, expected),
178                None => panic!(
179                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
180                ),
181            },
182        };
183    }
184
185    #[tokio::test]
186    async fn fallback_credentials_should_be_returned_from_provider2_on_timeout_while_provider1_was_providing_credentials(
187    ) {
188        let chain = CredentialsProviderChain::first_try(
189            "provider1",
190            provide_credentials_fn(|| async {
191                tokio::time::sleep(Duration::from_millis(200)).await;
192                Err(CredentialsError::not_loaded(
193                    "no providers in chain provided credentials",
194                ))
195            }),
196        )
197        .or_else("provider2", FallbackCredentials(Credentials::for_tests()));
198
199        // Let the first call to `provide_credentials` succeed.
200        let expected = chain.provide_credentials().await.unwrap();
201
202        // Let the second call fail with an external timeout.
203        let timeout = Timeout::new(
204            chain.provide_credentials(),
205            tokio::time::sleep(Duration::from_millis(100)),
206        );
207        match timeout.await {
208            Ok(_) => panic!("provide_credentials completed before timeout future"),
209            Err(_err) => match chain.fallback_on_interrupt() {
210                Some(actual) => assert_eq!(actual, expected),
211                None => panic!(
212                    "provide_credentials timed out and no credentials returned from fallback_on_interrupt"
213                ),
214            },
215        };
216    }
217}