use aws_sigv4::http_request::{
PayloadChecksumKind, PercentEncodingMode, SessionTokenMode, SignableBody, SignatureLocation,
SigningInstructions, SigningSettings, UriPathNormalizationMode,
};
use aws_smithy_runtime_api::box_error::BoxError;
use aws_smithy_runtime_api::client::auth::AuthSchemeEndpointConfig;
use aws_smithy_runtime_api::client::identity::Identity;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_types::config_bag::{ConfigBag, Storable, StoreReplace};
use aws_smithy_types::Document;
use aws_types::region::{Region, SigningRegion, SigningRegionSet};
use aws_types::SigningName;
use std::error::Error as StdError;
use std::fmt;
use std::time::Duration;
pub mod sigv4;
#[cfg(feature = "sigv4a")]
pub mod sigv4a;
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum HttpSignatureType {
HttpRequestHeaders,
HttpRequestQueryParams,
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub struct SigningOptions {
pub double_uri_encode: bool,
pub content_sha256_header: bool,
pub normalize_uri_path: bool,
pub omit_session_token: bool,
pub payload_override: Option<SignableBody<'static>>,
pub signature_type: HttpSignatureType,
pub signing_optional: bool,
pub expires_in: Option<Duration>,
}
impl Default for SigningOptions {
fn default() -> Self {
Self {
double_uri_encode: true,
content_sha256_header: false,
normalize_uri_path: true,
omit_session_token: false,
payload_override: None,
signature_type: HttpSignatureType::HttpRequestHeaders,
signing_optional: false,
expires_in: None,
}
}
}
pub(crate) type SessionTokenNameOverrideFn = Box<
dyn Fn(&SigningSettings, &ConfigBag) -> Result<Option<&'static str>, BoxError>
+ Send
+ Sync
+ 'static,
>;
pub struct SigV4SessionTokenNameOverride {
name_override: SessionTokenNameOverrideFn,
}
impl SigV4SessionTokenNameOverride {
pub fn new<F>(name_override: F) -> Self
where
F: Fn(&SigningSettings, &ConfigBag) -> Result<Option<&'static str>, BoxError>
+ Send
+ Sync
+ 'static,
{
Self {
name_override: Box::new(name_override),
}
}
pub fn name_override(
&self,
settings: &SigningSettings,
config_bag: &ConfigBag,
) -> Result<Option<&'static str>, BoxError> {
(self.name_override)(settings, config_bag)
}
}
impl fmt::Debug for SigV4SessionTokenNameOverride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SessionTokenNameOverride").finish()
}
}
impl Storable for SigV4SessionTokenNameOverride {
type Storer = StoreReplace<Self>;
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct SigV4OperationSigningConfig {
pub region: Option<SigningRegion>,
pub region_set: Option<SigningRegionSet>,
pub name: Option<SigningName>,
pub signing_options: SigningOptions,
}
impl Storable for SigV4OperationSigningConfig {
type Storer = StoreReplace<Self>;
}
fn settings(operation_config: &SigV4OperationSigningConfig) -> SigningSettings {
let mut settings = SigningSettings::default();
settings.percent_encoding_mode = if operation_config.signing_options.double_uri_encode {
PercentEncodingMode::Double
} else {
PercentEncodingMode::Single
};
settings.payload_checksum_kind = if operation_config.signing_options.content_sha256_header {
PayloadChecksumKind::XAmzSha256
} else {
PayloadChecksumKind::NoHeader
};
settings.uri_path_normalization_mode = if operation_config.signing_options.normalize_uri_path {
UriPathNormalizationMode::Enabled
} else {
UriPathNormalizationMode::Disabled
};
settings.session_token_mode = if operation_config.signing_options.omit_session_token {
SessionTokenMode::Exclude
} else {
SessionTokenMode::Include
};
settings.signature_location = match operation_config.signing_options.signature_type {
HttpSignatureType::HttpRequestHeaders => SignatureLocation::Headers,
HttpSignatureType::HttpRequestQueryParams => SignatureLocation::QueryParams,
};
settings.expires_in = operation_config.signing_options.expires_in;
settings
}
#[derive(Debug)]
enum SigV4SigningError {
MissingOperationSigningConfig,
MissingSigningRegion,
#[cfg(feature = "sigv4a")]
MissingSigningRegionSet,
MissingSigningName,
WrongIdentityType(Identity),
BadTypeInEndpointAuthSchemeConfig(&'static str),
}
impl fmt::Display for SigV4SigningError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use SigV4SigningError::*;
let mut w = |s| f.write_str(s);
match self {
MissingOperationSigningConfig => w("missing operation signing config"),
MissingSigningRegion => w("missing signing region"),
#[cfg(feature = "sigv4a")]
MissingSigningRegionSet => w("missing signing region set"),
MissingSigningName => w("missing signing name"),
WrongIdentityType(identity) => {
write!(f, "wrong identity type for SigV4/sigV4a. Expected AWS credentials but got `{identity:?}`")
}
BadTypeInEndpointAuthSchemeConfig(field_name) => {
write!(
f,
"unexpected type for `{field_name}` in endpoint auth scheme config",
)
}
}
}
}
impl StdError for SigV4SigningError {}
fn extract_endpoint_auth_scheme_signing_name(
endpoint_config: &AuthSchemeEndpointConfig<'_>,
) -> Result<Option<SigningName>, SigV4SigningError> {
use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType;
match extract_field_from_endpoint_config("signingName", endpoint_config) {
Some(Document::String(s)) => Ok(Some(SigningName::from(s.to_string()))),
None => Ok(None),
_ => Err(UnexpectedType("signingName")),
}
}
fn extract_endpoint_auth_scheme_signing_region(
endpoint_config: &AuthSchemeEndpointConfig<'_>,
) -> Result<Option<SigningRegion>, SigV4SigningError> {
use SigV4SigningError::BadTypeInEndpointAuthSchemeConfig as UnexpectedType;
match extract_field_from_endpoint_config("signingRegion", endpoint_config) {
Some(Document::String(s)) => Ok(Some(SigningRegion::from(Region::new(s.clone())))),
None => Ok(None),
_ => Err(UnexpectedType("signingRegion")),
}
}
fn extract_field_from_endpoint_config<'a>(
field_name: &'static str,
endpoint_config: &'a AuthSchemeEndpointConfig<'_>,
) -> Option<&'a Document> {
endpoint_config
.as_document()
.and_then(Document::as_object)
.and_then(|config| config.get(field_name))
}
fn apply_signing_instructions(
instructions: SigningInstructions,
request: &mut HttpRequest,
) -> Result<(), BoxError> {
let (new_headers, new_query) = instructions.into_parts();
for header in new_headers.into_iter() {
let mut value = http::HeaderValue::from_str(header.value()).unwrap();
value.set_sensitive(header.sensitive());
request.headers_mut().insert(header.name(), value);
}
if !new_query.is_empty() {
let mut query = aws_smithy_http::query_writer::QueryWriter::new_from_string(request.uri())?;
for (name, value) in new_query {
query.insert(name, &value);
}
request.set_uri(query.build_uri())?;
}
Ok(())
}