retry_policies/policies/
exponential_backoff.rs
1use crate::{RetryDecision, RetryPolicy};
2use chrono::Utc;
3use rand::distributions::uniform::{UniformFloat, UniformSampler};
4use std::{cmp, time::Duration};
5
6const MIN_JITTER: f64 = 0.0;
7const MAX_JITTER: f64 = 3.0;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct ExponentialBackoff {
13 pub max_n_retries: u32,
15 pub min_retry_interval: Duration,
17 pub max_retry_interval: Duration,
19 pub backoff_exponent: u32,
26}
27
28impl ExponentialBackoff {
29 pub fn builder() -> ExponentialBackoffBuilder {
45 <_>::default()
46 }
47}
48
49impl RetryPolicy for ExponentialBackoff {
50 fn should_retry(&self, n_past_retries: u32) -> RetryDecision {
51 if n_past_retries >= self.max_n_retries {
52 RetryDecision::DoNotRetry
53 } else {
54 let unjittered_wait_for = self.min_retry_interval
55 * self
56 .backoff_exponent
57 .checked_pow(n_past_retries)
58 .unwrap_or(u32::MAX);
59 let jitter_factor =
60 UniformFloat::<f64>::sample_single(MIN_JITTER, MAX_JITTER, &mut rand::thread_rng());
61 let jittered_wait_for = unjittered_wait_for.mul_f64(jitter_factor);
62
63 let execute_after =
64 Utc::now() + clip_and_convert(jittered_wait_for, self.max_retry_interval);
65 RetryDecision::Retry { execute_after }
66 }
67 }
68}
69
70fn clip_and_convert(duration: Duration, max_duration: Duration) -> chrono::Duration {
72 chrono::Duration::from_std(cmp::min(duration, max_duration)).unwrap()
75}
76
77pub struct ExponentialBackoffBuilder {
78 min_retry_interval: Duration,
79 max_retry_interval: Duration,
80 backoff_exponent: u32,
81}
82
83impl Default for ExponentialBackoffBuilder {
84 fn default() -> Self {
85 Self {
86 min_retry_interval: Duration::from_secs(1),
87 max_retry_interval: Duration::from_secs(30 * 60),
88 backoff_exponent: 3,
89 }
90 }
91}
92
93impl ExponentialBackoffBuilder {
94 pub fn retry_bounds(
100 mut self,
101 min_retry_interval: Duration,
102 max_retry_interval: Duration,
103 ) -> Self {
104 assert!(
105 min_retry_interval <= max_retry_interval,
106 "The maximum interval between retries should be greater or equal than the minimum retry interval."
107 );
108 self.min_retry_interval = min_retry_interval;
109 self.max_retry_interval = max_retry_interval;
110 self
111 }
112
113 pub fn backoff_exponent(mut self, exponent: u32) -> Self {
117 self.backoff_exponent = exponent;
118 self
119 }
120
121 pub fn build_with_max_retries(self, n: u32) -> ExponentialBackoff {
125 ExponentialBackoff {
126 min_retry_interval: self.min_retry_interval,
127 max_retry_interval: self.max_retry_interval,
128 backoff_exponent: self.backoff_exponent,
129 max_n_retries: n,
130 }
131 }
132
133 pub fn build_with_total_retry_duration(self, total_duration: Duration) -> ExponentialBackoff {
157 let mut out = self.build_with_max_retries(0);
158
159 const MEAN_JITTER: f64 = (MIN_JITTER + MAX_JITTER) / 2.0;
160
161 let delays = (0u32..).into_iter().map(|n| {
162 let min_interval = out.min_retry_interval;
163 let backoff_factor = out.backoff_exponent.checked_pow(n).unwrap_or(u32::MAX);
164 let n_delay = (min_interval * backoff_factor).mul_f64(MEAN_JITTER);
165 cmp::min(n_delay, out.max_retry_interval)
166 });
167
168 let mut approx_total = Duration::from_secs(0);
169 for (n, delay) in delays.enumerate() {
170 approx_total += delay;
171 if approx_total >= total_duration {
172 out.max_n_retries = (n + 1) as _;
173 break;
174 } else if delay == out.max_retry_interval {
175 let remaining_s = (total_duration - approx_total).as_secs_f64();
177 let additional_tries = (remaining_s / delay.as_secs_f64()).ceil() as usize;
178 out.max_n_retries = (n + 1 + additional_tries) as _;
179 break;
180 }
181 }
182
183 out
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use fake::Fake;
191
192 fn get_retry_policy() -> ExponentialBackoff {
193 ExponentialBackoff {
194 max_n_retries: 6,
195 min_retry_interval: Duration::from_secs(1),
196 max_retry_interval: Duration::from_secs(5 * 60),
197 backoff_exponent: 3,
198 }
199 }
200
201 #[test]
202 fn if_n_past_retries_is_below_maximum_it_decides_to_retry() {
203 let policy = get_retry_policy();
205 let n_past_retries = (0..policy.max_n_retries).fake();
206 assert!(n_past_retries < policy.max_n_retries);
207
208 let decision = policy.should_retry(n_past_retries);
210
211 matches!(decision, RetryDecision::Retry { .. });
213 }
214
215 #[test]
216 fn if_n_past_retries_is_above_maximum_it_decides_to_mark_as_failed() {
217 let policy = get_retry_policy();
219 let n_past_retries = (policy.max_n_retries..).fake();
220 assert!(n_past_retries >= policy.max_n_retries);
221
222 let decision = policy.should_retry(n_past_retries);
224
225 matches!(decision, RetryDecision::DoNotRetry);
227 }
228
229 #[test]
230 fn maximum_retry_interval_is_never_exceeded() {
231 let policy = get_retry_policy();
233 let max_interval = chrono::Duration::from_std(policy.max_retry_interval).unwrap();
234
235 let decision = policy.should_retry(policy.max_n_retries - 1);
237
238 match decision {
240 RetryDecision::Retry { execute_after } => {
241 assert!((execute_after - Utc::now()) <= max_interval)
242 }
243 RetryDecision::DoNotRetry => panic!("Expected Retry decision."),
244 }
245 }
246
247 #[test]
248 fn overflow_backoff_exponent_does_not_cause_a_panic() {
249 let policy = ExponentialBackoff {
250 max_n_retries: u32::MAX,
251 backoff_exponent: 2,
252 ..get_retry_policy()
253 };
254 let max_interval = chrono::Duration::from_std(policy.max_retry_interval).unwrap();
255 let n_failed_attempts = u32::MAX - 1;
256
257 let decision = policy.should_retry(n_failed_attempts);
259
260 match decision {
262 RetryDecision::Retry { execute_after } => {
263 assert!((execute_after - Utc::now()) <= max_interval)
264 }
265 RetryDecision::DoNotRetry => panic!("Expected Retry decision."),
266 }
267 }
268
269 #[test]
270 #[should_panic]
271 fn builder_invalid_retry_bounds() {
272 ExponentialBackoff::builder().retry_bounds(Duration::from_secs(3), Duration::from_secs(2));
274 }
275}