reqwest_retry/middleware.rs
1//! `RetryTransientMiddleware` implements retrying requests on transient errors.
2use crate::retryable_strategy::RetryableStrategy;
3use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy};
4use anyhow::anyhow;
5use chrono::Utc;
6use reqwest::{Request, Response};
7use reqwest_middleware::{Error, Middleware, Next, Result};
8use retry_policies::RetryPolicy;
9use task_local_extensions::Extensions;
10
11/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
12/// and can be safely executed again.
13///
14/// Currently, it allows setting a [RetryPolicy][retry_policies::RetryPolicy] algorithm for calculating the __wait_time__
15/// between each request retry.
16///
17///```rust
18/// use reqwest_middleware::ClientBuilder;
19/// use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
20/// use reqwest::Client;
21///
22/// // We create a ExponentialBackoff retry policy which implements `RetryPolicy`.
23/// let retry_policy = ExponentialBackoff {
24/// /// How many times the policy will tell the middleware to retry the request.
25/// max_n_retries: 3,
26/// max_retry_interval: std::time::Duration::from_millis(30),
27/// min_retry_interval: std::time::Duration::from_millis(100),
28/// backoff_exponent: 2,
29/// };
30///
31/// let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
32/// let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
33///```
34///
35/// # Note
36///
37/// This middleware always errors when given requests with streaming bodies, before even executing
38/// the request. When this happens you'll get an [`Error::Middleware`] with the message
39/// 'Request object is not clonable. Are you passing a streaming body?'.
40///
41/// Some workaround suggestions:
42/// * If you can fit the data in memory, you can instead build static request bodies e.g. with
43/// `Body`'s `From<String>` or `From<Bytes>` implementations.
44/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
45/// * You can write a custom retry middleware that builds new streaming requests from the data
46/// source directly, avoiding the issue of streaming requests not being clonable.
47pub struct RetryTransientMiddleware<
48 T: RetryPolicy + Send + Sync + 'static,
49 R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
50> {
51 retry_policy: T,
52 retryable_strategy: R,
53}
54
55impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
56 /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy].
57 pub fn new_with_policy(retry_policy: T) -> Self {
58 Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
59 }
60}
61
62impl<T, R> RetryTransientMiddleware<T, R>
63where
64 T: RetryPolicy + Send + Sync,
65 R: RetryableStrategy + Send + Sync,
66{
67 /// Construct `RetryTransientMiddleware` with a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
68 pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
69 Self {
70 retry_policy,
71 retryable_strategy,
72 }
73 }
74}
75
76#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
77#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
78impl<T, R> Middleware for RetryTransientMiddleware<T, R>
79where
80 T: RetryPolicy + Send + Sync,
81 R: RetryableStrategy + Send + Sync + 'static,
82{
83 async fn handle(
84 &self,
85 req: Request,
86 extensions: &mut Extensions,
87 next: Next<'_>,
88 ) -> Result<Response> {
89 // TODO: Ideally we should create a new instance of the `Extensions` map to pass
90 // downstream. This will guard against previous retries poluting `Extensions`.
91 // That is, we only return what's populated in the typemap for the last retry attempt
92 // and copy those into the the `global` Extensions map.
93 self.execute_with_retry(req, next, extensions).await
94 }
95}
96
97impl<T, R> RetryTransientMiddleware<T, R>
98where
99 T: RetryPolicy + Send + Sync,
100 R: RetryableStrategy + Send + Sync,
101{
102 /// This function will try to execute the request, if it fails
103 /// with an error classified as transient it will call itself
104 /// to retry the request.
105 async fn execute_with_retry<'a>(
106 &'a self,
107 req: Request,
108 next: Next<'a>,
109 ext: &'a mut Extensions,
110 ) -> Result<Response> {
111 let mut n_past_retries = 0;
112 loop {
113 // Cloning the request object before-the-fact is not ideal..
114 // However, if the body of the request is not static, e.g of type `Bytes`,
115 // the Clone operation should be of constant complexity and not O(N)
116 // since the byte abstraction is a shared pointer over a buffer.
117 let duplicate_request = req.try_clone().ok_or_else(|| {
118 Error::Middleware(anyhow!(
119 "Request object is not clonable. Are you passing a streaming body?".to_string()
120 ))
121 })?;
122
123 let result = next.clone().run(duplicate_request, ext).await;
124
125 // We classify the response which will return None if not
126 // errors were returned.
127 break match self.retryable_strategy.handle(&result) {
128 Some(Retryable::Transient) => {
129 // If the response failed and the error type was transient
130 // we can safely try to retry the request.
131 let retry_decision = self.retry_policy.should_retry(n_past_retries);
132 if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
133 let duration = (execute_after - Utc::now())
134 .to_std()
135 .map_err(Error::middleware)?;
136 // Sleep the requested amount before we try again.
137 tracing::warn!(
138 "Retry attempt #{}. Sleeping {:?} before the next attempt",
139 n_past_retries,
140 duration
141 );
142 #[cfg(not(target_arch = "wasm32"))]
143 tokio::time::sleep(duration).await;
144 #[cfg(target_arch = "wasm32")]
145 wasm_timer::Delay::new(duration)
146 .await
147 .expect("failed sleeping");
148
149 n_past_retries += 1;
150 continue;
151 } else {
152 result
153 }
154 }
155 Some(_) | None => result,
156 };
157 }
158 }
159}