use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::retries::classifiers::{
ClassifyRetry, RetryAction, RetryClassifierPriority, SharedRetryClassifier,
};
use aws_smithy_types::retry::ProvideErrorKind;
use std::borrow::Cow;
use std::error::Error as StdError;
use std::marker::PhantomData;
#[derive(Debug, Default)]
pub struct ModeledAsRetryableClassifier<E> {
_inner: PhantomData<E>,
}
impl<E> ModeledAsRetryableClassifier<E> {
pub fn new() -> Self {
Self {
_inner: PhantomData,
}
}
pub fn priority() -> RetryClassifierPriority {
RetryClassifierPriority::modeled_as_retryable_classifier()
}
}
impl<E> ClassifyRetry for ModeledAsRetryableClassifier<E>
where
E: StdError + ProvideErrorKind + Send + Sync + 'static,
{
fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
let output_or_error = ctx.output_or_error();
let error = match output_or_error {
Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
Some(Err(err)) => err,
};
error
.as_operation_error()
.and_then(|err| err.downcast_ref::<E>())
.and_then(|err| err.retryable_error_kind().map(RetryAction::retryable_error))
.unwrap_or_default()
}
fn name(&self) -> &'static str {
"Errors Modeled As Retryable"
}
fn priority(&self) -> RetryClassifierPriority {
Self::priority()
}
}
#[derive(Debug, Default)]
pub struct TransientErrorClassifier<E> {
_inner: PhantomData<E>,
}
impl<E> TransientErrorClassifier<E> {
pub fn new() -> Self {
Self {
_inner: PhantomData,
}
}
pub fn priority() -> RetryClassifierPriority {
RetryClassifierPriority::transient_error_classifier()
}
}
impl<E> ClassifyRetry for TransientErrorClassifier<E>
where
E: StdError + Send + Sync + 'static,
{
fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
let output_or_error = ctx.output_or_error();
let error = match output_or_error {
Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
Some(Err(err)) => err,
};
if error.is_response_error() || error.is_timeout_error() {
RetryAction::transient_error()
} else if let Some(error) = error.as_connector_error() {
if error.is_timeout() || error.is_io() {
RetryAction::transient_error()
} else {
error
.as_other()
.map(RetryAction::retryable_error)
.unwrap_or_default()
}
} else {
RetryAction::NoActionIndicated
}
}
fn name(&self) -> &'static str {
"Retryable Smithy Errors"
}
fn priority(&self) -> RetryClassifierPriority {
Self::priority()
}
}
const TRANSIENT_ERROR_STATUS_CODES: &[u16] = &[500, 502, 503, 504];
#[derive(Debug)]
pub struct HttpStatusCodeClassifier {
retryable_status_codes: Cow<'static, [u16]>,
}
impl Default for HttpStatusCodeClassifier {
fn default() -> Self {
Self::new_from_codes(TRANSIENT_ERROR_STATUS_CODES.to_owned())
}
}
impl HttpStatusCodeClassifier {
pub fn new_from_codes(retryable_status_codes: impl Into<Cow<'static, [u16]>>) -> Self {
Self {
retryable_status_codes: retryable_status_codes.into(),
}
}
pub fn priority() -> RetryClassifierPriority {
RetryClassifierPriority::http_status_code_classifier()
}
}
impl ClassifyRetry for HttpStatusCodeClassifier {
fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
let is_retryable = ctx
.response()
.map(|res| res.status().into())
.map(|status| self.retryable_status_codes.contains(&status))
.unwrap_or_default();
if is_retryable {
RetryAction::transient_error()
} else {
RetryAction::NoActionIndicated
}
}
fn name(&self) -> &'static str {
"HTTP Status Code"
}
fn priority(&self) -> RetryClassifierPriority {
Self::priority()
}
}
pub fn run_classifiers_on_ctx(
classifiers: impl Iterator<Item = SharedRetryClassifier>,
ctx: &InterceptorContext,
) -> RetryAction {
let mut result = RetryAction::NoActionIndicated;
for classifier in classifiers {
let new_result = classifier.classify_retry(ctx);
if new_result == RetryAction::NoActionIndicated {
continue;
}
tracing::trace!(
"Classifier '{}' set the result of classification to '{}'",
classifier.name(),
new_result
);
result = new_result;
if result == RetryAction::RetryForbidden {
tracing::trace!("retry classification ending early because a `RetryAction::RetryForbidden` was emitted",);
break;
}
}
result
}
#[cfg(test)]
mod test {
use crate::client::retries::classifiers::{
HttpStatusCodeClassifier, ModeledAsRetryableClassifier,
};
use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, InterceptorContext};
use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
use std::fmt;
use super::TransientErrorClassifier;
#[derive(Debug, PartialEq, Eq, Clone)]
struct UnmodeledError;
impl fmt::Display for UnmodeledError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UnmodeledError")
}
}
impl std::error::Error for UnmodeledError {}
#[test]
fn classify_by_response_status() {
let policy = HttpStatusCodeClassifier::default();
let res = http::Response::builder()
.status(500)
.body("error!")
.unwrap()
.map(SdkBody::from);
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_response(res.try_into().unwrap());
assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error());
}
#[test]
fn classify_by_response_status_not_retryable() {
let policy = HttpStatusCodeClassifier::default();
let res = http::Response::builder()
.status(408)
.body("error!")
.unwrap()
.map(SdkBody::from);
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_response(res.try_into().unwrap());
assert_eq!(policy.classify_retry(&ctx), RetryAction::NoActionIndicated);
}
#[test]
fn classify_by_error_kind() {
#[derive(Debug)]
struct RetryableError;
impl fmt::Display for RetryableError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Some retryable error")
}
}
impl ProvideErrorKind for RetryableError {
fn retryable_error_kind(&self) -> Option<ErrorKind> {
Some(ErrorKind::ClientError)
}
fn code(&self) -> Option<&str> {
unimplemented!()
}
}
impl std::error::Error for RetryableError {}
let policy = ModeledAsRetryableClassifier::<RetryableError>::new();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
RetryableError,
))));
assert_eq!(policy.classify_retry(&ctx), RetryAction::client_error(),);
}
#[test]
fn classify_response_error() {
let policy = TransientErrorClassifier::<UnmodeledError>::new();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::response(
"I am a response error".into(),
)));
assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
}
#[test]
fn test_timeout_error() {
let policy = TransientErrorClassifier::<UnmodeledError>::new();
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Err(OrchestratorError::timeout(
"I am a timeout error".into(),
)));
assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
}
}