use crate::imds::client::error::{BuildError, ImdsError, InnerImdsError, InvalidEndpointMode};
use crate::imds::client::token::TokenRuntimePlugin;
use crate::provider_config::ProviderConfig;
use crate::PKG_VERSION;
use aws_runtime::user_agent::{ApiMetadata, AwsUserAgent, UserAgentInterceptor};
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::client::retries::strategy::StandardRetryStrategy;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::AuthSchemeOptionResolverParams;
use aws_smithy_runtime_api::client::endpoint::{
EndpointFuture, EndpointResolverParams, ResolveEndpoint,
};
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::{
HttpRequest, OrchestratorError, SensitiveOutput,
};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::result::SdkError;
use aws_smithy_runtime_api::client::retries::classifiers::{
ClassifyRetry, RetryAction, SharedRetryClassifier,
};
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::config_bag::{FrozenLayer, Layer};
use aws_smithy_types::endpoint::Endpoint;
use aws_smithy_types::retry::RetryConfig;
use aws_smithy_types::timeout::TimeoutConfig;
use aws_types::os_shim_internal::Env;
use http::Uri;
use std::borrow::Cow;
use std::error::Error as _;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
pub mod error;
mod token;
const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
fn user_agent() -> AwsUserAgent {
AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
}
#[derive(Clone, Debug)]
pub struct Client {
operation: Operation<String, SensitiveString, InnerImdsError>,
}
impl Client {
pub fn builder() -> Builder {
Builder::default()
}
pub async fn get(&self, path: impl Into<String>) -> Result<SensitiveString, ImdsError> {
self.operation
.invoke(path.into())
.await
.map_err(|err| match err {
SdkError::ConstructionFailure(_) if err.source().is_some() => {
match err.into_source().map(|e| e.downcast::<ImdsError>()) {
Ok(Ok(token_failure)) => *token_failure,
Ok(Err(err)) => ImdsError::unexpected(err),
Err(err) => ImdsError::unexpected(err),
}
}
SdkError::ConstructionFailure(_) => ImdsError::unexpected(err),
SdkError::ServiceError(context) => match context.err() {
InnerImdsError::InvalidUtf8 => {
ImdsError::unexpected("IMDS returned invalid UTF-8")
}
InnerImdsError::BadStatus => ImdsError::error_response(context.into_raw()),
},
err @ SdkError::DispatchFailure(_) => match err.into_source() {
Ok(source) => match source.downcast::<ConnectorError>() {
Ok(source) => match source.into_source().downcast::<ImdsError>() {
Ok(source) => *source,
Err(err) => ImdsError::unexpected(err),
},
Err(err) => ImdsError::unexpected(err),
},
Err(err) => ImdsError::unexpected(err),
},
SdkError::TimeoutError(_) | SdkError::ResponseError(_) => ImdsError::io_error(err),
_ => ImdsError::unexpected(err),
})
}
}
#[derive(Clone)]
pub struct SensitiveString(String);
impl fmt::Debug for SensitiveString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SensitiveString")
.field(&"** redacted **")
.finish()
}
}
impl AsRef<str> for SensitiveString {
fn as_ref(&self) -> &str {
&self.0
}
}
impl From<String> for SensitiveString {
fn from(value: String) -> Self {
Self(value)
}
}
impl From<SensitiveString> for String {
fn from(value: SensitiveString) -> Self {
value.0
}
}
#[derive(Debug)]
struct ImdsCommonRuntimePlugin {
config: FrozenLayer,
components: RuntimeComponentsBuilder,
}
impl ImdsCommonRuntimePlugin {
fn new(
config: &ProviderConfig,
endpoint_resolver: ImdsEndpointResolver,
retry_config: RetryConfig,
timeout_config: TimeoutConfig,
) -> Self {
let mut layer = Layer::new("ImdsCommonRuntimePlugin");
layer.store_put(AuthSchemeOptionResolverParams::new(()));
layer.store_put(EndpointResolverParams::new(()));
layer.store_put(SensitiveOutput);
layer.store_put(retry_config);
layer.store_put(timeout_config);
layer.store_put(user_agent());
Self {
config: layer.freeze(),
components: RuntimeComponentsBuilder::new("ImdsCommonRuntimePlugin")
.with_http_client(config.http_client())
.with_endpoint_resolver(Some(endpoint_resolver))
.with_interceptor(UserAgentInterceptor::new())
.with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier))
.with_retry_strategy(Some(StandardRetryStrategy::new()))
.with_time_source(Some(config.time_source()))
.with_sleep_impl(config.sleep_impl()),
}
}
}
impl RuntimePlugin for ImdsCommonRuntimePlugin {
fn config(&self) -> Option<FrozenLayer> {
Some(self.config.clone())
}
fn runtime_components(
&self,
_current_components: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.components)
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EndpointMode {
IpV4,
IpV6,
}
impl FromStr for EndpointMode {
type Err = InvalidEndpointMode;
fn from_str(value: &str) -> Result<Self, Self::Err> {
match value {
_ if value.eq_ignore_ascii_case("ipv4") => Ok(EndpointMode::IpV4),
_ if value.eq_ignore_ascii_case("ipv6") => Ok(EndpointMode::IpV6),
other => Err(InvalidEndpointMode::new(other.to_owned())),
}
}
}
impl EndpointMode {
fn endpoint(&self) -> Uri {
match self {
EndpointMode::IpV4 => Uri::from_static("http://169.254.169.254"),
EndpointMode::IpV6 => Uri::from_static("http://[fd00:ec2::254]"),
}
}
}
#[derive(Default, Debug, Clone)]
pub struct Builder {
max_attempts: Option<u32>,
endpoint: Option<EndpointSource>,
mode_override: Option<EndpointMode>,
token_ttl: Option<Duration>,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
config: Option<ProviderConfig>,
}
impl Builder {
pub fn max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = Some(max_attempts);
self
}
pub fn configure(mut self, provider_config: &ProviderConfig) -> Self {
self.config = Some(provider_config.clone());
self
}
pub fn endpoint(mut self, endpoint: impl AsRef<str>) -> Result<Self, BoxError> {
let uri: Uri = endpoint.as_ref().parse()?;
self.endpoint = Some(EndpointSource::Explicit(uri));
Ok(self)
}
pub fn endpoint_mode(mut self, mode: EndpointMode) -> Self {
self.mode_override = Some(mode);
self
}
pub fn token_ttl(mut self, ttl: Duration) -> Self {
self.token_ttl = Some(ttl);
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn read_timeout(mut self, timeout: Duration) -> Self {
self.read_timeout = Some(timeout);
self
}
pub fn build(self) -> Client {
let config = self.config.unwrap_or_default();
let timeout_config = TimeoutConfig::builder()
.connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
.read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
.build();
let endpoint_source = self
.endpoint
.unwrap_or_else(|| EndpointSource::Env(config.clone()));
let endpoint_resolver = ImdsEndpointResolver {
endpoint_source: Arc::new(endpoint_source),
mode_override: self.mode_override,
};
let retry_config = RetryConfig::standard()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
&config,
endpoint_resolver,
retry_config,
timeout_config,
));
let operation = Operation::builder()
.service_name("imds")
.operation_name("get")
.runtime_plugin(common_plugin.clone())
.runtime_plugin(TokenRuntimePlugin::new(
common_plugin,
self.token_ttl.unwrap_or(DEFAULT_TOKEN_TTL),
))
.with_connection_poisoning()
.serializer(|path| {
Ok(HttpRequest::try_from(
http::Request::builder()
.uri(path)
.body(SdkBody::empty())
.expect("valid request"),
)
.unwrap())
})
.deserializer(|response| {
if response.status().is_success() {
std::str::from_utf8(response.body().bytes().expect("non-streaming response"))
.map(|data| SensitiveString::from(data.to_string()))
.map_err(|_| OrchestratorError::operation(InnerImdsError::InvalidUtf8))
} else {
Err(OrchestratorError::operation(InnerImdsError::BadStatus))
}
})
.build();
Client { operation }
}
}
mod env {
pub(super) const ENDPOINT: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT";
pub(super) const ENDPOINT_MODE: &str = "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE";
}
mod profile_keys {
pub(super) const ENDPOINT: &str = "ec2_metadata_service_endpoint";
pub(super) const ENDPOINT_MODE: &str = "ec2_metadata_service_endpoint_mode";
}
#[derive(Debug, Clone)]
enum EndpointSource {
Explicit(Uri),
Env(ProviderConfig),
}
impl EndpointSource {
async fn endpoint(&self, mode_override: Option<EndpointMode>) -> Result<Uri, BuildError> {
match self {
EndpointSource::Explicit(uri) => {
if mode_override.is_some() {
tracing::warn!(endpoint = ?uri, mode = ?mode_override,
"Endpoint mode override was set in combination with an explicit endpoint. \
The mode override will be ignored.")
}
Ok(uri.clone())
}
EndpointSource::Env(conf) => {
let env = conf.env();
let profile = conf.profile().await;
let uri_override = if let Ok(uri) = env.get(env::ENDPOINT) {
Some(Cow::Owned(uri))
} else {
profile
.and_then(|profile| profile.get(profile_keys::ENDPOINT))
.map(Cow::Borrowed)
};
if let Some(uri) = uri_override {
return Uri::try_from(uri.as_ref()).map_err(BuildError::invalid_endpoint_uri);
}
let mode = if let Some(mode) = mode_override {
mode
} else if let Ok(mode) = env.get(env::ENDPOINT_MODE) {
mode.parse::<EndpointMode>()
.map_err(BuildError::invalid_endpoint_mode)?
} else if let Some(mode) = profile.and_then(|p| p.get(profile_keys::ENDPOINT_MODE))
{
mode.parse::<EndpointMode>()
.map_err(BuildError::invalid_endpoint_mode)?
} else {
EndpointMode::IpV4
};
Ok(mode.endpoint())
}
}
}
}
#[derive(Clone, Debug)]
struct ImdsEndpointResolver {
endpoint_source: Arc<EndpointSource>,
mode_override: Option<EndpointMode>,
}
impl ResolveEndpoint for ImdsEndpointResolver {
fn resolve_endpoint<'a>(&'a self, _: &'a EndpointResolverParams) -> EndpointFuture<'a> {
EndpointFuture::new(async move {
self.endpoint_source
.endpoint(self.mode_override.clone())
.await
.map(|uri| Endpoint::builder().url(uri.to_string()).build())
.map_err(|err| err.into())
})
}
}
#[derive(Clone, Debug)]
struct ImdsResponseRetryClassifier;
impl ClassifyRetry for ImdsResponseRetryClassifier {
fn name(&self) -> &'static str {
"ImdsResponseRetryClassifier"
}
fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
if let Some(response) = ctx.response() {
let status = response.status();
match status {
_ if status.is_server_error() => RetryAction::server_error(),
_ if status.as_u16() == 401 => RetryAction::server_error(),
_ => RetryAction::NoActionIndicated,
}
} else {
RetryAction::NoActionIndicated
}
}
}
#[cfg(test)]
pub(crate) mod test {
use crate::imds::client::{Client, EndpointMode, ImdsResponseRetryClassifier};
use crate::provider_config::ProviderConfig;
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_async::test_util::{instant_time_and_sleep, InstantSleep};
use aws_smithy_runtime::client::http::test_util::{
capture_request, ReplayEvent, StaticReplayClient,
};
use aws_smithy_runtime::test_util::capture_test_logs::capture_test_logs;
use aws_smithy_runtime_api::client::interceptors::context::{
Input, InterceptorContext, Output,
};
use aws_smithy_runtime_api::client::orchestrator::{
HttpRequest, HttpResponse, OrchestratorError,
};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::os_shim_internal::{Env, Fs};
use http::header::USER_AGENT;
use http::Uri;
use serde::Deserialize;
use std::collections::HashMap;
use std::error::Error;
use std::io;
use std::time::{Duration, UNIX_EPOCH};
use tracing_test::traced_test;
macro_rules! assert_full_error_contains {
($err:expr, $contains:expr) => {
let err = $err;
let message = format!(
"{}",
aws_smithy_types::error::display::DisplayErrorContext(&err)
);
assert!(
message.contains($contains),
"Error message '{message}' didn't contain text '{}'",
$contains
);
};
}
const TOKEN_A: &str = "AQAEAFTNrA4eEGx0AQgJ1arIq_Cc-t4tWt3fB0Hd8RKhXlKc5ccvhg==";
const TOKEN_B: &str = "alternatetoken==";
pub(crate) fn token_request(base: &str, ttl: u32) -> HttpRequest {
http::Request::builder()
.uri(format!("{}/latest/api/token", base))
.header("x-aws-ec2-metadata-token-ttl-seconds", ttl)
.method("PUT")
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap()
}
pub(crate) fn token_response(ttl: u32, token: &'static str) -> HttpResponse {
HttpResponse::try_from(
http::Response::builder()
.status(200)
.header("X-aws-ec2-metadata-token-ttl-seconds", ttl)
.body(SdkBody::from(token))
.unwrap(),
)
.unwrap()
}
pub(crate) fn imds_request(path: &'static str, token: &str) -> HttpRequest {
http::Request::builder()
.uri(Uri::from_static(path))
.method("GET")
.header("x-aws-ec2-metadata-token", token)
.body(SdkBody::empty())
.unwrap()
.try_into()
.unwrap()
}
pub(crate) fn imds_response(body: &'static str) -> HttpResponse {
HttpResponse::try_from(
http::Response::builder()
.status(200)
.body(SdkBody::from(body))
.unwrap(),
)
.unwrap()
}
pub(crate) fn make_imds_client(http_client: &StaticReplayClient) -> super::Client {
tokio::time::pause();
super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_sleep_impl(InstantSleep::unlogged())
.with_http_client(http_client.clone()),
)
.build()
}
fn mock_imds_client(events: Vec<ReplayEvent>) -> (Client, StaticReplayClient) {
let http_client = StaticReplayClient::new(events);
let client = make_imds_client(&http_client);
(client, http_client)
}
#[tokio::test]
async fn client_caches_token() {
let (client, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output"#),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata2", TOKEN_A),
imds_response("output2"),
),
]);
let metadata = client.get("/latest/metadata").await.expect("failed");
assert_eq!("test-imds-output", metadata.as_ref());
let metadata = client.get("/latest/metadata2").await.expect("failed");
assert_eq!("output2", metadata.as_ref());
http_client.assert_requests_match(&[]);
}
#[tokio::test]
async fn token_can_expire() {
let (_, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
ReplayEvent::new(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output2"#),
),
]);
let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_http_client(http_client.clone())
.with_time_source(time_source.clone())
.with_sleep_impl(sleep),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build();
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(600));
let resp2 = client.get("/latest/metadata").await.expect("success");
http_client.assert_requests_match(&[]);
assert_eq!("test-imds-output1", resp1.as_ref());
assert_eq!("test-imds-output2", resp2.as_ref());
}
#[tokio::test]
async fn token_refresh_buffer() {
let _logs = capture_test_logs();
let (_, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output1"#),
),
ReplayEvent::new(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_A),
imds_response(r#"test-imds-output2"#),
),
ReplayEvent::new(
token_request("http://[fd00:ec2::254]", 600),
token_response(600, TOKEN_B),
),
ReplayEvent::new(
imds_request("http://[fd00:ec2::254]/latest/metadata", TOKEN_B),
imds_response(r#"test-imds-output3"#),
),
]);
let (time_source, sleep) = instant_time_and_sleep(UNIX_EPOCH);
let client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_sleep_impl(sleep)
.with_http_client(http_client.clone())
.with_time_source(time_source.clone()),
)
.endpoint_mode(EndpointMode::IpV6)
.token_ttl(Duration::from_secs(600))
.build();
tracing::info!("resp1 -----------------------------------------------------------");
let resp1 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(400));
tracing::info!("resp2 -----------------------------------------------------------");
let resp2 = client.get("/latest/metadata").await.expect("success");
time_source.advance(Duration::from_secs(150));
tracing::info!("resp3 -----------------------------------------------------------");
let resp3 = client.get("/latest/metadata").await.expect("success");
http_client.assert_requests_match(&[]);
assert_eq!("test-imds-output1", resp1.as_ref());
assert_eq!("test-imds-output2", resp2.as_ref());
assert_eq!("test-imds-output3", resp3.as_ref());
}
#[tokio::test]
#[traced_test]
async fn retry_500() {
let (client, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(500)
.body(SdkBody::empty())
.unwrap(),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
http_client.assert_requests_match(&[]);
for request in http_client.actual_requests() {
assert!(request.headers().get(USER_AGENT).is_some());
}
}
#[tokio::test]
#[traced_test]
async fn retry_token_failure() {
let (client, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
http::Response::builder()
.status(500)
.body(SdkBody::empty())
.unwrap(),
),
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
imds_response("ok"),
),
]);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
http_client.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn retry_metadata_401() {
let (client, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(0, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(401)
.body(SdkBody::empty())
.unwrap(),
),
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_B),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
imds_response("ok"),
),
]);
assert_eq!(
"ok",
client
.get("/latest/metadata")
.await
.expect("success")
.as_ref()
);
http_client.assert_requests_match(&[]);
}
#[tokio::test]
#[traced_test]
async fn no_403_retry() {
let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
http::Response::builder()
.status(403)
.body(SdkBody::empty())
.unwrap(),
)]);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "forbidden");
http_client.assert_requests_match(&[]);
}
#[test]
fn successful_response_properly_classified() {
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Ok(Output::doesnt_matter()));
ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
let classifier = ImdsResponseRetryClassifier;
assert_eq!(
RetryAction::NoActionIndicated,
classifier.classify_retry(&ctx)
);
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::connector(ConnectorError::io(
io::Error::new(io::ErrorKind::BrokenPipe, "fail to parse").into(),
))));
assert_eq!(
RetryAction::NoActionIndicated,
classifier.classify_retry(&ctx)
);
}
#[tokio::test]
async fn invalid_token() {
let (client, http_client) = mock_imds_client(vec![ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, "invalid\nheader\nvalue\0"),
)]);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid token");
http_client.assert_requests_match(&[]);
}
#[tokio::test]
async fn non_utf8_response() {
let (client, http_client) = mock_imds_client(vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_A).map(SdkBody::from),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(200)
.body(SdkBody::from(vec![0xA0, 0xA1]))
.unwrap(),
),
]);
let err = client.get("/latest/metadata").await.expect_err("no token");
assert_full_error_contains!(err, "invalid UTF-8");
http_client.assert_requests_match(&[]);
}
#[cfg_attr(windows, ignore)]
#[tokio::test]
#[cfg(feature = "rustls")]
async fn one_second_connect_timeout() {
use crate::imds::client::ImdsError;
use aws_smithy_types::error::display::DisplayErrorContext;
use std::time::SystemTime;
let client = Client::builder()
.endpoint("http://240.0.0.0")
.expect("valid uri")
.build();
let now = SystemTime::now();
let resp = client
.get("/latest/metadata")
.await
.expect_err("240.0.0.0 will never resolve");
match resp {
err @ ImdsError::FailedToLoadToken(_)
if format!("{}", DisplayErrorContext(&err)).contains("timeout") => {} other => panic!(
"wrong error, expected construction failure with TimedOutError inside: {}",
DisplayErrorContext(&other)
),
}
let time_elapsed = now.elapsed().unwrap();
assert!(
time_elapsed > Duration::from_secs(1),
"time_elapsed should be greater than 1s but was {:?}",
time_elapsed
);
assert!(
time_elapsed < Duration::from_secs(2),
"time_elapsed should be less than 2s but was {:?}",
time_elapsed
);
}
#[derive(Debug, Deserialize)]
struct ImdsConfigTest {
env: HashMap<String, String>,
fs: HashMap<String, String>,
endpoint_override: Option<String>,
mode_override: Option<String>,
result: Result<String, String>,
docs: String,
}
#[tokio::test]
async fn endpoint_config_tests() -> Result<(), Box<dyn Error>> {
let _logs = capture_test_logs();
let test_cases = std::fs::read_to_string("test-data/imds-config/imds-endpoint-tests.json")?;
#[derive(Deserialize)]
struct TestCases {
tests: Vec<ImdsConfigTest>,
}
let test_cases: TestCases = serde_json::from_str(&test_cases)?;
let test_cases = test_cases.tests;
for test in test_cases {
check(test).await;
}
Ok(())
}
async fn check(test_case: ImdsConfigTest) {
let (http_client, watcher) = capture_request(None);
let provider_config = ProviderConfig::no_configuration()
.with_sleep_impl(TokioSleep::new())
.with_env(Env::from(test_case.env))
.with_fs(Fs::from_map(test_case.fs))
.with_http_client(http_client);
let mut imds_client = Client::builder().configure(&provider_config);
if let Some(endpoint_override) = test_case.endpoint_override {
imds_client = imds_client
.endpoint(endpoint_override)
.expect("invalid URI");
}
if let Some(mode_override) = test_case.mode_override {
imds_client = imds_client.endpoint_mode(mode_override.parse().unwrap());
}
let imds_client = imds_client.build();
match &test_case.result {
Ok(uri) => {
let _ = imds_client.get("/hello").await;
assert_eq!(&watcher.expect_request().uri().to_string(), uri);
}
Err(expected) => {
let err = imds_client.get("/hello").await.expect_err("it should fail");
let message = format!("{}", DisplayErrorContext(&err));
assert!(
message.contains(expected),
"{}\nexpected error: {expected}\nactual error: {message}",
test_case.docs
);
}
};
}
}