aws_config/imds/client/
token.rsuse crate::identity::IdentityCache;
use crate::imds::client::error::{ImdsError, TokenError, TokenErrorKind};
use aws_smithy_async::time::SharedTimeSource;
use aws_smithy_runtime::client::orchestrator::operation::Operation;
use aws_smithy_runtime::expiring_cache::ExpiringCache;
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
use aws_smithy_runtime_api::client::auth::{
AuthScheme, AuthSchemeEndpointConfig, AuthSchemeId, Sign,
};
use aws_smithy_runtime_api::client::identity::{
Identity, IdentityFuture, ResolveIdentity, SharedIdentityResolver,
};
use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse, OrchestratorError};
use aws_smithy_runtime_api::client::runtime_components::{
GetIdentityResolver, RuntimeComponents, RuntimeComponentsBuilder,
};
use aws_smithy_runtime_api::client::runtime_plugin::{RuntimePlugin, SharedRuntimePlugin};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::config_bag::ConfigBag;
use http::{HeaderValue, Uri};
use std::borrow::Cow;
use std::fmt;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
const TOKEN_REFRESH_BUFFER: Duration = Duration::from_secs(120);
const X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS: &str = "x-aws-ec2-metadata-token-ttl-seconds";
const X_AWS_EC2_METADATA_TOKEN: &str = "x-aws-ec2-metadata-token";
const IMDS_TOKEN_AUTH_SCHEME: AuthSchemeId = AuthSchemeId::new(X_AWS_EC2_METADATA_TOKEN);
#[derive(Debug)]
struct TtlToken {
value: HeaderValue,
ttl: Duration,
}
#[derive(Clone)]
struct Token {
value: HeaderValue,
expiry: SystemTime,
}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Token")
.field("value", &"** redacted **")
.field("expiry", &self.expiry)
.finish()
}
}
#[derive(Debug)]
pub(super) struct TokenRuntimePlugin {
components: RuntimeComponentsBuilder,
}
impl TokenRuntimePlugin {
pub(super) fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
components: RuntimeComponentsBuilder::new("TokenRuntimePlugin")
.with_auth_scheme(TokenAuthScheme::new())
.with_auth_scheme_option_resolver(Some(StaticAuthSchemeOptionResolver::new(vec![
IMDS_TOKEN_AUTH_SCHEME,
])))
.with_identity_cache(Some(IdentityCache::no_cache()))
.with_identity_resolver(
IMDS_TOKEN_AUTH_SCHEME,
TokenResolver::new(common_plugin, token_ttl),
),
}
}
}
impl RuntimePlugin for TokenRuntimePlugin {
fn runtime_components(
&self,
_current_components: &RuntimeComponentsBuilder,
) -> Cow<'_, RuntimeComponentsBuilder> {
Cow::Borrowed(&self.components)
}
}
#[derive(Debug)]
struct TokenResolverInner {
cache: ExpiringCache<Token, ImdsError>,
refresh: Operation<(), TtlToken, TokenError>,
}
#[derive(Clone, Debug)]
struct TokenResolver {
inner: Arc<TokenResolverInner>,
}
impl TokenResolver {
fn new(common_plugin: SharedRuntimePlugin, token_ttl: Duration) -> Self {
Self {
inner: Arc::new(TokenResolverInner {
cache: ExpiringCache::new(TOKEN_REFRESH_BUFFER),
refresh: Operation::builder()
.service_name("imds")
.operation_name("get-token")
.runtime_plugin(common_plugin)
.no_auth()
.with_connection_poisoning()
.serializer(move |_| {
Ok(http::Request::builder()
.method("PUT")
.uri(Uri::from_static("/latest/api/token"))
.header(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS, token_ttl.as_secs())
.body(SdkBody::empty())
.expect("valid HTTP request")
.try_into()
.unwrap())
})
.deserializer(move |response| {
parse_token_response(response).map_err(OrchestratorError::operation)
})
.build(),
}),
}
}
async fn get_token(
&self,
time_source: SharedTimeSource,
) -> Result<(Token, SystemTime), ImdsError> {
let result = self.inner.refresh.invoke(()).await;
let now = time_source.now();
result
.map(|token| {
let token = Token {
value: token.value,
expiry: now + token.ttl,
};
let expiry = token.expiry;
(token, expiry)
})
.map_err(ImdsError::failed_to_load_token)
}
}
fn parse_token_response(response: &HttpResponse) -> Result<TtlToken, TokenError> {
match response.status().as_u16() {
400 => return Err(TokenErrorKind::InvalidParameters.into()),
403 => return Err(TokenErrorKind::Forbidden.into()),
_ => {}
}
let mut value =
HeaderValue::from_bytes(response.body().bytes().expect("non-streaming response"))
.map_err(|_| TokenErrorKind::InvalidToken)?;
value.set_sensitive(true);
let ttl: u64 = response
.headers()
.get(X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS)
.ok_or(TokenErrorKind::NoTtl)?
.parse()
.map_err(|_parse_error| TokenErrorKind::InvalidTtl)?;
Ok(TtlToken {
value,
ttl: Duration::from_secs(ttl),
})
}
impl ResolveIdentity for TokenResolver {
fn resolve_identity<'a>(
&'a self,
components: &'a RuntimeComponents,
_config_bag: &'a ConfigBag,
) -> IdentityFuture<'a> {
let time_source = components
.time_source()
.expect("time source required for IMDS token caching");
IdentityFuture::new(async {
let now = time_source.now();
let preloaded_token = self.inner.cache.yield_or_clear_if_expired(now).await;
let token = match preloaded_token {
Some(token) => {
tracing::trace!(
buffer_time=?TOKEN_REFRESH_BUFFER,
expiration=?token.expiry,
now=?now,
"loaded IMDS token from cache");
Ok(token)
}
None => {
tracing::debug!("IMDS token cache miss");
self.inner
.cache
.get_or_load(|| async { self.get_token(time_source).await })
.await
}
}?;
let expiry = token.expiry;
Ok(Identity::new(token, Some(expiry)))
})
}
}
#[derive(Debug)]
struct TokenAuthScheme {
signer: TokenSigner,
}
impl TokenAuthScheme {
fn new() -> Self {
Self {
signer: TokenSigner,
}
}
}
impl AuthScheme for TokenAuthScheme {
fn scheme_id(&self) -> AuthSchemeId {
IMDS_TOKEN_AUTH_SCHEME
}
fn identity_resolver(
&self,
identity_resolvers: &dyn GetIdentityResolver,
) -> Option<SharedIdentityResolver> {
identity_resolvers.identity_resolver(IMDS_TOKEN_AUTH_SCHEME)
}
fn signer(&self) -> &dyn Sign {
&self.signer
}
}
#[derive(Debug)]
struct TokenSigner;
impl Sign for TokenSigner {
fn sign_http_request(
&self,
request: &mut HttpRequest,
identity: &Identity,
_auth_scheme_endpoint_config: AuthSchemeEndpointConfig<'_>,
_runtime_components: &RuntimeComponents,
_config_bag: &ConfigBag,
) -> Result<(), BoxError> {
let token = identity.data::<Token>().expect("correct type");
request
.headers_mut()
.append(X_AWS_EC2_METADATA_TOKEN, token.value.clone());
Ok(())
}
}