reqwest_retry/retryable_strategy.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
use crate::retryable::Retryable;
use http::StatusCode;
use reqwest_middleware::Error;
/// A strategy to create a [`Retryable`] from a [`Result<reqwest::Response, reqwest_middleware::Error>`]
///
/// A [`RetryableStrategy`] has a single `handler` functions.
/// The result of calling the request could be:
/// - [`reqwest::Response`] In case the request has been sent and received correctly
/// This could however still mean that the server responded with a erroneous response.
/// For example a HTTP statuscode of 500
/// - [`reqwest_middleware::Error`] In this case the request actually failed.
/// This could, for example, be caused by a timeout on the connection.
///
/// Example:
///
/// ```
/// use reqwest_retry::{default_on_request_failure, policies::ExponentialBackoff, Retryable, RetryableStrategy, RetryTransientMiddleware};
/// use reqwest::{Request, Response};
/// use reqwest_middleware::{ClientBuilder, Middleware, Next, Result};
/// use task_local_extensions::Extensions;
///
/// // Log each request to show that the requests will be retried
/// struct LoggingMiddleware;
///
/// #[async_trait::async_trait]
/// impl Middleware for LoggingMiddleware {
/// async fn handle(
/// &self,
/// req: Request,
/// extensions: &mut Extensions,
/// next: Next<'_>,
/// ) -> Result<Response> {
/// println!("Request started {}", req.url());
/// let res = next.run(req, extensions).await;
/// println!("Request finished");
/// res
/// }
/// }
///
/// // Just a toy example, retry when the successful response code is 201, else do nothing.
/// struct Retry201;
/// impl RetryableStrategy for Retry201 {
/// fn handle(&self, res: &Result<reqwest::Response>) -> Option<Retryable> {
/// match res {
/// // retry if 201
/// Ok(success) if success.status() == 201 => Some(Retryable::Transient),
/// // otherwise do not retry a successful request
/// Ok(success) => None,
/// // but maybe retry a request failure
/// Err(error) => default_on_request_failure(error),
/// }
/// }
/// }
///
/// #[tokio::main]
/// async fn main() {
/// // Exponential backoff with max 2 retries
/// let retry_policy = ExponentialBackoff::builder()
/// .build_with_max_retries(2);
///
/// // Create the actual middleware, with the exponential backoff and custom retry stategy.
/// let ret_s = RetryTransientMiddleware::new_with_policy_and_strategy(
/// retry_policy,
/// Retry201,
/// );
///
/// let client = ClientBuilder::new(reqwest::Client::new())
/// // Retry failed requests.
/// .with(ret_s)
/// // Log the requests
/// .with(LoggingMiddleware)
/// .build();
///
/// // Send request which should get a 201 response. So it will be retried
/// let r = client
/// .get("https://httpbin.org/status/201")
/// .send()
/// .await;
/// println!("{:?}", r);
///
/// // Send request which should get a 200 response. So it will not be retried
/// let r = client
/// .get("https://httpbin.org/status/200")
/// .send()
/// .await;
/// println!("{:?}", r);
/// }
/// ```
pub trait RetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable>;
}
/// The default [`RetryableStrategy`] for [`RetryTransientMiddleware`](crate::RetryTransientMiddleware).
pub struct DefaultRetryableStrategy;
impl RetryableStrategy for DefaultRetryableStrategy {
fn handle(&self, res: &Result<reqwest::Response, Error>) -> Option<Retryable> {
match res {
Ok(success) => default_on_request_success(success),
Err(error) => default_on_request_failure(error),
}
}
}
/// Default request success retry strategy.
///
/// Will only retry if:
/// * The status was 5XX (server error)
/// * The status was 408 (request timeout) or 429 (too many requests)
///
/// Note that success here means that the request finished without interruption, not that it was logically OK.
pub fn default_on_request_success(success: &reqwest::Response) -> Option<Retryable> {
let status = success.status();
if status.is_server_error() {
Some(Retryable::Transient)
} else if status.is_client_error()
&& status != StatusCode::REQUEST_TIMEOUT
&& status != StatusCode::TOO_MANY_REQUESTS
{
Some(Retryable::Fatal)
} else if status.is_success() {
None
} else if status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
}
/// Default request failure retry strategy.
///
/// Will only retry if the request failed due to a network error
pub fn default_on_request_failure(error: &Error) -> Option<Retryable> {
match error {
// If something fails in the middleware we're screwed.
Error::Middleware(_) => Some(Retryable::Fatal),
Error::Reqwest(error) => {
#[cfg(not(target_arch = "wasm32"))]
let is_connect = error.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if error.is_timeout() || is_connect {
Some(Retryable::Transient)
} else if error.is_body()
|| error.is_decode()
|| error.is_builder()
|| error.is_redirect()
{
Some(Retryable::Fatal)
} else if error.is_request() {
// It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
// Here we check if the Reqwest error was originated by hyper and map it consistently.
#[cfg(not(target_arch = "wasm32"))]
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
// The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
// This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
// We can safely retry the call, hence marking this error as [`Retryable::Transient`].
// Instead hyper::Error(Canceled) is raised when the connection is
// gracefully closed on the server side.
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
Some(Retryable::Transient)
// Try and downcast the hyper error to io::Error if that is the
// underlying error, and try and classify it.
} else if let Some(io_error) =
get_source_error_type::<std::io::Error>(hyper_error)
{
Some(classify_io_error(io_error))
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
#[cfg(target_arch = "wasm32")]
Some(Retryable::Fatal)
} else {
// We omit checking if error.is_status() since we check that already.
// However, if Response::error_for_status is used the status will still
// remain in the response object.
None
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
fn classify_io_error(error: &std::io::Error) -> Retryable {
match error.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
Retryable::Transient
}
_ => Retryable::Fatal,
}
}
/// Downcasts the given err source into T.
#[cfg(not(target_arch = "wasm32"))]
fn get_source_error_type<T: std::error::Error + 'static>(
err: &dyn std::error::Error,
) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(err) = err.downcast_ref::<T>() {
return Some(err);
}
source = err.source();
}
None
}