aws_config/imds/
client.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Raw IMDSv2 Client
7//!
8//! Client for direct access to IMDSv2.
9
10use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::orchestrator::operation::Operation;
16use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
19use aws_smithy_runtime_api::client::endpoint::{
20    EndpointFuture, EndpointResolverParams, ResolveEndpoint,
21};
22use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
23use aws_smithy_runtime_api::client::orchestrator::{
24    HttpRequest, OrchestratorError, SensitiveOutput,
25};
26use aws_smithy_runtime_api::client::result::ConnectorError;
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::{
29    ClassifyRetry, RetryAction, SharedRetryClassifier,
30};
31use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
32use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
33use aws_smithy_types::body::SdkBody;
34use aws_smithy_types::config_bag::{FrozenLayer, Layer};
35use aws_smithy_types::endpoint::Endpoint;
36use aws_smithy_types::retry::RetryConfig;
37use aws_smithy_types::timeout::TimeoutConfig;
38use aws_types::os_shim_internal::Env;
39use http::Uri;
40use std::borrow::Cow;
41use std::error::Error as _;
42use std::fmt;
43use std::str::FromStr;
44use std::sync::Arc;
45use std::time::Duration;
46
47pub mod error;
48mod token;
49
50// 6 hours
51const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
52const DEFAULT_ATTEMPTS: u32 = 4;
53const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
54const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
55
56fn user_agent() -> AwsUserAgent {
57    AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
58}
59
60/// IMDSv2 Client
61///
62/// Client for IMDSv2. This client handles fetching tokens, retrying on failure, and token
63/// caching according to the specified token TTL.
64///
65/// _Note: This client ONLY supports IMDSv2. It will not fallback to IMDSv1. See
66/// [transitioning to IMDSv2](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html#instance-metadata-transition-to-version-2)
67/// for more information._
68///
69/// **Note**: When running in a Docker container, all network requests will incur an additional hop. When combined with the default IMDS hop limit of 1, this will cause requests to IMDS to timeout! To fix this issue, you'll need to set the following instance metadata settings :
70/// ```txt
71/// amazonec2-metadata-token=required
72/// amazonec2-metadata-token-response-hop-limit=2
73/// ```
74///
75/// On an instance that is already running, these can be set with [ModifyInstanceMetadataOptions](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_ModifyInstanceMetadataOptions.html). On a new instance, these can be set with the `MetadataOptions` field on [RunInstances](https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_RunInstances.html).
76///
77/// For more information about IMDSv2 vs. IMDSv1 see [this guide](https://docs.aws.amazon.com/AWSEC2/latest/WindowsGuide/configuring-instance-metadata-service.html)
78///
79/// # Client Configuration
80/// The IMDS client can load configuration explicitly, via environment variables, or via
81/// `~/.aws/config`. It will first attempt to resolve an endpoint override. If no endpoint
82/// override exists, it will attempt to resolve an [`EndpointMode`]. If no
83/// [`EndpointMode`] override exists, it will fallback to [`IpV4`](EndpointMode::IpV4). An exhaustive
84/// list is below:
85///
86/// ## Endpoint configuration list
87/// 1. Explicit configuration of `Endpoint` via the [builder](Builder):
88/// ```no_run
89/// use aws_config::imds::client::Client;
90/// # async fn docs() {
91/// let client = Client::builder()
92///   .endpoint("http://customidms:456/").expect("valid URI")
93///   .build();
94/// # }
95/// ```
96///
97/// 2. The `AWS_EC2_METADATA_SERVICE_ENDPOINT` environment variable. Note: If this environment variable
98/// is set, it MUST contain to a valid URI or client construction will fail.
99///
100/// 3. The `ec2_metadata_service_endpoint` field in `~/.aws/config`:
101/// ```ini
102/// [default]
103/// # ... other configuration
104/// ec2_metadata_service_endpoint = http://my-custom-endpoint:444
105/// ```
106///
107/// 4. An explicitly set endpoint mode:
108/// ```no_run
109/// use aws_config::imds::client::{Client, EndpointMode};
110/// # async fn docs() {
111/// let client = Client::builder().endpoint_mode(EndpointMode::IpV6).build();
112/// # }
113/// ```
114///
115/// 5. An [endpoint mode](EndpointMode) loaded from the `AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE` environment
116/// variable. Valid values: `IPv4`, `IPv6`
117///
118/// 6. An [endpoint mode](EndpointMode) loaded from the `ec2_metadata_service_endpoint_mode` field in
119/// `~/.aws/config`:
120/// ```ini
121/// [default]
122/// # ... other configuration
123/// ec2_metadata_service_endpoint_mode = IPv4
124/// ```
125///
126/// 7. The default value of `http://169.254.169.254` will be used.
127///
128#[derive(Clone, Debug)]
129pub struct Client {
130    operation: Operation<String, SensitiveString, InnerImdsError>,
131}
132
133impl Client {
134    /// IMDS client builder
135    pub fn builder() -> Builder {
136        Builder::default()
137    }
138
139    /// Retrieve information from IMDS
140    ///
141    /// This method will handle loading and caching a session token, combining the `path` with the
142    /// configured IMDS endpoint, and retrying potential errors.
143    ///
144    /// For more information about IMDSv2 methods and functionality, see
145    /// [Instance metadata and user data](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-metadata.html)
146    ///
147    /// # Examples
148    ///
149    /// ```no_run
150    /// use aws_config::imds::client::Client;
151    /// # async fn docs() {
152    /// let client = Client::builder().build();
153    /// let ami_id = client
154    ///   .get("/latest/meta-data/ami-id")
155    ///   .await
156    ///   .expect("failure communicating with IMDS");
157    /// # }
158    /// ```
159    pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
160        self.operation
161            .invoke(path.into())
162            .await
163            .map_err(|err| match err {
164                SdkError::ConstructionFailure(_) if err.source().is_some() => {
165                    match err.into_source().map(|e| e.downcast::<ImdsError>()) {
166                        Ok(Ok(token_failure)) => *token_failure,
167                        Ok(Err(err)) => ImdsError::unexpected(err),
168                        Err(err) => ImdsError::unexpected(err),
169                    }
170                }
171                SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
172                SdkError::ServiceError(context) => match context.err() {
173                    InnerImdsError::InvalidUtf8 => {
174                        ImdsError::unexpected("IMDS returned invalid UTF-8")
175                    }
176                    InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
177                },
178                // If the error source is an ImdsError, then we need to directly return that source.
179                // That way, the IMDS token provider's errors can become the top-level ImdsError.
180                // There is a unit test that checks the correct error is being extracted.
181                err @ SdkError::DispatchFailure(_) => match err.into_source() {
182                    Ok(source) => match source.downcast::<ConnectorError>() {
183                        Ok(source) => match source.into_source().downcast::<ImdsError>() {
184                            Ok(source) => *source,
185                            Err(err) => ImdsError::unexpected(err),
186                        },
187                        Err(err) => ImdsError::unexpected(err),
188                    },
189                    Err(err) => ImdsError::unexpected(err),
190                },
191                SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
192                _ => ImdsError::unexpected(err),
193            })
194    }
195}
196
197/// New-type around `String` that doesn't emit the string value in the `Debug` impl.
198#[derive(Clone)]
199pub struct SensitiveString(String);
200
201impl fmt::Debug for SensitiveString {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        f.debug_tuple("SensitiveString")
204            .field(&"** redacted **")
205            .finish()
206    }
207}
208
209impl AsRef<str> for SensitiveString {
210    fn as_ref(&self) -> &str {
211        &self.0
212    }
213}
214
215impl From<String> for SensitiveString {
216    fn from(value: String) -> Self {
217        Self(value)
218    }
219}
220
221impl From<SensitiveString> for String {
222    fn from(value: SensitiveString) -> Self {
223        value.0
224    }
225}
226
227/// Runtime plugin that is used by both the IMDS client and the inner client that resolves
228/// the IMDS token and attaches it to requests. This runtime plugin marks the responses as
229/// sensitive, configures user agent headers, and sets up retries and timeouts.
230#[derive(Debug)]
231struct ImdsCommonRuntimePlugin {
232    config: FrozenLayer,
233    components: RuntimeComponentsBuilder,
234}
235
236impl ImdsCommonRuntimePlugin {
237    fn new(
238        config: &ProviderConfig,
239        endpoint_resolver: ImdsEndpointResolver,
240        retry_config: RetryConfig,
241        timeout_config: TimeoutConfig,
242    ) -> Self {
243        let mut layer = Layer::new("ImdsCommonRuntimePlugin");
244        layer.store_put(AuthSchemeOptionResolverParams::new(()));
245        layer.store_put(EndpointResolverParams::new(()));
246        layer.store_put(SensitiveOutput);
247        layer.store_put(retry_config);
248        layer.store_put(timeout_config);
249        layer.store_put(user_agent());
250
251        Self {
252            config: layer.freeze(),
253            components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
254                .with_http_client(config.http_client())
255                .with_endpoint_resolver(Some(endpoint_resolver))
256                .with_interceptor(UserAgentInterceptor::new())
257                .with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier))
258                .with_retry_strategy(Some(StandardRetryStrategy::new()))
259                .with_time_source(Some(config.time_source()))
260                .with_sleep_impl(config.sleep_impl()),
261        }
262    }
263}
264
265impl RuntimePlugin for ImdsCommonRuntimePlugin {
266    fn config(&self) -> Option<FrozenLayer> {
267        Some(self.config.clone())
268    }
269
270    fn runtime_components(
271        &self,
272        _current_components: &RuntimeComponentsBuilder,
273    ) -> Cow<'_, RuntimeComponentsBuilder> {
274        Cow::Borrowed(&self.components)
275    }
276}
277
278/// IMDSv2 Endpoint Mode
279///
280/// IMDS can be accessed in two ways:
281/// 1. Via the IpV4 endpoint: `http://169.254.169.254`
282/// 2. Via the Ipv6 endpoint: `http://[fd00:ec2::254]`
283#[derive(Debug, Clone)]
284#[non_exhaustive]
285pub enum EndpointMode {
286    /// IpV4 mode: `http://169.254.169.254`
287    ///
288    /// This mode is the default unless otherwise specified.
289    IpV4,
290    /// IpV6 mode: `http://[fd00:ec2::254]`
291    IpV6,
292}
293
294impl FromStr for EndpointMode {
295    type Err = InvalidEndpointMode;
296
297    fn from_str(value: &str) -> Result<Self, Self::Err> {
298        match value {
299            _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
300            _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
301            other => Err(InvalidEndpointMode::new(other.to_owned())),
302        }
303    }
304}
305
306impl EndpointMode {
307    /// IMDS URI for this endpoint mode
308    fn endpoint(&self) -> Uri {
309        match self {
310            EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
311            EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
312        }
313    }
314}
315
316/// IMDSv2 Client Builder
317#[derive(Default, Debug, Clone)]
318pub struct Builder {
319    max_attempts: Option<u32>,
320    endpoint: Option<EndpointSource>,
321    mode_override: Option<EndpointMode>,
322    token_ttl: Option<Duration>,
323    connect_timeout: Option<Duration>,
324    read_timeout: Option<Duration>,
325    config: Option<ProviderConfig>,
326}
327
328impl Builder {
329    /// Override the number of retries for fetching tokens & metadata
330    ///
331    /// By default, 4 attempts will be made.
332    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
333        self.max_attempts = Some(max_attempts);
334        self
335    }
336
337    /// Configure generic options of the [`Client`]
338    ///
339    /// # Examples
340    /// ```no_run
341    /// # async fn test() {
342    /// use aws_config::imds::Client;
343    /// use aws_config::provider_config::ProviderConfig;
344    ///
345    /// let provider = Client::builder()
346    ///     .configure(&ProviderConfig::with_default_region().await)
347    ///     .build();
348    /// # }
349    /// ```
350    pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
351        self.config = Some(provider_config.clone());
352        self
353    }
354
355    /// Override the endpoint for the [`Client`]
356    ///
357    /// By default, the client will resolve an endpoint from the environment, AWS config, and endpoint mode.
358    ///
359    /// See [`Client`] for more information.
360    pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
361        let uri: Uri = endpoint.as_ref().parse()?;
362        self.endpoint = Some(EndpointSource::Explicit(uri));
363        Ok(self)
364    }
365
366    /// Override the endpoint mode for [`Client`]
367    ///
368    /// * When set to [`IpV4`](EndpointMode::IpV4), the endpoint will be `http://169.254.169.254`.
369    /// * When set to [`IpV6`](EndpointMode::IpV6), the endpoint will be `http://[fd00:ec2::254]`.
370    pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
371        self.mode_override = Some(mode);
372        self
373    }
374
375    /// Override the time-to-live for the session token
376    ///
377    /// Requests to IMDS utilize a session token for authentication. By default, session tokens last
378    /// for 6 hours. When the TTL for the token expires, a new token must be retrieved from the
379    /// metadata service.
380    pub fn token_ttl(mut self, ttl: Duration) -> Self {
381        self.token_ttl = Some(ttl);
382        self
383    }
384
385    /// Override the connect timeout for IMDS
386    ///
387    /// This value defaults to 1 second
388    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
389        self.connect_timeout = Some(timeout);
390        self
391    }
392
393    /// Override the read timeout for IMDS
394    ///
395    /// This value defaults to 1 second
396    pub fn read_timeout(mut self, timeout: Duration) -> Self {
397        self.read_timeout = Some(timeout);
398        self
399    }
400
401    /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
402    /*
403    pub fn port(mut self, port: u32) -> Self {
404        self.port_override = Some(port);
405        self
406    }*/
407
408    /// Build an IMDSv2 Client
409    pub fn build(self) -> Client {
410        let config = self.config.unwrap_or_default();
411        let timeout_config = TimeoutConfig::builder()
412            .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
413            .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
414            .build();
415        let endpoint_source = self
416            .endpoint
417            .unwrap_or_else(|| EndpointSource::Env(config.clone()));
418        let endpoint_resolver = ImdsEndpointResolver {
419            endpoint_source: Arc::new(endpoint_source),
420            mode_override: self.mode_override,
421        };
422        let retry_config = RetryConfig::standard()
423            .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
424        let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
425            &config,
426            endpoint_resolver,
427            retry_config,
428            timeout_config,
429        ));
430        let operation = Operation::builder()
431            .service_name("imds")
432            .operation_name("get")
433            .runtime_plugin(common_plugin.clone())
434            .runtime_plugin(TokenRuntimePlugin::new(
435                common_plugin,
436                self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
437            ))
438            .with_connection_poisoning()
439            .serializer(|path| {
440                Ok(HttpRequest::try_from(
441                    http::Request::builder()
442                        .uri(path)
443                        .body(SdkBody::empty())
444                        .expect("valid request"),
445                )
446                .unwrap())
447            })
448            .deserializer(|response| {
449                if response.status().is_success() {
450                    std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
451                        .map(|data| SensitiveString::from(data.to_string()))
452                        .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
453                } else {
454                    Err(OrchestratorError::operation(InnerImdsError::BadStatus))
455                }
456            })
457            .build();
458        Client { operation }
459    }
460}
461
462mod env {
463    pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
464    pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
465}
466
467mod profile_keys {
468    pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
469    pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
470}
471
472/// Endpoint Configuration Abstraction
473#[derive(Debug, Clone)]
474enum EndpointSource {
475    Explicit(Uri),
476    Env(ProviderConfig),
477}
478
479impl EndpointSource {
480    async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
481        match self {
482            EndpointSource::Explicit(uri) => {
483                if mode_override.is_some() {
484                    tracing::warn!(endpoint = ?uri, mode = ?mode_override,
485                        "Endpoint mode override was set in combination with an explicit endpoint. \
486                        The mode override will be ignored.")
487                }
488                Ok(uri.clone())
489            }
490            EndpointSource::Env(conf) => {
491                let env = conf.env();
492                // load an endpoint override from the environment
493                let profile = conf.profile().await;
494                let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
495                    Some(Cow::Owned(uri))
496                } else {
497                    profile
498                        .and_then(|profile| profile.get(profile_keys::ENDPOINT))
499                        .map(Cow::Borrowed)
500                };
501                if let Some(uri) = uri_override {
502                    return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
503                }
504
505                // if not, load a endpoint mode from the environment
506                let mode = if let Some(mode) = mode_override {
507                    mode
508                } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
509                    mode.parse::<EndpointMode>()
510                        .map_err(BuildError::invalid_endpoint_mode)?
511                } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
512                {
513                    mode.parse::<EndpointMode>()
514                        .map_err(BuildError::invalid_endpoint_mode)?
515                } else {
516                    EndpointMode::IpV4
517                };
518
519                Ok(mode.endpoint())
520            }
521        }
522    }
523}
524
525#[derive(Clone, Debug)]
526struct ImdsEndpointResolver {
527    endpoint_source: Arc<EndpointSource>,
528    mode_override: Option<EndpointMode>,
529}
530
531impl ResolveEndpoint for ImdsEndpointResolver {
532    fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
533        EndpointFuture::new(async move {
534            self.endpoint_source
535                .endpoint(self.mode_override.clone())
536                .await
537                .map(|uri| Endpoint::builder().url(uri.to_string()).build())
538                .map_err(|err| err.into())
539        })
540    }
541}
542
543/// IMDS Response Retry Classifier
544///
545/// Possible status codes:
546/// - 200 (OK)
547/// - 400 (Missing or invalid parameters) **Not Retryable**
548/// - 401 (Unauthorized, expired token) **Retryable**
549/// - 403 (IMDS disabled): **Not Retryable**
550/// - 404 (Not found): **Not Retryable**
551/// - >=500 (server error): **Retryable**
552#[derive(Clone, Debug)]
553struct ImdsResponseRetryClassifier;
554
555impl ClassifyRetry for ImdsResponseRetryClassifier {
556    fn name(&self) -> &'static str {
557        "ImdsResponseRetryClassifier"
558    }
559
560    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
561        if let Some(response) = ctx.response() {
562            let status = response.status();
563            match status {
564                _ if status.is_server_error() => RetryAction::server_error(),
565                // 401 indicates that the token has expired, this is retryable
566                _ if status.as_u16() == 401 => RetryAction::server_error(),
567                // This catch-all includes successful responses that fail to parse. These should not be retried.
568                _ => RetryAction::NoActionIndicated,
569            }
570        } else {
571            // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
572            // credentials provider chain to fail to provide credentials.
573            // Also don't retry non-responses.
574            RetryAction::NoActionIndicated
575        }
576    }
577}
578
579#[cfg(test)]
580pub(crate) mod test {
581    use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
582    use crate::provider_config::ProviderConfig;
583    use aws_smithy_async::rt::sleep::TokioSleep;
584    use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
585    use aws_smithy_runtime::client::http::test_util::{
586        capture_request, ReplayEvent, StaticReplayClient,
587    };
588    use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
589    use aws_smithy_runtime_api::client::interceptors::context::{
590        Input, InterceptorContext, Output,
591    };
592    use aws_smithy_runtime_api::client::orchestrator::{
593        HttpRequest, HttpResponse, OrchestratorError,
594    };
595    use aws_smithy_runtime_api::client::result::ConnectorError;
596    use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
597    use aws_smithy_types::body::SdkBody;
598    use aws_smithy_types::error::display::DisplayErrorContext;
599    use aws_types::os_shim_internal::{Env, Fs};
600    use http::header::USER_AGENT;
601    use http::Uri;
602    use serde::Deserialize;
603    use std::collections::HashMap;
604    use std::error::Error;
605    use std::io;
606    use std::time::{Duration, UNIX_EPOCH};
607    use tracing_test::traced_test;
608
609    macro_rules! assert_full_error_contains {
610        ($err:expr, $contains:expr) => {
611            let err = $err;
612            let message = format!(
613                "{}",
614                aws_smithy_types::error::display::DisplayErrorContext(&err)
615            );
616            assert!(
617                message.contains($contains),
618                "Error message '{message}' didn't contain text '{}'",
619                $contains
620            );
621        };
622    }
623
624    const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
625    const TOKEN_B: &str = "alternatetoken==";
626
627    pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
628        http::Request::builder()
629            .uri(format!("{}/latest/api/token", base))
630            .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
631            .method("PUT")
632            .body(SdkBody::empty())
633            .unwrap()
634            .try_into()
635            .unwrap()
636    }
637
638    pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
639        HttpResponse::try_from(
640            http::Response::builder()
641                .status(200)
642                .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
643                .body(SdkBody::from(token))
644                .unwrap(),
645        )
646        .unwrap()
647    }
648
649    pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
650        http::Request::builder()
651            .uri(Uri::from_static(path))
652            .method("GET")
653            .header("x-aws-ec2-metadata-token", token)
654            .body(SdkBody::empty())
655            .unwrap()
656            .try_into()
657            .unwrap()
658    }
659
660    pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
661        HttpResponse::try_from(
662            http::Response::builder()
663                .status(200)
664                .body(SdkBody::from(body))
665                .unwrap(),
666        )
667        .unwrap()
668    }
669
670    pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
671        tokio::time::pause();
672        super::Client::builder()
673            .configure(
674                &ProviderConfig::no_configuration()
675                    .with_sleep_impl(InstantSleep::unlogged())
676                    .with_http_client(http_client.clone()),
677            )
678            .build()
679    }
680
681    fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
682        let http_client = StaticReplayClient::new(events);
683        let client = make_imds_client(&http_client);
684        (client, http_client)
685    }
686
687    #[tokio::test]
688    async fn client_caches_token() {
689        let (client, http_client) = mock_imds_client(vec![
690            ReplayEvent::new(
691                token_request("http://169.254.169.254", 21600),
692                token_response(21600, TOKEN_A),
693            ),
694            ReplayEvent::new(
695                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
696                imds_response(r#"test-imds-output"#),
697            ),
698            ReplayEvent::new(
699                imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
700                imds_response("output2"),
701            ),
702        ]);
703        // load once
704        let metadata = client.get("/latest/metadata").await.expect("failed");
705        assert_eq!("test-imds-output", metadata.as_ref());
706        // load again: the cached token should be used
707        let metadata = client.get("/latest/metadata2").await.expect("failed");
708        assert_eq!("output2", metadata.as_ref());
709        http_client.assert_requests_match(&[]);
710    }
711
712    #[tokio::test]
713    async fn token_can_expire() {
714        let (_, http_client) = mock_imds_client(vec![
715            ReplayEvent::new(
716                token_request("http://[fd00:ec2::254]", 600),
717                token_response(600, TOKEN_A),
718            ),
719            ReplayEvent::new(
720                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
721                imds_response(r#"test-imds-output1"#),
722            ),
723            ReplayEvent::new(
724                token_request("http://[fd00:ec2::254]", 600),
725                token_response(600, TOKEN_B),
726            ),
727            ReplayEvent::new(
728                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
729                imds_response(r#"test-imds-output2"#),
730            ),
731        ]);
732        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
733        let client = super::Client::builder()
734            .configure(
735                &ProviderConfig::no_configuration()
736                    .with_http_client(http_client.clone())
737                    .with_time_source(time_source.clone())
738                    .with_sleep_impl(sleep),
739            )
740            .endpoint_mode(EndpointMode::IpV6)
741            .token_ttl(Duration::from_secs(600))
742            .build();
743
744        let resp1 = client.get("/latest/metadata").await.expect("success");
745        // now the cached credential has expired
746        time_source.advance(Duration::from_secs(600));
747        let resp2 = client.get("/latest/metadata").await.expect("success");
748        http_client.assert_requests_match(&[]);
749        assert_eq!("test-imds-output1", resp1.as_ref());
750        assert_eq!("test-imds-output2", resp2.as_ref());
751    }
752
753    /// Tokens are refreshed up to 120 seconds early to avoid using an expired token.
754    #[tokio::test]
755    async fn token_refresh_buffer() {
756        let _logs = capture_test_logs();
757        let (_, http_client) = mock_imds_client(vec![
758            ReplayEvent::new(
759                token_request("http://[fd00:ec2::254]", 600),
760                token_response(600, TOKEN_A),
761            ),
762            // t = 0
763            ReplayEvent::new(
764                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
765                imds_response(r#"test-imds-output1"#),
766            ),
767            // t = 400 (no refresh)
768            ReplayEvent::new(
769                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
770                imds_response(r#"test-imds-output2"#),
771            ),
772            // t = 550 (within buffer)
773            ReplayEvent::new(
774                token_request("http://[fd00:ec2::254]", 600),
775                token_response(600, TOKEN_B),
776            ),
777            ReplayEvent::new(
778                imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
779                imds_response(r#"test-imds-output3"#),
780            ),
781        ]);
782        let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
783        let client = super::Client::builder()
784            .configure(
785                &ProviderConfig::no_configuration()
786                    .with_sleep_impl(sleep)
787                    .with_http_client(http_client.clone())
788                    .with_time_source(time_source.clone()),
789            )
790            .endpoint_mode(EndpointMode::IpV6)
791            .token_ttl(Duration::from_secs(600))
792            .build();
793
794        tracing::info!("resp1 -----------------------------------------------------------");
795        let resp1 = client.get("/latest/metadata").await.expect("success");
796        // now the cached credential has expired
797        time_source.advance(Duration::from_secs(400));
798        tracing::info!("resp2 -----------------------------------------------------------");
799        let resp2 = client.get("/latest/metadata").await.expect("success");
800        time_source.advance(Duration::from_secs(150));
801        tracing::info!("resp3 -----------------------------------------------------------");
802        let resp3 = client.get("/latest/metadata").await.expect("success");
803        http_client.assert_requests_match(&[]);
804        assert_eq!("test-imds-output1", resp1.as_ref());
805        assert_eq!("test-imds-output2", resp2.as_ref());
806        assert_eq!("test-imds-output3", resp3.as_ref());
807    }
808
809    /// 500 error during the GET should be retried
810    #[tokio::test]
811    #[traced_test]
812    async fn retry_500() {
813        let (client, http_client) = mock_imds_client(vec![
814            ReplayEvent::new(
815                token_request("http://169.254.169.254", 21600),
816                token_response(21600, TOKEN_A),
817            ),
818            ReplayEvent::new(
819                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
820                http::Response::builder()
821                    .status(500)
822                    .body(SdkBody::empty())
823                    .unwrap(),
824            ),
825            ReplayEvent::new(
826                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
827                imds_response("ok"),
828            ),
829        ]);
830        assert_eq!(
831            "ok",
832            client
833                .get("/latest/metadata")
834                .await
835                .expect("success")
836                .as_ref()
837        );
838        http_client.assert_requests_match(&[]);
839
840        // all requests should have a user agent header
841        for request in http_client.actual_requests() {
842            assert!(request.headers().get(USER_AGENT).is_some());
843        }
844    }
845
846    /// 500 error during token acquisition should be retried
847    #[tokio::test]
848    #[traced_test]
849    async fn retry_token_failure() {
850        let (client, http_client) = mock_imds_client(vec![
851            ReplayEvent::new(
852                token_request("http://169.254.169.254", 21600),
853                http::Response::builder()
854                    .status(500)
855                    .body(SdkBody::empty())
856                    .unwrap(),
857            ),
858            ReplayEvent::new(
859                token_request("http://169.254.169.254", 21600),
860                token_response(21600, TOKEN_A),
861            ),
862            ReplayEvent::new(
863                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
864                imds_response("ok"),
865            ),
866        ]);
867        assert_eq!(
868            "ok",
869            client
870                .get("/latest/metadata")
871                .await
872                .expect("success")
873                .as_ref()
874        );
875        http_client.assert_requests_match(&[]);
876    }
877
878    /// 401 error during metadata retrieval must be retried
879    #[tokio::test]
880    #[traced_test]
881    async fn retry_metadata_401() {
882        let (client, http_client) = mock_imds_client(vec![
883            ReplayEvent::new(
884                token_request("http://169.254.169.254", 21600),
885                token_response(0, TOKEN_A),
886            ),
887            ReplayEvent::new(
888                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
889                http::Response::builder()
890                    .status(401)
891                    .body(SdkBody::empty())
892                    .unwrap(),
893            ),
894            ReplayEvent::new(
895                token_request("http://169.254.169.254", 21600),
896                token_response(21600, TOKEN_B),
897            ),
898            ReplayEvent::new(
899                imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
900                imds_response("ok"),
901            ),
902        ]);
903        assert_eq!(
904            "ok",
905            client
906                .get("/latest/metadata")
907                .await
908                .expect("success")
909                .as_ref()
910        );
911        http_client.assert_requests_match(&[]);
912    }
913
914    /// 403 responses from IMDS during token acquisition MUST NOT be retried
915    #[tokio::test]
916    #[traced_test]
917    async fn no_403_retry() {
918        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
919            token_request("http://169.254.169.254", 21600),
920            http::Response::builder()
921                .status(403)
922                .body(SdkBody::empty())
923                .unwrap(),
924        )]);
925        let err = client.get("/latest/metadata").await.expect_err("no token");
926        assert_full_error_contains!(err, "forbidden");
927        http_client.assert_requests_match(&[]);
928    }
929
930    /// The classifier should return `None` when classifying a successful response.
931    #[test]
932    fn successful_response_properly_classified() {
933        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
934        ctx.set_output_or_error(Ok(Output::doesnt_matter()));
935        ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
936        let classifier = ImdsResponseRetryClassifier;
937        assert_eq!(
938            RetryAction::NoActionIndicated,
939            classifier.classify_retry(&ctx)
940        );
941
942        // Emulate a failure to parse the response body (using an io error since it's easy to construct in a test)
943        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
944        ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
945            io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
946        ))));
947        assert_eq!(
948            RetryAction::NoActionIndicated,
949            classifier.classify_retry(&ctx)
950        );
951    }
952
953    // since tokens are sent as headers, the tokens need to be valid header values
954    #[tokio::test]
955    async fn invalid_token() {
956        let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
957            token_request("http://169.254.169.254", 21600),
958            token_response(21600, "invalid\nheader\nvalue\0"),
959        )]);
960        let err = client.get("/latest/metadata").await.expect_err("no token");
961        assert_full_error_contains!(err, "invalid token");
962        http_client.assert_requests_match(&[]);
963    }
964
965    #[tokio::test]
966    async fn non_utf8_response() {
967        let (client, http_client) = mock_imds_client(vec![
968            ReplayEvent::new(
969                token_request("http://169.254.169.254", 21600),
970                token_response(21600, TOKEN_A).map(SdkBody::from),
971            ),
972            ReplayEvent::new(
973                imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
974                http::Response::builder()
975                    .status(200)
976                    .body(SdkBody::from(vec![0xA0, 0xA1]))
977                    .unwrap(),
978            ),
979        ]);
980        let err = client.get("/latest/metadata").await.expect_err("no token");
981        assert_full_error_contains!(err, "invalid UTF-8");
982        http_client.assert_requests_match(&[]);
983    }
984
985    // TODO(https://github.com/awslabs/aws-sdk-rust/issues/1117) This test is ignored on Windows because it uses Unix-style paths
986    #[cfg_attr(windows, ignore)]
987    /// Verify that the end-to-end real client has a 1-second connect timeout
988    #[tokio::test]
989    #[cfg(feature = "rustls")]
990    async fn one_second_connect_timeout() {
991        use crate::imds::client::ImdsError;
992        use aws_smithy_types::error::display::DisplayErrorContext;
993        use std::time::SystemTime;
994
995        let client = Client::builder()
996            // 240.* can never be resolved
997            .endpoint("http://240.0.0.0")
998            .expect("valid uri")
999            .build();
1000        let now = SystemTime::now();
1001        let resp = client
1002            .get("/latest/metadata")
1003            .await
1004            .expect_err("240.0.0.0 will never resolve");
1005        match resp {
1006            err @ ImdsError::FailedToLoadToken(_)
1007                if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} // ok,
1008            other => panic!(
1009                "wrong error, expected construction failure with TimedOutError inside: {}",
1010                DisplayErrorContext(&other)
1011            ),
1012        }
1013        let time_elapsed = now.elapsed().unwrap();
1014        assert!(
1015            time_elapsed > Duration::from_secs(1),
1016            "time_elapsed should be greater than 1s but was {:?}",
1017            time_elapsed
1018        );
1019        assert!(
1020            time_elapsed < Duration::from_secs(2),
1021            "time_elapsed should be less than 2s but was {:?}",
1022            time_elapsed
1023        );
1024    }
1025
1026    #[derive(Debug, Deserialize)]
1027    struct ImdsConfigTest {
1028        env: HashMap<String, String>,
1029        fs: HashMap<String, String>,
1030        endpoint_override: Option<String>,
1031        mode_override: Option<String>,
1032        result: Result<String, String>,
1033        docs: String,
1034    }
1035
1036    #[tokio::test]
1037    async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1038        let _logs = capture_test_logs();
1039
1040        let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1041        #[derive(Deserialize)]
1042        struct TestCases {
1043            tests: Vec<ImdsConfigTest>,
1044        }
1045
1046        let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1047        let test_cases = test_cases.tests;
1048        for test in test_cases {
1049            check(test).await;
1050        }
1051        Ok(())
1052    }
1053
1054    async fn check(test_case: ImdsConfigTest) {
1055        let (http_client, watcher) = capture_request(None);
1056        let provider_config = ProviderConfig::no_configuration()
1057            .with_sleep_impl(TokioSleep::new())
1058            .with_env(Env::from(test_case.env))
1059            .with_fs(Fs::from_map(test_case.fs))
1060            .with_http_client(http_client);
1061        let mut imds_client = Client::builder().configure(&provider_config);
1062        if let Some(endpoint_override) = test_case.endpoint_override {
1063            imds_client = imds_client
1064                .endpoint(endpoint_override)
1065                .expect("invalid URI");
1066        }
1067
1068        if let Some(mode_override) = test_case.mode_override {
1069            imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1070        }
1071
1072        let imds_client = imds_client.build();
1073        match &test_case.result {
1074            Ok(uri) => {
1075                // this request will fail, we just want to capture the endpoint configuration
1076                let _ = imds_client.get("/hello").await;
1077                assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1078            }
1079            Err(expected) => {
1080                let err = imds_client.get("/hello").await.expect_err("it should fail");
1081                let message = format!("{}", DisplayErrorContext(&err));
1082                assert!(
1083                    message.contains(expected),
1084                    "{}\nexpected error: {expected}\nactual error: {message}",
1085                    test_case.docs
1086                );
1087            }
1088        };
1089    }
1090}