reqwest_retry/
middleware.rs

1//! `RetryTransientMiddleware` implements retrying requests on transient errors.
2use crate::retryable_strategy::RetryableStrategy;
3use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy};
4use anyhow::anyhow;
5use chrono::Utc;
6use reqwest::{Request, Response};
7use reqwest_middleware::{Error, Middleware, Next, Result};
8use retry_policies::RetryPolicy;
9use task_local_extensions::Extensions;
10
11/// `RetryTransientMiddleware` offers retry logic for requests that fail in a transient manner
12/// and can be safely executed again.
13///
14/// Currently, it allows setting a [RetryPolicy][retry_policies::RetryPolicy] algorithm for calculating the __wait_time__
15/// between each request retry.
16///
17///```rust
18///     use reqwest_middleware::ClientBuilder;
19///     use reqwest_retry::{RetryTransientMiddleware, policies::ExponentialBackoff};
20///     use reqwest::Client;
21///
22///     // We create a ExponentialBackoff retry policy which implements `RetryPolicy`.
23///     let retry_policy = ExponentialBackoff {
24///         /// How many times the policy will tell the middleware to retry the request.
25///         max_n_retries: 3,
26///         max_retry_interval: std::time::Duration::from_millis(30),
27///         min_retry_interval: std::time::Duration::from_millis(100),
28///         backoff_exponent: 2,
29///     };
30///
31///     let retry_transient_middleware = RetryTransientMiddleware::new_with_policy(retry_policy);
32///     let client = ClientBuilder::new(Client::new()).with(retry_transient_middleware).build();
33///```
34///
35/// # Note
36///
37/// This middleware always errors when given requests with streaming bodies, before even executing
38/// the request. When this happens you'll get an [`Error::Middleware`] with the message
39/// 'Request object is not clonable. Are you passing a streaming body?'.
40///
41/// Some workaround suggestions:
42/// * If you can fit the data in memory, you can instead build static request bodies e.g. with
43/// `Body`'s `From<String>` or `From<Bytes>` implementations.
44/// * You can wrap this middleware in a custom one which skips retries for streaming requests.
45/// * You can write a custom retry middleware that builds new streaming requests from the data
46/// source directly, avoiding the issue of streaming requests not being clonable.
47pub struct RetryTransientMiddleware<
48    T: RetryPolicy + Send + Sync + 'static,
49    R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
50> {
51    retry_policy: T,
52    retryable_strategy: R,
53}
54
55impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
56    /// Construct `RetryTransientMiddleware` with  a [retry_policy][RetryPolicy].
57    pub fn new_with_policy(retry_policy: T) -> Self {
58        Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
59    }
60}
61
62impl<T, R> RetryTransientMiddleware<T, R>
63where
64    T: RetryPolicy + Send + Sync,
65    R: RetryableStrategy + Send + Sync,
66{
67    /// Construct `RetryTransientMiddleware` with  a [retry_policy][RetryPolicy] and [retryable_strategy](RetryableStrategy).
68    pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
69        Self {
70            retry_policy,
71            retryable_strategy,
72        }
73    }
74}
75
76#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
77#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
78impl<T, R> Middleware for RetryTransientMiddleware<T, R>
79where
80    T: RetryPolicy + Send + Sync,
81    R: RetryableStrategy + Send + Sync + 'static,
82{
83    async fn handle(
84        &self,
85        req: Request,
86        extensions: &mut Extensions,
87        next: Next<'_>,
88    ) -> Result<Response> {
89        // TODO: Ideally we should create a new instance of the `Extensions` map to pass
90        // downstream. This will guard against previous retries poluting `Extensions`.
91        // That is, we only return what's populated in the typemap for the last retry attempt
92        // and copy those into the the `global` Extensions map.
93        self.execute_with_retry(req, next, extensions).await
94    }
95}
96
97impl<T, R> RetryTransientMiddleware<T, R>
98where
99    T: RetryPolicy + Send + Sync,
100    R: RetryableStrategy + Send + Sync,
101{
102    /// This function will try to execute the request, if it fails
103    /// with an error classified as transient it will call itself
104    /// to retry the request.
105    async fn execute_with_retry<'a>(
106        &'a self,
107        req: Request,
108        next: Next<'a>,
109        ext: &'a mut Extensions,
110    ) -> Result<Response> {
111        let mut n_past_retries = 0;
112        loop {
113            // Cloning the request object before-the-fact is not ideal..
114            // However, if the body of the request is not static, e.g of type `Bytes`,
115            // the Clone operation should be of constant complexity and not O(N)
116            // since the byte abstraction is a shared pointer over a buffer.
117            let duplicate_request = req.try_clone().ok_or_else(|| {
118                Error::Middleware(anyhow!(
119                    "Request object is not clonable. Are you passing a streaming body?".to_string()
120                ))
121            })?;
122
123            let result = next.clone().run(duplicate_request, ext).await;
124
125            // We classify the response which will return None if not
126            // errors were returned.
127            break match self.retryable_strategy.handle(&result) {
128                Some(Retryable::Transient) => {
129                    // If the response failed and the error type was transient
130                    // we can safely try to retry the request.
131                    let retry_decision = self.retry_policy.should_retry(n_past_retries);
132                    if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
133                        let duration = (execute_after - Utc::now())
134                            .to_std()
135                            .map_err(Error::middleware)?;
136                        // Sleep the requested amount before we try again.
137                        tracing::warn!(
138                            "Retry attempt #{}. Sleeping {:?} before the next attempt",
139                            n_past_retries,
140                            duration
141                        );
142                        #[cfg(not(target_arch = "wasm32"))]
143                        tokio::time::sleep(duration).await;
144                        #[cfg(target_arch = "wasm32")]
145                        wasm_timer::Delay::new(duration)
146                            .await
147                            .expect("failed sleeping");
148
149                        n_past_retries += 1;
150                        continue;
151                    } else {
152                        result
153                    }
154                }
155                Some(_) | None => result,
156            };
157        }
158    }
159}