1use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
11use crate::imds::client::token::TokenRuntimePlugin;
12use crate::provider_config::ProviderConfig;
13use crate::PKG_VERSION;
14use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
15use aws_smithy_runtime::client::orchestrator::operation::Operation;
16use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
17use aws_smithy_runtime_api::box_error::BoxError;
18use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
19use aws_smithy_runtime_api::client::endpoint::{
20 EndpointFuture, EndpointResolverParams, ResolveEndpoint,
21};
22use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
23use aws_smithy_runtime_api::client::orchestrator::{
24 HttpRequest, OrchestratorError, SensitiveOutput,
25};
26use aws_smithy_runtime_api::client::result::ConnectorError;
27use aws_smithy_runtime_api::client::result::SdkError;
28use aws_smithy_runtime_api::client::retries::classifiers::{
29 ClassifyRetry, RetryAction, SharedRetryClassifier,
30};
31use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
32use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
33use aws_smithy_types::body::SdkBody;
34use aws_smithy_types::config_bag::{FrozenLayer, Layer};
35use aws_smithy_types::endpoint::Endpoint;
36use aws_smithy_types::retry::RetryConfig;
37use aws_smithy_types::timeout::TimeoutConfig;
38use aws_types::os_shim_internal::Env;
39use http::Uri;
40use std::borrow::Cow;
41use std::error::Error as _;
42use std::fmt;
43use std::str::FromStr;
44use std::sync::Arc;
45use std::time::Duration;
46
47pub mod error;
48mod token;
49
50const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
52const DEFAULT_ATTEMPTS: u32 = 4;
53const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
54const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
55
56fn user_agent() -> AwsUserAgent {
57 AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
58}
59
60#[derive(Clone, Debug)]
129pub struct Client {
130 operation: Operation<String, SensitiveString, InnerImdsError>,
131}
132
133impl Client {
134 pub fn builder() -> Builder {
136 Builder::default()
137 }
138
139 pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
160 self.operation
161 .invoke(path.into())
162 .await
163 .map_err(|err| match err {
164 SdkError::ConstructionFailure(_) if err.source().is_some() => {
165 match err.into_source().map(|e| e.downcast::<ImdsError>()) {
166 Ok(Ok(token_failure)) => *token_failure,
167 Ok(Err(err)) => ImdsError::unexpected(err),
168 Err(err) => ImdsError::unexpected(err),
169 }
170 }
171 SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
172 SdkError::ServiceError(context) => match context.err() {
173 InnerImdsError::InvalidUtf8 => {
174 ImdsError::unexpected("IMDS returned invalid UTF-8")
175 }
176 InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
177 },
178 err @ SdkError::DispatchFailure(_) => match err.into_source() {
182 Ok(source) => match source.downcast::<ConnectorError>() {
183 Ok(source) => match source.into_source().downcast::<ImdsError>() {
184 Ok(source) => *source,
185 Err(err) => ImdsError::unexpected(err),
186 },
187 Err(err) => ImdsError::unexpected(err),
188 },
189 Err(err) => ImdsError::unexpected(err),
190 },
191 SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
192 _ => ImdsError::unexpected(err),
193 })
194 }
195}
196
197#[derive(Clone)]
199pub struct SensitiveString(String);
200
201impl fmt::Debug for SensitiveString {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_tuple("SensitiveString")
204 .field(&"** redacted **")
205 .finish()
206 }
207}
208
209impl AsRef<str> for SensitiveString {
210 fn as_ref(&self) -> &str {
211 &self.0
212 }
213}
214
215impl From<String> for SensitiveString {
216 fn from(value: String) -> Self {
217 Self(value)
218 }
219}
220
221impl From<SensitiveString> for String {
222 fn from(value: SensitiveString) -> Self {
223 value.0
224 }
225}
226
227#[derive(Debug)]
231struct ImdsCommonRuntimePlugin {
232 config: FrozenLayer,
233 components: RuntimeComponentsBuilder,
234}
235
236impl ImdsCommonRuntimePlugin {
237 fn new(
238 config: &ProviderConfig,
239 endpoint_resolver: ImdsEndpointResolver,
240 retry_config: RetryConfig,
241 timeout_config: TimeoutConfig,
242 ) -> Self {
243 let mut layer = Layer::new("ImdsCommonRuntimePlugin");
244 layer.store_put(AuthSchemeOptionResolverParams::new(()));
245 layer.store_put(EndpointResolverParams::new(()));
246 layer.store_put(SensitiveOutput);
247 layer.store_put(retry_config);
248 layer.store_put(timeout_config);
249 layer.store_put(user_agent());
250
251 Self {
252 config: layer.freeze(),
253 components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
254 .with_http_client(config.http_client())
255 .with_endpoint_resolver(Some(endpoint_resolver))
256 .with_interceptor(UserAgentInterceptor::new())
257 .with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier))
258 .with_retry_strategy(Some(StandardRetryStrategy::new()))
259 .with_time_source(Some(config.time_source()))
260 .with_sleep_impl(config.sleep_impl()),
261 }
262 }
263}
264
265impl RuntimePlugin for ImdsCommonRuntimePlugin {
266 fn config(&self) -> Option<FrozenLayer> {
267 Some(self.config.clone())
268 }
269
270 fn runtime_components(
271 &self,
272 _current_components: &RuntimeComponentsBuilder,
273 ) -> Cow<'_, RuntimeComponentsBuilder> {
274 Cow::Borrowed(&self.components)
275 }
276}
277
278#[derive(Debug, Clone)]
284#[non_exhaustive]
285pub enum EndpointMode {
286 IpV4,
290 IpV6,
292}
293
294impl FromStr for EndpointMode {
295 type Err = InvalidEndpointMode;
296
297 fn from_str(value: &str) -> Result<Self, Self::Err> {
298 match value {
299 _ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
300 _ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
301 other => Err(InvalidEndpointMode::new(other.to_owned())),
302 }
303 }
304}
305
306impl EndpointMode {
307 fn endpoint(&self) -> Uri {
309 match self {
310 EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
311 EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
312 }
313 }
314}
315
316#[derive(Default, Debug, Clone)]
318pub struct Builder {
319 max_attempts: Option<u32>,
320 endpoint: Option<EndpointSource>,
321 mode_override: Option<EndpointMode>,
322 token_ttl: Option<Duration>,
323 connect_timeout: Option<Duration>,
324 read_timeout: Option<Duration>,
325 config: Option<ProviderConfig>,
326}
327
328impl Builder {
329 pub fn max_attempts(mut self, max_attempts: u32) -> Self {
333 self.max_attempts = Some(max_attempts);
334 self
335 }
336
337 pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
351 self.config = Some(provider_config.clone());
352 self
353 }
354
355 pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
361 let uri: Uri = endpoint.as_ref().parse()?;
362 self.endpoint = Some(EndpointSource::Explicit(uri));
363 Ok(self)
364 }
365
366 pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
371 self.mode_override = Some(mode);
372 self
373 }
374
375 pub fn token_ttl(mut self, ttl: Duration) -> Self {
381 self.token_ttl = Some(ttl);
382 self
383 }
384
385 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
389 self.connect_timeout = Some(timeout);
390 self
391 }
392
393 pub fn read_timeout(mut self, timeout: Duration) -> Self {
397 self.read_timeout = Some(timeout);
398 self
399 }
400
401 pub fn build(self) -> Client {
410 let config = self.config.unwrap_or_default();
411 let timeout_config = TimeoutConfig::builder()
412 .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
413 .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
414 .build();
415 let endpoint_source = self
416 .endpoint
417 .unwrap_or_else(|| EndpointSource::Env(config.clone()));
418 let endpoint_resolver = ImdsEndpointResolver {
419 endpoint_source: Arc::new(endpoint_source),
420 mode_override: self.mode_override,
421 };
422 let retry_config = RetryConfig::standard()
423 .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
424 let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
425 &config,
426 endpoint_resolver,
427 retry_config,
428 timeout_config,
429 ));
430 let operation = Operation::builder()
431 .service_name("imds")
432 .operation_name("get")
433 .runtime_plugin(common_plugin.clone())
434 .runtime_plugin(TokenRuntimePlugin::new(
435 common_plugin,
436 self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
437 ))
438 .with_connection_poisoning()
439 .serializer(|path| {
440 Ok(HttpRequest::try_from(
441 http::Request::builder()
442 .uri(path)
443 .body(SdkBody::empty())
444 .expect("valid request"),
445 )
446 .unwrap())
447 })
448 .deserializer(|response| {
449 if response.status().is_success() {
450 std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
451 .map(|data| SensitiveString::from(data.to_string()))
452 .map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
453 } else {
454 Err(OrchestratorError::operation(InnerImdsError::BadStatus))
455 }
456 })
457 .build();
458 Client { operation }
459 }
460}
461
462mod env {
463 pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
464 pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
465}
466
467mod profile_keys {
468 pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
469 pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
470}
471
472#[derive(Debug, Clone)]
474enum EndpointSource {
475 Explicit(Uri),
476 Env(ProviderConfig),
477}
478
479impl EndpointSource {
480 async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
481 match self {
482 EndpointSource::Explicit(uri) => {
483 if mode_override.is_some() {
484 tracing::warn!(endpoint = ?uri, mode = ?mode_override,
485 "Endpoint mode override was set in combination with an explicit endpoint. \
486 The mode override will be ignored.")
487 }
488 Ok(uri.clone())
489 }
490 EndpointSource::Env(conf) => {
491 let env = conf.env();
492 let profile = conf.profile().await;
494 let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
495 Some(Cow::Owned(uri))
496 } else {
497 profile
498 .and_then(|profile| profile.get(profile_keys::ENDPOINT))
499 .map(Cow::Borrowed)
500 };
501 if let Some(uri) = uri_override {
502 return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
503 }
504
505 let mode = if let Some(mode) = mode_override {
507 mode
508 } else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
509 mode.parse::<EndpointMode>()
510 .map_err(BuildError::invalid_endpoint_mode)?
511 } else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
512 {
513 mode.parse::<EndpointMode>()
514 .map_err(BuildError::invalid_endpoint_mode)?
515 } else {
516 EndpointMode::IpV4
517 };
518
519 Ok(mode.endpoint())
520 }
521 }
522 }
523}
524
525#[derive(Clone, Debug)]
526struct ImdsEndpointResolver {
527 endpoint_source: Arc<EndpointSource>,
528 mode_override: Option<EndpointMode>,
529}
530
531impl ResolveEndpoint for ImdsEndpointResolver {
532 fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
533 EndpointFuture::new(async move {
534 self.endpoint_source
535 .endpoint(self.mode_override.clone())
536 .await
537 .map(|uri| Endpoint::builder().url(uri.to_string()).build())
538 .map_err(|err| err.into())
539 })
540 }
541}
542
543#[derive(Clone, Debug)]
553struct ImdsResponseRetryClassifier;
554
555impl ClassifyRetry for ImdsResponseRetryClassifier {
556 fn name(&self) -> &'static str {
557 "ImdsResponseRetryClassifier"
558 }
559
560 fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
561 if let Some(response) = ctx.response() {
562 let status = response.status();
563 match status {
564 _ if status.is_server_error() => RetryAction::server_error(),
565 _ if status.as_u16() == 401 => RetryAction::server_error(),
567 _ => RetryAction::NoActionIndicated,
569 }
570 } else {
571 RetryAction::NoActionIndicated
575 }
576 }
577}
578
579#[cfg(test)]
580pub(crate) mod test {
581 use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
582 use crate::provider_config::ProviderConfig;
583 use aws_smithy_async::rt::sleep::TokioSleep;
584 use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
585 use aws_smithy_runtime::client::http::test_util::{
586 capture_request, ReplayEvent, StaticReplayClient,
587 };
588 use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
589 use aws_smithy_runtime_api::client::interceptors::context::{
590 Input, InterceptorContext, Output,
591 };
592 use aws_smithy_runtime_api::client::orchestrator::{
593 HttpRequest, HttpResponse, OrchestratorError,
594 };
595 use aws_smithy_runtime_api::client::result::ConnectorError;
596 use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
597 use aws_smithy_types::body::SdkBody;
598 use aws_smithy_types::error::display::DisplayErrorContext;
599 use aws_types::os_shim_internal::{Env, Fs};
600 use http::header::USER_AGENT;
601 use http::Uri;
602 use serde::Deserialize;
603 use std::collections::HashMap;
604 use std::error::Error;
605 use std::io;
606 use std::time::{Duration, UNIX_EPOCH};
607 use tracing_test::traced_test;
608
609 macro_rules! assert_full_error_contains {
610 ($err:expr, $contains:expr) => {
611 let err = $err;
612 let message = format!(
613 "{}",
614 aws_smithy_types::error::display::DisplayErrorContext(&err)
615 );
616 assert!(
617 message.contains($contains),
618 "Error message '{message}' didn't contain text '{}'",
619 $contains
620 );
621 };
622 }
623
624 const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
625 const TOKEN_B: &str = "alternatetoken==";
626
627 pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
628 http::Request::builder()
629 .uri(format!("{}/latest/api/token", base))
630 .header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
631 .method("PUT")
632 .body(SdkBody::empty())
633 .unwrap()
634 .try_into()
635 .unwrap()
636 }
637
638 pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
639 HttpResponse::try_from(
640 http::Response::builder()
641 .status(200)
642 .header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
643 .body(SdkBody::from(token))
644 .unwrap(),
645 )
646 .unwrap()
647 }
648
649 pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
650 http::Request::builder()
651 .uri(Uri::from_static(path))
652 .method("GET")
653 .header("x-aws-ec2-metadata-token", token)
654 .body(SdkBody::empty())
655 .unwrap()
656 .try_into()
657 .unwrap()
658 }
659
660 pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
661 HttpResponse::try_from(
662 http::Response::builder()
663 .status(200)
664 .body(SdkBody::from(body))
665 .unwrap(),
666 )
667 .unwrap()
668 }
669
670 pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
671 tokio::time::pause();
672 super::Client::builder()
673 .configure(
674 &ProviderConfig::no_configuration()
675 .with_sleep_impl(InstantSleep::unlogged())
676 .with_http_client(http_client.clone()),
677 )
678 .build()
679 }
680
681 fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
682 let http_client = StaticReplayClient::new(events);
683 let client = make_imds_client(&http_client);
684 (client, http_client)
685 }
686
687 #[tokio::test]
688 async fn client_caches_token() {
689 let (client, http_client) = mock_imds_client(vec![
690 ReplayEvent::new(
691 token_request("http://169.254.169.254", 21600),
692 token_response(21600, TOKEN_A),
693 ),
694 ReplayEvent::new(
695 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
696 imds_response(r#"test-imds-output"#),
697 ),
698 ReplayEvent::new(
699 imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
700 imds_response("output2"),
701 ),
702 ]);
703 let metadata = client.get("/latest/metadata").await.expect("failed");
705 assert_eq!("test-imds-output", metadata.as_ref());
706 let metadata = client.get("/latest/metadata2").await.expect("failed");
708 assert_eq!("output2", metadata.as_ref());
709 http_client.assert_requests_match(&[]);
710 }
711
712 #[tokio::test]
713 async fn token_can_expire() {
714 let (_, http_client) = mock_imds_client(vec![
715 ReplayEvent::new(
716 token_request("http://[fd00:ec2::254]", 600),
717 token_response(600, TOKEN_A),
718 ),
719 ReplayEvent::new(
720 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
721 imds_response(r#"test-imds-output1"#),
722 ),
723 ReplayEvent::new(
724 token_request("http://[fd00:ec2::254]", 600),
725 token_response(600, TOKEN_B),
726 ),
727 ReplayEvent::new(
728 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
729 imds_response(r#"test-imds-output2"#),
730 ),
731 ]);
732 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
733 let client = super::Client::builder()
734 .configure(
735 &ProviderConfig::no_configuration()
736 .with_http_client(http_client.clone())
737 .with_time_source(time_source.clone())
738 .with_sleep_impl(sleep),
739 )
740 .endpoint_mode(EndpointMode::IpV6)
741 .token_ttl(Duration::from_secs(600))
742 .build();
743
744 let resp1 = client.get("/latest/metadata").await.expect("success");
745 time_source.advance(Duration::from_secs(600));
747 let resp2 = client.get("/latest/metadata").await.expect("success");
748 http_client.assert_requests_match(&[]);
749 assert_eq!("test-imds-output1", resp1.as_ref());
750 assert_eq!("test-imds-output2", resp2.as_ref());
751 }
752
753 #[tokio::test]
755 async fn token_refresh_buffer() {
756 let _logs = capture_test_logs();
757 let (_, http_client) = mock_imds_client(vec![
758 ReplayEvent::new(
759 token_request("http://[fd00:ec2::254]", 600),
760 token_response(600, TOKEN_A),
761 ),
762 ReplayEvent::new(
764 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
765 imds_response(r#"test-imds-output1"#),
766 ),
767 ReplayEvent::new(
769 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
770 imds_response(r#"test-imds-output2"#),
771 ),
772 ReplayEvent::new(
774 token_request("http://[fd00:ec2::254]", 600),
775 token_response(600, TOKEN_B),
776 ),
777 ReplayEvent::new(
778 imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
779 imds_response(r#"test-imds-output3"#),
780 ),
781 ]);
782 let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
783 let client = super::Client::builder()
784 .configure(
785 &ProviderConfig::no_configuration()
786 .with_sleep_impl(sleep)
787 .with_http_client(http_client.clone())
788 .with_time_source(time_source.clone()),
789 )
790 .endpoint_mode(EndpointMode::IpV6)
791 .token_ttl(Duration::from_secs(600))
792 .build();
793
794 tracing::info!("resp1 -----------------------------------------------------------");
795 let resp1 = client.get("/latest/metadata").await.expect("success");
796 time_source.advance(Duration::from_secs(400));
798 tracing::info!("resp2 -----------------------------------------------------------");
799 let resp2 = client.get("/latest/metadata").await.expect("success");
800 time_source.advance(Duration::from_secs(150));
801 tracing::info!("resp3 -----------------------------------------------------------");
802 let resp3 = client.get("/latest/metadata").await.expect("success");
803 http_client.assert_requests_match(&[]);
804 assert_eq!("test-imds-output1", resp1.as_ref());
805 assert_eq!("test-imds-output2", resp2.as_ref());
806 assert_eq!("test-imds-output3", resp3.as_ref());
807 }
808
809 #[tokio::test]
811 #[traced_test]
812 async fn retry_500() {
813 let (client, http_client) = mock_imds_client(vec![
814 ReplayEvent::new(
815 token_request("http://169.254.169.254", 21600),
816 token_response(21600, TOKEN_A),
817 ),
818 ReplayEvent::new(
819 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
820 http::Response::builder()
821 .status(500)
822 .body(SdkBody::empty())
823 .unwrap(),
824 ),
825 ReplayEvent::new(
826 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
827 imds_response("ok"),
828 ),
829 ]);
830 assert_eq!(
831 "ok",
832 client
833 .get("/latest/metadata")
834 .await
835 .expect("success")
836 .as_ref()
837 );
838 http_client.assert_requests_match(&[]);
839
840 for request in http_client.actual_requests() {
842 assert!(request.headers().get(USER_AGENT).is_some());
843 }
844 }
845
846 #[tokio::test]
848 #[traced_test]
849 async fn retry_token_failure() {
850 let (client, http_client) = mock_imds_client(vec![
851 ReplayEvent::new(
852 token_request("http://169.254.169.254", 21600),
853 http::Response::builder()
854 .status(500)
855 .body(SdkBody::empty())
856 .unwrap(),
857 ),
858 ReplayEvent::new(
859 token_request("http://169.254.169.254", 21600),
860 token_response(21600, TOKEN_A),
861 ),
862 ReplayEvent::new(
863 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
864 imds_response("ok"),
865 ),
866 ]);
867 assert_eq!(
868 "ok",
869 client
870 .get("/latest/metadata")
871 .await
872 .expect("success")
873 .as_ref()
874 );
875 http_client.assert_requests_match(&[]);
876 }
877
878 #[tokio::test]
880 #[traced_test]
881 async fn retry_metadata_401() {
882 let (client, http_client) = mock_imds_client(vec![
883 ReplayEvent::new(
884 token_request("http://169.254.169.254", 21600),
885 token_response(0, TOKEN_A),
886 ),
887 ReplayEvent::new(
888 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
889 http::Response::builder()
890 .status(401)
891 .body(SdkBody::empty())
892 .unwrap(),
893 ),
894 ReplayEvent::new(
895 token_request("http://169.254.169.254", 21600),
896 token_response(21600, TOKEN_B),
897 ),
898 ReplayEvent::new(
899 imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
900 imds_response("ok"),
901 ),
902 ]);
903 assert_eq!(
904 "ok",
905 client
906 .get("/latest/metadata")
907 .await
908 .expect("success")
909 .as_ref()
910 );
911 http_client.assert_requests_match(&[]);
912 }
913
914 #[tokio::test]
916 #[traced_test]
917 async fn no_403_retry() {
918 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
919 token_request("http://169.254.169.254", 21600),
920 http::Response::builder()
921 .status(403)
922 .body(SdkBody::empty())
923 .unwrap(),
924 )]);
925 let err = client.get("/latest/metadata").await.expect_err("no token");
926 assert_full_error_contains!(err, "forbidden");
927 http_client.assert_requests_match(&[]);
928 }
929
930 #[test]
932 fn successful_response_properly_classified() {
933 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
934 ctx.set_output_or_error(Ok(Output::doesnt_matter()));
935 ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
936 let classifier = ImdsResponseRetryClassifier;
937 assert_eq!(
938 RetryAction::NoActionIndicated,
939 classifier.classify_retry(&ctx)
940 );
941
942 let mut ctx = InterceptorContext::new(Input::doesnt_matter());
944 ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
945 io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
946 ))));
947 assert_eq!(
948 RetryAction::NoActionIndicated,
949 classifier.classify_retry(&ctx)
950 );
951 }
952
953 #[tokio::test]
955 async fn invalid_token() {
956 let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
957 token_request("http://169.254.169.254", 21600),
958 token_response(21600, "invalid\nheader\nvalue\0"),
959 )]);
960 let err = client.get("/latest/metadata").await.expect_err("no token");
961 assert_full_error_contains!(err, "invalid token");
962 http_client.assert_requests_match(&[]);
963 }
964
965 #[tokio::test]
966 async fn non_utf8_response() {
967 let (client, http_client) = mock_imds_client(vec![
968 ReplayEvent::new(
969 token_request("http://169.254.169.254", 21600),
970 token_response(21600, TOKEN_A).map(SdkBody::from),
971 ),
972 ReplayEvent::new(
973 imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
974 http::Response::builder()
975 .status(200)
976 .body(SdkBody::from(vec![0xA0, 0xA1]))
977 .unwrap(),
978 ),
979 ]);
980 let err = client.get("/latest/metadata").await.expect_err("no token");
981 assert_full_error_contains!(err, "invalid UTF-8");
982 http_client.assert_requests_match(&[]);
983 }
984
985 #[cfg_attr(windows, ignore)]
987 #[tokio::test]
989 #[cfg(feature = "rustls")]
990 async fn one_second_connect_timeout() {
991 use crate::imds::client::ImdsError;
992 use aws_smithy_types::error::display::DisplayErrorContext;
993 use std::time::SystemTime;
994
995 let client = Client::builder()
996 .endpoint("http://240.0.0.0")
998 .expect("valid uri")
999 .build();
1000 let now = SystemTime::now();
1001 let resp = client
1002 .get("/latest/metadata")
1003 .await
1004 .expect_err("240.0.0.0 will never resolve");
1005 match resp {
1006 err @ ImdsError::FailedToLoadToken(_)
1007 if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
1009 "wrong error, expected construction failure with TimedOutError inside: {}",
1010 DisplayErrorContext(&other)
1011 ),
1012 }
1013 let time_elapsed = now.elapsed().unwrap();
1014 assert!(
1015 time_elapsed > Duration::from_secs(1),
1016 "time_elapsed should be greater than 1s but was {:?}",
1017 time_elapsed
1018 );
1019 assert!(
1020 time_elapsed < Duration::from_secs(2),
1021 "time_elapsed should be less than 2s but was {:?}",
1022 time_elapsed
1023 );
1024 }
1025
1026 #[derive(Debug, Deserialize)]
1027 struct ImdsConfigTest {
1028 env: HashMap<String, String>,
1029 fs: HashMap<String, String>,
1030 endpoint_override: Option<String>,
1031 mode_override: Option<String>,
1032 result: Result<String, String>,
1033 docs: String,
1034 }
1035
1036 #[tokio::test]
1037 async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
1038 let _logs = capture_test_logs();
1039
1040 let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
1041 #[derive(Deserialize)]
1042 struct TestCases {
1043 tests: Vec<ImdsConfigTest>,
1044 }
1045
1046 let test_cases: TestCases = serde_json::from_str(&test_cases)?;
1047 let test_cases = test_cases.tests;
1048 for test in test_cases {
1049 check(test).await;
1050 }
1051 Ok(())
1052 }
1053
1054 async fn check(test_case: ImdsConfigTest) {
1055 let (http_client, watcher) = capture_request(None);
1056 let provider_config = ProviderConfig::no_configuration()
1057 .with_sleep_impl(TokioSleep::new())
1058 .with_env(Env::from(test_case.env))
1059 .with_fs(Fs::from_map(test_case.fs))
1060 .with_http_client(http_client);
1061 let mut imds_client = Client::builder().configure(&provider_config);
1062 if let Some(endpoint_override) = test_case.endpoint_override {
1063 imds_client = imds_client
1064 .endpoint(endpoint_override)
1065 .expect("invalid URI");
1066 }
1067
1068 if let Some(mode_override) = test_case.mode_override {
1069 imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
1070 }
1071
1072 let imds_client = imds_client.build();
1073 match &test_case.result {
1074 Ok(uri) => {
1075 let _ = imds_client.get("/hello").await;
1077 assert_eq!(&watcher.expect_request().uri().to_string(), uri);
1078 }
1079 Err(expected) => {
1080 let err = imds_client.get("/hello").await.expect_err("it should fail");
1081 let message = format!("{}", DisplayErrorContext(&err));
1082 assert!(
1083 message.contains(expected),
1084 "{}\nexpected error: {expected}\nactual error: {message}",
1085 test_case.docs
1086 );
1087 }
1088 };
1089 }
1090}