aws_smithy_runtime/client/retries/
classifiers.rs1use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
7use aws_smithy_runtime_api::client::retries::classifiers::{
8    ClassifyRetry, RetryAction, RetryClassifierPriority, SharedRetryClassifier,
9};
10use aws_smithy_types::retry::ProvideErrorKind;
11use std::borrow::Cow;
12use std::error::Error as StdError;
13use std::marker::PhantomData;
14
15#[derive(Debug, Default)]
17pub struct ModeledAsRetryableClassifier<E> {
18    _inner: PhantomData<E>,
19}
20
21impl<E> ModeledAsRetryableClassifier<E> {
22    pub fn new() -> Self {
24        Self {
25            _inner: PhantomData,
26        }
27    }
28
29    pub fn priority() -> RetryClassifierPriority {
31        RetryClassifierPriority::modeled_as_retryable_classifier()
32    }
33}
34
35impl<E> ClassifyRetry for ModeledAsRetryableClassifier<E>
36where
37    E: StdError + ProvideErrorKind + Send + Sync + 'static,
38{
39    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
40        let output_or_error = ctx.output_or_error();
42        let error = match output_or_error {
44            Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
45            Some(Err(err)) => err,
46        };
47        error
49            .as_operation_error()
50            .and_then(|err| err.downcast_ref::<E>())
52            .and_then(|err| err.retryable_error_kind().map(RetryAction::retryable_error))
54            .unwrap_or_default()
55    }
56
57    fn name(&self) -> &'static str {
58        "Errors Modeled As Retryable"
59    }
60
61    fn priority(&self) -> RetryClassifierPriority {
62        Self::priority()
63    }
64}
65
66#[derive(Debug, Default)]
68pub struct TransientErrorClassifier<E> {
69    _inner: PhantomData<E>,
70}
71
72impl<E> TransientErrorClassifier<E> {
73    pub fn new() -> Self {
75        Self {
76            _inner: PhantomData,
77        }
78    }
79
80    pub fn priority() -> RetryClassifierPriority {
82        RetryClassifierPriority::transient_error_classifier()
83    }
84}
85
86impl<E> ClassifyRetry for TransientErrorClassifier<E>
87where
88    E: StdError + Send + Sync + 'static,
89{
90    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
91        let output_or_error = ctx.output_or_error();
93        let error = match output_or_error {
95            Some(Ok(_)) | None => return RetryAction::NoActionIndicated,
96            Some(Err(err)) => err,
97        };
98
99        if error.is_response_error() || error.is_timeout_error() {
100            RetryAction::transient_error()
101        } else if let Some(error) = error.as_connector_error() {
102            if error.is_timeout() || error.is_io() {
103                RetryAction::transient_error()
104            } else {
105                error
106                    .as_other()
107                    .map(RetryAction::retryable_error)
108                    .unwrap_or_default()
109            }
110        } else {
111            RetryAction::NoActionIndicated
112        }
113    }
114
115    fn name(&self) -> &'static str {
116        "Retryable Smithy Errors"
117    }
118
119    fn priority(&self) -> RetryClassifierPriority {
120        Self::priority()
121    }
122}
123
124const TRANSIENT_ERROR_STATUS_CODES: &[u16] = &[500, 502, 503, 504];
125
126#[derive(Debug)]
129pub struct HttpStatusCodeClassifier {
130    retryable_status_codes: Cow<'static, [u16]>,
131}
132
133impl Default for HttpStatusCodeClassifier {
134    fn default() -> Self {
135        Self::new_from_codes(TRANSIENT_ERROR_STATUS_CODES.to_owned())
136    }
137}
138
139impl HttpStatusCodeClassifier {
140    pub fn new_from_codes(retryable_status_codes: impl Into<Cow<'static, [u16]>>) -> Self {
144        Self {
145            retryable_status_codes: retryable_status_codes.into(),
146        }
147    }
148
149    pub fn priority() -> RetryClassifierPriority {
151        RetryClassifierPriority::http_status_code_classifier()
152    }
153}
154
155impl ClassifyRetry for HttpStatusCodeClassifier {
156    fn classify_retry(&self, ctx: &InterceptorContext) -> RetryAction {
157        let is_retryable = ctx
158            .response()
159            .map(|res| res.status().into())
160            .map(|status| self.retryable_status_codes.contains(&status))
161            .unwrap_or_default();
162
163        if is_retryable {
164            RetryAction::transient_error()
165        } else {
166            RetryAction::NoActionIndicated
167        }
168    }
169
170    fn name(&self) -> &'static str {
171        "HTTP Status Code"
172    }
173
174    fn priority(&self) -> RetryClassifierPriority {
175        Self::priority()
176    }
177}
178
179pub fn run_classifiers_on_ctx(
183    classifiers: impl Iterator<Item = SharedRetryClassifier>,
184    ctx: &InterceptorContext,
185) -> RetryAction {
186    let mut result = RetryAction::NoActionIndicated;
188
189    for classifier in classifiers {
190        let new_result = classifier.classify_retry(ctx);
191
192        if new_result == RetryAction::NoActionIndicated {
195            continue;
196        }
197
198        tracing::trace!(
200            "Classifier '{}' set the result of classification to '{}'",
201            classifier.name(),
202            new_result
203        );
204        result = new_result;
205
206        if result == RetryAction::RetryForbidden {
208            tracing::trace!("retry classification ending early because a `RetryAction::RetryForbidden` was emitted",);
209            break;
210        }
211    }
212
213    result
214}
215
216#[cfg(test)]
217mod test {
218    use crate::client::retries::classifiers::{
219        HttpStatusCodeClassifier, ModeledAsRetryableClassifier,
220    };
221    use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, InterceptorContext};
222    use aws_smithy_runtime_api::client::orchestrator::OrchestratorError;
223    use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
224    use aws_smithy_types::body::SdkBody;
225    use aws_smithy_types::retry::{ErrorKind, ProvideErrorKind};
226    use std::fmt;
227
228    use super::TransientErrorClassifier;
229
230    #[derive(Debug, PartialEq, Eq, Clone)]
231    struct UnmodeledError;
232
233    impl fmt::Display for UnmodeledError {
234        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235            write!(f, "UnmodeledError")
236        }
237    }
238
239    impl std::error::Error for UnmodeledError {}
240
241    #[test]
242    fn classify_by_response_status() {
243        let policy = HttpStatusCodeClassifier::default();
244        let res = http_02x::Response::builder()
245            .status(500)
246            .body("error!")
247            .unwrap()
248            .map(SdkBody::from);
249        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
250        ctx.set_response(res.try_into().unwrap());
251        assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error());
252    }
253
254    #[test]
255    fn classify_by_response_status_not_retryable() {
256        let policy = HttpStatusCodeClassifier::default();
257        let res = http_02x::Response::builder()
258            .status(408)
259            .body("error!")
260            .unwrap()
261            .map(SdkBody::from);
262        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
263        ctx.set_response(res.try_into().unwrap());
264        assert_eq!(policy.classify_retry(&ctx), RetryAction::NoActionIndicated);
265    }
266
267    #[test]
268    fn classify_by_error_kind() {
269        #[derive(Debug)]
270        struct RetryableError;
271
272        impl fmt::Display for RetryableError {
273            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274                write!(f, "Some retryable error")
275            }
276        }
277
278        impl ProvideErrorKind for RetryableError {
279            fn retryable_error_kind(&self) -> Option<ErrorKind> {
280                Some(ErrorKind::ClientError)
281            }
282
283            fn code(&self) -> Option<&str> {
284                unimplemented!()
286            }
287        }
288
289        impl std::error::Error for RetryableError {}
290
291        let policy = ModeledAsRetryableClassifier::<RetryableError>::new();
292        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
293        ctx.set_output_or_error(Err(OrchestratorError::operation(Error::erase(
294            RetryableError,
295        ))));
296
297        assert_eq!(policy.classify_retry(&ctx), RetryAction::client_error(),);
298    }
299
300    #[test]
301    fn classify_response_error() {
302        let policy = TransientErrorClassifier::<UnmodeledError>::new();
303        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
304        ctx.set_output_or_error(Err(OrchestratorError::response(
305            "I am a response error".into(),
306        )));
307        assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
308    }
309
310    #[test]
311    fn test_timeout_error() {
312        let policy = TransientErrorClassifier::<UnmodeledError>::new();
313        let mut ctx = InterceptorContext::new(Input::doesnt_matter());
314        ctx.set_output_or_error(Err(OrchestratorError::timeout(
315            "I am a timeout error".into(),
316        )));
317        assert_eq!(policy.classify_retry(&ctx), RetryAction::transient_error(),);
318    }
319}