aws_config/
ecs.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Ecs Credentials Provider
7//!
8//! This credential provider is frequently used with an AWS-provided credentials service (e.g.
9//! [IAM Roles for tasks](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html)).
10//! However, it's possible to use environment variables to configure this provider to use your own
11//! credentials sources.
12//!
13//! This provider is part of the [default credentials chain](crate::default_provider::credentials).
14//!
15//! ## Configuration
16//! **First**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`. It will use this
17//! to construct a URI rooted at `http://169.254.170.2`. For example, if the value of the environment
18//! variable was `/credentials`, the SDK would look for credentials at `http://169.254.170.2/credentials`.
19//!
20//! **Next**: It will check the value of `$AWS_CONTAINER_CREDENTIALS_FULL_URI`. This specifies the full
21//! URL to load credentials. The URL MUST satisfy one of the following three properties:
22//! 1. The URL begins with `https`
23//! 2. The URL refers to an allowed IP address. If a URL contains a domain name instead of an IP address,
24//! a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP address, or
25//! the credentials provider will return `CredentialsError::InvalidConfiguration`. Valid IP addresses are:
26//!     a) Loopback interfaces
27//!     b) The [ECS Task Metadata V2](https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html)
28//!        address ie 169.254.170.2.
29//!     c) [EKS Pod Identity](https://docs.aws.amazon.com/eks/latest/userguide/pod-identities.html) addresses
30//!        ie 169.254.170.23 or fd00:ec2::23
31//!
32//! **Next**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE`. If this is set,
33//! the filename specified will be read, and the value passed in the `Authorization` header. If the file
34//! cannot be read, an error is returned.
35//!
36//! **Finally**: It will check the value of `$AWS_CONTAINER_AUTHORIZATION_TOKEN`. If this is set, the
37//! value will be passed in the `Authorization` header.
38//!
39//! ## Credentials Format
40//! Credentials MUST be returned in a JSON format:
41//! ```json
42//! {
43//!    "AccessKeyId" : "MUA...",
44//!    "SecretAccessKey" : "/7PC5om....",
45//!    "Token" : "AQoDY....=",
46//!    "Expiration" : "2016-02-25T06:03:31Z"
47//!  }
48//! ```
49//!
50//! Credentials errors MAY be returned with a `code` and `message` field:
51//! ```json
52//! {
53//!   "code": "ErrorCode",
54//!   "message": "Helpful error message."
55//! }
56//! ```
57
58use crate::http_credential_provider::HttpCredentialProvider;
59use crate::provider_config::ProviderConfig;
60use aws_credential_types::provider::{self, error::CredentialsError, future, ProvideCredentials};
61use aws_smithy_runtime::client::endpoint::apply_endpoint;
62use aws_smithy_runtime_api::client::dns::{ResolveDns, ResolveDnsError, SharedDnsResolver};
63use aws_smithy_runtime_api::client::http::HttpConnectorSettings;
64use aws_smithy_runtime_api::shared::IntoShared;
65use aws_smithy_types::error::display::DisplayErrorContext;
66use aws_types::os_shim_internal::{Env, Fs};
67use http::header::InvalidHeaderValue;
68use http::uri::{InvalidUri, PathAndQuery, Scheme};
69use http::{HeaderValue, Uri};
70use std::error::Error;
71use std::fmt::{Display, Formatter};
72use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
73use std::time::Duration;
74use tokio::sync::OnceCell;
75
76const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5);
77const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
78
79// URL from https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-metadata-endpoint-v2.html
80const BASE_HOST: &str = "http://169.254.170.2";
81const ENV_RELATIVE_URI: &str = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
82const ENV_FULL_URI: &str = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
83const ENV_AUTHORIZATION_TOKEN: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN";
84const ENV_AUTHORIZATION_TOKEN_FILE: &str = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE";
85
86/// Credential provider for ECS and generalized HTTP credentials
87///
88/// See the [module](crate::ecs) documentation for more details.
89///
90/// This credential provider is part of the default chain.
91#[derive(Debug)]
92pub struct EcsCredentialsProvider {
93    inner: OnceCell<Provider>,
94    env: Env,
95    fs: Fs,
96    builder: Builder,
97}
98
99impl EcsCredentialsProvider {
100    /// Builder for [`EcsCredentialsProvider`]
101    pub fn builder() -> Builder {
102        Builder::default()
103    }
104
105    /// Load credentials from this credentials provider
106    pub async fn credentials(&self) -> provider::Result {
107        let env_token_file = self.env.get(ENV_AUTHORIZATION_TOKEN_FILE).ok();
108        let env_token = self.env.get(ENV_AUTHORIZATION_TOKEN).ok();
109        let auth = if let Some(auth_token_file) = env_token_file {
110            let auth = self
111                .fs
112                .read_to_end(auth_token_file)
113                .await
114                .map_err(CredentialsError::provider_error)?;
115            Some(HeaderValue::from_bytes(auth.as_slice()).map_err(|err| {
116                let auth_token = String::from_utf8_lossy(auth.as_slice()).to_string();
117                tracing::warn!(token = %auth_token, "invalid auth token");
118                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
119                    err,
120                    value: auth_token,
121                })
122            })?)
123        } else if let Some(auth_token) = env_token {
124            Some(HeaderValue::from_str(&auth_token).map_err(|err| {
125                tracing::warn!(token = %auth_token, "invalid auth token");
126                CredentialsError::invalid_configuration(EcsConfigurationError::InvalidAuthToken {
127                    err,
128                    value: auth_token,
129                })
130            })?)
131        } else {
132            None
133        };
134        match self.provider().await {
135            Provider::NotConfigured => {
136                Err(CredentialsError::not_loaded("ECS provider not configured"))
137            }
138            Provider::InvalidConfiguration(err) => {
139                Err(CredentialsError::invalid_configuration(format!("{}", err)))
140            }
141            Provider::Configured(provider) => provider.credentials(auth).await,
142        }
143    }
144
145    async fn provider(&self) -> &Provider {
146        self.inner
147            .get_or_init(|| Provider::make(self.builder.clone()))
148            .await
149    }
150}
151
152impl ProvideCredentials for EcsCredentialsProvider {
153    fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a>
154    where
155        Self: 'a,
156    {
157        future::ProvideCredentials::new(self.credentials())
158    }
159}
160
161/// Inner Provider that can record failed configuration state
162#[derive(Debug)]
163#[allow(clippy::large_enum_variant)]
164enum Provider {
165    Configured(HttpCredentialProvider),
166    NotConfigured,
167    InvalidConfiguration(EcsConfigurationError),
168}
169
170impl Provider {
171    async fn uri(env: Env, dns: Option<SharedDnsResolver>) -> Result<Uri, EcsConfigurationError> {
172        let relative_uri = env.get(ENV_RELATIVE_URI).ok();
173        let full_uri = env.get(ENV_FULL_URI).ok();
174        if let Some(relative_uri) = relative_uri {
175            Self::build_full_uri(relative_uri)
176        } else if let Some(full_uri) = full_uri {
177            let dns = dns.or_else(default_dns);
178            validate_full_uri(&full_uri, dns)
179                .await
180                .map_err(|err| EcsConfigurationError::InvalidFullUri { err, uri: full_uri })
181        } else {
182            Err(EcsConfigurationError::NotConfigured)
183        }
184    }
185
186    async fn make(builder: Builder) -> Self {
187        let provider_config = builder.provider_config.unwrap_or_default();
188        let env = provider_config.env();
189        let uri = match Self::uri(env, builder.dns).await {
190            Ok(uri) => uri,
191            Err(EcsConfigurationError::NotConfigured) => return Provider::NotConfigured,
192            Err(err) => return Provider::InvalidConfiguration(err),
193        };
194        let path = uri.path().to_string();
195        let endpoint = {
196            let mut parts = uri.into_parts();
197            parts.path_and_query = Some(PathAndQuery::from_static("/"));
198            Uri::from_parts(parts)
199        }
200        .expect("parts will be valid")
201        .to_string();
202
203        let http_provider = HttpCredentialProvider::builder()
204            .configure(&provider_config)
205            .http_connector_settings(
206                HttpConnectorSettings::builder()
207                    .connect_timeout(DEFAULT_CONNECT_TIMEOUT)
208                    .read_timeout(DEFAULT_READ_TIMEOUT)
209                    .build(),
210            )
211            .build("EcsContainer", &endpoint, path);
212        Provider::Configured(http_provider)
213    }
214
215    fn build_full_uri(relative_uri: String) -> Result<Uri, EcsConfigurationError> {
216        let mut relative_uri = match relative_uri.parse::<Uri>() {
217            Ok(uri) => uri,
218            Err(invalid_uri) => {
219                tracing::warn!(uri = %DisplayErrorContext(&invalid_uri), "invalid URI loaded from environment");
220                return Err(EcsConfigurationError::InvalidRelativeUri {
221                    err: invalid_uri,
222                    uri: relative_uri,
223                });
224            }
225        };
226        let endpoint = Uri::from_static(BASE_HOST);
227        apply_endpoint(&mut relative_uri, &endpoint, None)
228            .expect("appending relative URLs to the ECS endpoint should always succeed");
229        Ok(relative_uri)
230    }
231}
232
233#[derive(Debug)]
234enum EcsConfigurationError {
235    InvalidRelativeUri {
236        err: InvalidUri,
237        uri: String,
238    },
239    InvalidFullUri {
240        err: InvalidFullUriError,
241        uri: String,
242    },
243    InvalidAuthToken {
244        err: InvalidHeaderValue,
245        value: String,
246    },
247    NotConfigured,
248}
249
250impl Display for EcsConfigurationError {
251    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
252        match self {
253            EcsConfigurationError::InvalidRelativeUri { err, uri } => write!(
254                f,
255                "invalid relative URI for ECS provider ({}): {}",
256                err, uri
257            ),
258            EcsConfigurationError::InvalidFullUri { err, uri } => {
259                write!(f, "invalid full URI for ECS provider ({}): {}", err, uri)
260            }
261            EcsConfigurationError::NotConfigured => write!(
262                f,
263                "No environment variables were set to configure ECS provider"
264            ),
265            EcsConfigurationError::InvalidAuthToken { err, value } => write!(
266                f,
267                "`{}` could not be used as a header value for the auth token. {}",
268                value, err
269            ),
270        }
271    }
272}
273
274impl Error for EcsConfigurationError {
275    fn source(&self) -> Option<&(dyn Error + 'static)> {
276        match &self {
277            EcsConfigurationError::InvalidRelativeUri { err, .. } => Some(err),
278            EcsConfigurationError::InvalidFullUri { err, .. } => Some(err),
279            EcsConfigurationError::InvalidAuthToken { err, .. } => Some(err),
280            EcsConfigurationError::NotConfigured => None,
281        }
282    }
283}
284
285/// Builder for [`EcsCredentialsProvider`]
286#[derive(Default, Debug, Clone)]
287pub struct Builder {
288    provider_config: Option<ProviderConfig>,
289    dns: Option<SharedDnsResolver>,
290    connect_timeout: Option<Duration>,
291    read_timeout: Option<Duration>,
292}
293
294impl Builder {
295    /// Override the configuration used for this provider
296    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
297        self.provider_config = Some(provider_config.clone());
298        self
299    }
300
301    /// Override the DNS resolver used to validate URIs
302    ///
303    /// URIs must refer to valid IP addresses as defined in the module documentation. The [`ResolveDns`]
304    /// implementation is used to retrieve IP addresses for a given domain.
305    pub fn dns(mut self, dns: impl ResolveDns + 'static) -> Self {
306        self.dns = Some(dns.into_shared());
307        self
308    }
309
310    /// Override the connect timeout for the HTTP client
311    ///
312    /// This value defaults to 2 seconds
313    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
314        self.connect_timeout = Some(timeout);
315        self
316    }
317
318    /// Override the read timeout for the HTTP client
319    ///
320    /// This value defaults to 5 seconds
321    pub fn read_timeout(mut self, timeout: Duration) -> Self {
322        self.read_timeout = Some(timeout);
323        self
324    }
325
326    /// Create an [`EcsCredentialsProvider`] from this builder
327    pub fn build(self) -> EcsCredentialsProvider {
328        let env = self
329            .provider_config
330            .as_ref()
331            .map(|config| config.env())
332            .unwrap_or_default();
333        let fs = self
334            .provider_config
335            .as_ref()
336            .map(|config| config.fs())
337            .unwrap_or_default();
338        EcsCredentialsProvider {
339            inner: OnceCell::new(),
340            env,
341            fs,
342            builder: self,
343        }
344    }
345}
346
347#[derive(Debug)]
348enum InvalidFullUriErrorKind {
349    /// The provided URI could not be parsed as a URI
350    #[non_exhaustive]
351    InvalidUri(InvalidUri),
352
353    /// No Dns resolver was provided
354    #[non_exhaustive]
355    NoDnsResolver,
356
357    /// The URI did not specify a host
358    #[non_exhaustive]
359    MissingHost,
360
361    /// The URI did not refer to an allowed IP address
362    #[non_exhaustive]
363    DisallowedIP,
364
365    /// DNS lookup failed when attempting to resolve the host to an IP Address for validation.
366    DnsLookupFailed(ResolveDnsError),
367}
368
369/// Invalid Full URI
370///
371/// When the full URI setting is used, the URI must either be HTTPS, point to a loopback interface,
372/// or point to known ECS/EKS container IPs.
373#[derive(Debug)]
374pub struct InvalidFullUriError {
375    kind: InvalidFullUriErrorKind,
376}
377
378impl Display for InvalidFullUriError {
379    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
380        use InvalidFullUriErrorKind::*;
381        match self.kind {
382            InvalidUri(_) => write!(f, "URI was invalid"),
383            MissingHost => write!(f, "URI did not specify a host"),
384            DisallowedIP => {
385                write!(f, "URI did not refer to an allowed IP address")
386            }
387            DnsLookupFailed(_) => {
388                write!(
389                    f,
390                    "failed to perform DNS lookup while validating URI"
391                )
392            }
393            NoDnsResolver => write!(f, "no DNS resolver was provided. Enable `rt-tokio` or provide a `dns` resolver to the builder.")
394        }
395    }
396}
397
398impl Error for InvalidFullUriError {
399    fn source(&self) -> Option<&(dyn Error + 'static)> {
400        use InvalidFullUriErrorKind::*;
401        match &self.kind {
402            InvalidUri(err) => Some(err),
403            DnsLookupFailed(err) => Some(err as _),
404            _ => None,
405        }
406    }
407}
408
409impl From<InvalidFullUriErrorKind> for InvalidFullUriError {
410    fn from(kind: InvalidFullUriErrorKind) -> Self {
411        Self { kind }
412    }
413}
414
415/// Validate that `uri` is valid to be used as a full provider URI
416/// Either:
417/// 1. The URL is uses `https`
418/// 2. The URL refers to an allowed IP. If a URL contains a domain name instead of an IP address,
419/// a DNS lookup will be performed. ALL resolved IP addresses MUST refer to an allowed IP, or
420/// the credentials provider will return `CredentialsError::InvalidConfiguration`. Allowed IPs
421/// are the loopback interfaces, and the known ECS/EKS container IPs.
422async fn validate_full_uri(
423    uri: &str,
424    dns: Option<SharedDnsResolver>,
425) -> Result<Uri, InvalidFullUriError> {
426    let uri = uri
427        .parse::<Uri>()
428        .map_err(InvalidFullUriErrorKind::InvalidUri)?;
429    if uri.scheme() == Some(&Scheme::HTTPS) {
430        return Ok(uri);
431    }
432    // For HTTP URIs, we need to validate that it points to a valid IP
433    let host = uri.host().ok_or(InvalidFullUriErrorKind::MissingHost)?;
434    let maybe_ip = if host.starts_with('[') && host.ends_with(']') {
435        host[1..host.len() - 1].parse::<IpAddr>()
436    } else {
437        host.parse::<IpAddr>()
438    };
439    let is_allowed = match maybe_ip {
440        Ok(addr) => is_full_uri_ip_allowed(&addr),
441        Err(_domain_name) => {
442            let dns = dns.ok_or(InvalidFullUriErrorKind::NoDnsResolver)?;
443            dns.resolve_dns(host)
444                .await
445                .map_err(|err| InvalidFullUriErrorKind::DnsLookupFailed(ResolveDnsError::new(err)))?
446                .iter()
447                    .all(|addr| {
448                        if !is_full_uri_ip_allowed(addr) {
449                            tracing::warn!(
450                                addr = ?addr,
451                                "HTTP credential provider cannot be used: Address does not resolve to an allowed IP."
452                            )
453                        };
454                        is_full_uri_ip_allowed(addr)
455                    })
456        }
457    };
458    match is_allowed {
459        true => Ok(uri),
460        false => Err(InvalidFullUriErrorKind::DisallowedIP.into()),
461    }
462}
463
464// "169.254.170.2"
465const ECS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 2));
466
467// "169.254.170.23"
468const EKS_CONTAINER_IPV4: IpAddr = IpAddr::V4(Ipv4Addr::new(169, 254, 170, 23));
469
470// "fd00:ec2::23"
471const EKS_CONTAINER_IPV6: IpAddr = IpAddr::V6(Ipv6Addr::new(0xFD00, 0x0EC2, 0, 0, 0, 0, 0, 0x23));
472fn is_full_uri_ip_allowed(ip: &IpAddr) -> bool {
473    ip.is_loopback()
474        || ip.eq(&ECS_CONTAINER_IPV4)
475        || ip.eq(&EKS_CONTAINER_IPV4)
476        || ip.eq(&EKS_CONTAINER_IPV6)
477}
478
479/// Default DNS resolver impl
480///
481/// DNS resolution is required to validate that provided URIs point to a valid IP address
482#[cfg(any(not(feature = "rt-tokio"), target_family = "wasm"))]
483fn default_dns() -> Option<SharedDnsResolver> {
484    None
485}
486#[cfg(all(feature = "rt-tokio", not(target_family = "wasm")))]
487fn default_dns() -> Option<SharedDnsResolver> {
488    use aws_smithy_runtime::client::dns::TokioDnsResolver;
489    Some(TokioDnsResolver::new().into_shared())
490}
491
492#[cfg(test)]
493mod test {
494    use super::*;
495    use crate::provider_config::ProviderConfig;
496    use crate::test_case::{no_traffic_client, GenericTestResult};
497    use aws_credential_types::provider::ProvideCredentials;
498    use aws_credential_types::Credentials;
499    use aws_smithy_async::future::never::Never;
500    use aws_smithy_async::rt::sleep::TokioSleep;
501    use aws_smithy_runtime::client::http::test_util::{ReplayEvent, StaticReplayClient};
502    use aws_smithy_runtime_api::client::dns::DnsFuture;
503    use aws_smithy_runtime_api::client::http::HttpClient;
504    use aws_smithy_runtime_api::shared::IntoShared;
505    use aws_smithy_types::body::SdkBody;
506    use aws_types::os_shim_internal::Env;
507    use futures_util::FutureExt;
508    use http::header::AUTHORIZATION;
509    use http::Uri;
510    use serde::Deserialize;
511    use std::collections::HashMap;
512    use std::error::Error;
513    use std::ffi::OsString;
514    use std::net::IpAddr;
515    use std::time::{Duration, UNIX_EPOCH};
516    use tracing_test::traced_test;
517
518    fn provider(
519        env: Env,
520        fs: Fs,
521        http_client: impl HttpClient + 'static,
522    ) -> EcsCredentialsProvider {
523        let provider_config = ProviderConfig::empty()
524            .with_env(env)
525            .with_fs(fs)
526            .with_http_client(http_client)
527            .with_sleep_impl(TokioSleep::new());
528        Builder::default().configure(&provider_config).build()
529    }
530
531    #[derive(Deserialize)]
532    struct EcsUriTest {
533        env: HashMap<String, String>,
534        result: GenericTestResult<String>,
535    }
536
537    impl EcsUriTest {
538        async fn check(&self) {
539            let env = Env::from(self.env.clone());
540            let uri = Provider::uri(env, Some(TestDns::default().into_shared()))
541                .await
542                .map(|uri| uri.to_string());
543            self.result.assert_matches(uri.as_ref());
544        }
545    }
546
547    #[tokio::test]
548    async fn run_config_tests() -> Result<(), Box<dyn Error>> {
549        let test_cases = std::fs::read_to_string("test-data/ecs-tests.json")?;
550        #[derive(Deserialize)]
551        struct TestCases {
552            tests: Vec<EcsUriTest>,
553        }
554
555        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
556        let test_cases = test_cases.tests;
557        for test in test_cases {
558            test.check().await
559        }
560        Ok(())
561    }
562
563    #[test]
564    fn validate_uri_https() {
565        // over HTTPs, any URI is fine
566        let dns = Some(NeverDns.into_shared());
567        assert_eq!(
568            validate_full_uri("https://amazon.com", None)
569                .now_or_never()
570                .unwrap()
571                .expect("valid"),
572            Uri::from_static("https://amazon.com")
573        );
574        // over HTTP, it will try to lookup
575        assert!(
576            validate_full_uri("http://amazon.com", dns)
577                .now_or_never()
578                .is_none(),
579            "DNS lookup should occur, but it will never return"
580        );
581
582        let no_dns_error = validate_full_uri("http://amazon.com", None)
583            .now_or_never()
584            .unwrap()
585            .expect_err("DNS service is required");
586        assert!(
587            matches!(
588                no_dns_error,
589                InvalidFullUriError {
590                    kind: InvalidFullUriErrorKind::NoDnsResolver
591                }
592            ),
593            "expected no dns service, got: {}",
594            no_dns_error
595        );
596    }
597
598    #[test]
599    fn valid_uri_loopback() {
600        assert_eq!(
601            validate_full_uri("http://127.0.0.1:8080/get-credentials", None)
602                .now_or_never()
603                .unwrap()
604                .expect("valid uri"),
605            Uri::from_static("http://127.0.0.1:8080/get-credentials")
606        );
607
608        let err = validate_full_uri("http://192.168.10.120/creds", None)
609            .now_or_never()
610            .unwrap()
611            .expect_err("not a loopback");
612        assert!(matches!(
613            err,
614            InvalidFullUriError {
615                kind: InvalidFullUriErrorKind::DisallowedIP
616            }
617        ));
618    }
619
620    #[test]
621    fn valid_uri_ecs_eks() {
622        assert_eq!(
623            validate_full_uri("http://169.254.170.2:8080/get-credentials", None)
624                .now_or_never()
625                .unwrap()
626                .expect("valid uri"),
627            Uri::from_static("http://169.254.170.2:8080/get-credentials")
628        );
629        assert_eq!(
630            validate_full_uri("http://169.254.170.23:8080/get-credentials", None)
631                .now_or_never()
632                .unwrap()
633                .expect("valid uri"),
634            Uri::from_static("http://169.254.170.23:8080/get-credentials")
635        );
636        assert_eq!(
637            validate_full_uri("http://[fd00:ec2::23]:8080/get-credentials", None)
638                .now_or_never()
639                .unwrap()
640                .expect("valid uri"),
641            Uri::from_static("http://[fd00:ec2::23]:8080/get-credentials")
642        );
643
644        let err = validate_full_uri("http://169.254.171.23/creds", None)
645            .now_or_never()
646            .unwrap()
647            .expect_err("not an ecs/eks container address");
648        assert!(matches!(
649            err,
650            InvalidFullUriError {
651                kind: InvalidFullUriErrorKind::DisallowedIP
652            }
653        ));
654
655        let err = validate_full_uri("http://[fd00:ec2::2]/creds", None)
656            .now_or_never()
657            .unwrap()
658            .expect_err("not an ecs/eks container address");
659        assert!(matches!(
660            err,
661            InvalidFullUriError {
662                kind: InvalidFullUriErrorKind::DisallowedIP
663            }
664        ));
665    }
666
667    #[test]
668    fn all_addrs_local() {
669        let dns = Some(
670            TestDns::with_fallback(vec![
671                "127.0.0.1".parse().unwrap(),
672                "127.0.0.2".parse().unwrap(),
673                "169.254.170.23".parse().unwrap(),
674                "fd00:ec2::23".parse().unwrap(),
675            ])
676            .into_shared(),
677        );
678        let resp = validate_full_uri("http://localhost:8888", dns)
679            .now_or_never()
680            .unwrap();
681        assert!(resp.is_ok(), "Should be valid: {:?}", resp);
682    }
683
684    #[test]
685    fn all_addrs_not_local() {
686        let dns = Some(
687            TestDns::with_fallback(vec![
688                "127.0.0.1".parse().unwrap(),
689                "192.168.0.1".parse().unwrap(),
690            ])
691            .into_shared(),
692        );
693        let resp = validate_full_uri("http://localhost:8888", dns)
694            .now_or_never()
695            .unwrap();
696        assert!(
697            matches!(
698                resp,
699                Err(InvalidFullUriError {
700                    kind: InvalidFullUriErrorKind::DisallowedIP
701                })
702            ),
703            "Should be invalid: {:?}",
704            resp
705        );
706    }
707
708    fn creds_request(uri: &str, auth: Option<&str>) -> http::Request<SdkBody> {
709        let mut builder = http::Request::builder();
710        if let Some(auth) = auth {
711            builder = builder.header(AUTHORIZATION, auth);
712        }
713        builder.uri(uri).body(SdkBody::empty()).unwrap()
714    }
715
716    fn ok_creds_response() -> http::Response<SdkBody> {
717        http::Response::builder()
718            .status(200)
719            .body(SdkBody::from(
720                r#" {
721                       "AccessKeyId" : "AKID",
722                       "SecretAccessKey" : "SECRET",
723                       "Token" : "TOKEN....=",
724                       "Expiration" : "2009-02-13T23:31:30Z"
725                     }"#,
726            ))
727            .unwrap()
728    }
729
730    #[track_caller]
731    fn assert_correct(creds: Credentials) {
732        assert_eq!(creds.access_key_id(), "AKID");
733        assert_eq!(creds.secret_access_key(), "SECRET");
734        assert_eq!(creds.session_token().unwrap(), "TOKEN....=");
735        assert_eq!(
736            creds.expiry().unwrap(),
737            UNIX_EPOCH + Duration::from_secs(1234567890)
738        );
739    }
740
741    #[tokio::test]
742    async fn load_valid_creds_auth() {
743        let env = Env::from_slice(&[
744            ("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials"),
745            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "Basic password"),
746        ]);
747        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
748            creds_request("http://169.254.170.2/credentials", Some("Basic password")),
749            ok_creds_response(),
750        )]);
751        let provider = provider(env, Fs::default(), http_client.clone());
752        let creds = provider
753            .provide_credentials()
754            .await
755            .expect("valid credentials");
756        assert_correct(creds);
757        http_client.assert_requests_match(&[]);
758    }
759
760    #[tokio::test]
761    async fn load_valid_creds_auth_file() {
762        let env = Env::from_slice(&[
763            (
764                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
765                "http://169.254.170.23/v1/credentials",
766            ),
767            (
768                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
769                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
770            ),
771        ]);
772        let fs = Fs::from_raw_map(HashMap::from([(
773            OsString::from(
774                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
775            ),
776            "Basic password".into(),
777        )]));
778
779        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
780            creds_request(
781                "http://169.254.170.23/v1/credentials",
782                Some("Basic password"),
783            ),
784            ok_creds_response(),
785        )]);
786        let provider = provider(env, fs, http_client.clone());
787        let creds = provider
788            .provide_credentials()
789            .await
790            .expect("valid credentials");
791        assert_correct(creds);
792        http_client.assert_requests_match(&[]);
793    }
794
795    #[tokio::test]
796    async fn auth_file_precedence_over_env() {
797        let env = Env::from_slice(&[
798            (
799                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
800                "http://169.254.170.23/v1/credentials",
801            ),
802            (
803                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
804                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
805            ),
806            ("AWS_CONTAINER_AUTHORIZATION_TOKEN", "unused"),
807        ]);
808        let fs = Fs::from_raw_map(HashMap::from([(
809            OsString::from(
810                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
811            ),
812            "Basic password".into(),
813        )]));
814
815        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
816            creds_request(
817                "http://169.254.170.23/v1/credentials",
818                Some("Basic password"),
819            ),
820            ok_creds_response(),
821        )]);
822        let provider = provider(env, fs, http_client.clone());
823        let creds = provider
824            .provide_credentials()
825            .await
826            .expect("valid credentials");
827        assert_correct(creds);
828        http_client.assert_requests_match(&[]);
829    }
830
831    #[tokio::test]
832    async fn fs_missing_file() {
833        let env = Env::from_slice(&[
834            (
835                "AWS_CONTAINER_CREDENTIALS_FULL_URI",
836                "http://169.254.170.23/v1/credentials",
837            ),
838            (
839                "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE",
840                "/var/run/secrets/pods.eks.amazonaws.com/serviceaccount/eks-pod-identity-token",
841            ),
842        ]);
843        let fs = Fs::from_raw_map(HashMap::new());
844
845        let provider = provider(env, fs, no_traffic_client());
846        let err = provider.credentials().await.expect_err("no JWT token file");
847        match err {
848            CredentialsError::ProviderError { .. } => { /* ok */ }
849            _ => panic!("incorrect error variant"),
850        }
851    }
852
853    #[tokio::test]
854    async fn retry_5xx() {
855        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
856        let http_client = StaticReplayClient::new(vec![
857            ReplayEvent::new(
858                creds_request("http://169.254.170.2/credentials", None),
859                http::Response::builder()
860                    .status(500)
861                    .body(SdkBody::empty())
862                    .unwrap(),
863            ),
864            ReplayEvent::new(
865                creds_request("http://169.254.170.2/credentials", None),
866                ok_creds_response(),
867            ),
868        ]);
869        tokio::time::pause();
870        let provider = provider(env, Fs::default(), http_client.clone());
871        let creds = provider
872            .provide_credentials()
873            .await
874            .expect("valid credentials");
875        assert_correct(creds);
876    }
877
878    #[tokio::test]
879    async fn load_valid_creds_no_auth() {
880        let env = Env::from_slice(&[("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", "/credentials")]);
881        let http_client = StaticReplayClient::new(vec![ReplayEvent::new(
882            creds_request("http://169.254.170.2/credentials", None),
883            ok_creds_response(),
884        )]);
885        let provider = provider(env, Fs::default(), http_client.clone());
886        let creds = provider
887            .provide_credentials()
888            .await
889            .expect("valid credentials");
890        assert_correct(creds);
891        http_client.assert_requests_match(&[]);
892    }
893
894    // ignored by default because it relies on actual DNS resolution
895    #[allow(unused_attributes)]
896    #[tokio::test]
897    #[traced_test]
898    #[ignore]
899    async fn real_dns_lookup() {
900        let dns = Some(
901            default_dns()
902                .expect("feature must be enabled")
903                .into_shared(),
904        );
905        let err = validate_full_uri("http://www.amazon.com/creds", dns.clone())
906            .await
907            .expect_err("not a valid IP");
908        assert!(
909            matches!(
910                err,
911                InvalidFullUriError {
912                    kind: InvalidFullUriErrorKind::DisallowedIP
913                }
914            ),
915            "{:?}",
916            err
917        );
918        assert!(logs_contain("Address does not resolve to an allowed IP"));
919        validate_full_uri("http://localhost:8888/creds", dns.clone())
920            .await
921            .expect("localhost is the loopback interface");
922        validate_full_uri("http://169.254.170.2.backname.io:8888/creds", dns.clone())
923            .await
924            .expect("169.254.170.2.backname.io is the ecs container address");
925        validate_full_uri("http://169.254.170.23.backname.io:8888/creds", dns.clone())
926            .await
927            .expect("169.254.170.23.backname.io is the eks pod identity address");
928        validate_full_uri("http://fd00-ec2--23.backname.io:8888/creds", dns)
929            .await
930            .expect("fd00-ec2--23.backname.io is the eks pod identity address");
931    }
932
933    /// Always returns the same IP addresses
934    #[derive(Clone, Debug)]
935    struct TestDns {
936        addrs: HashMap<String, Vec<IpAddr>>,
937        fallback: Vec<IpAddr>,
938    }
939
940    /// Default that returns a loopback for `localhost` and a non-loopback for all other hostnames
941    impl Default for TestDns {
942        fn default() -> Self {
943            let mut addrs = HashMap::new();
944            addrs.insert(
945                "localhost".into(),
946                vec!["127.0.0.1".parse().unwrap(), "127.0.0.2".parse().unwrap()],
947            );
948            TestDns {
949                addrs,
950                // non-loopback address
951                fallback: vec!["72.21.210.29".parse().unwrap()],
952            }
953        }
954    }
955
956    impl TestDns {
957        fn with_fallback(fallback: Vec<IpAddr>) -> Self {
958            TestDns {
959                addrs: Default::default(),
960                fallback,
961            }
962        }
963    }
964
965    impl ResolveDns for TestDns {
966        fn resolve_dns<'a>(&'a self, name: &'a str) -> DnsFuture<'a> {
967            DnsFuture::ready(Ok(self.addrs.get(name).unwrap_or(&self.fallback).clone()))
968        }
969    }
970
971    #[derive(Debug)]
972    struct NeverDns;
973    impl ResolveDns for NeverDns {
974        fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
975            DnsFuture::new(async {
976                Never::new().await;
977                unreachable!()
978            })
979        }
980    }
981}