retry_policies/
retry_policy.rs

1use std::time::{Duration, SystemTime};
2
3use rand::{
4    distr::uniform::{UniformFloat, UniformSampler},
5    Rng,
6};
7
8/// A policy for deciding whether and when to retry.
9pub trait RetryPolicy {
10    /// Determine if a task should be retried according to a retry policy.
11    fn should_retry(&self, request_start_time: SystemTime, n_past_retries: u32) -> RetryDecision;
12}
13
14/// Outcome of evaluating a retry policy for a failed task.
15#[derive(Debug)]
16pub enum RetryDecision {
17    /// Retry after the specified timestamp.
18    Retry { execute_after: SystemTime },
19    /// Give up.
20    DoNotRetry,
21}
22
23/// How to apply jitter to the retry intervals.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[non_exhaustive]
26pub enum Jitter {
27    /// Don't apply any jitter.
28    None,
29    /// Jitter between 0 and the calculated backoff duration.
30    Full,
31    /// Jitter between 50% of `min_retry_interval` and the calculated backoff duration.
32    Bounded,
33}
34
35impl Jitter {
36    /// The lower bound for the calculated interval, as a fraction of the minimum
37    /// interval.
38    const BOUNDED_MIN_BOUND_FRACTION: f64 = 0.5;
39
40    pub(crate) fn apply(
41        &self,
42        interval: Duration,
43        min_interval: Duration,
44        rng: &mut impl Rng,
45    ) -> Duration {
46        match self {
47            Jitter::None => interval,
48            Jitter::Full => {
49                let jitter_factor = UniformFloat::<f64>::sample_single(0.0, 1.0, rng)
50                    .expect("Sample range should be valid");
51
52                interval.mul_f64(jitter_factor)
53            }
54            Jitter::Bounded => {
55                let jitter_factor = UniformFloat::<f64>::sample_single(0.0, 1.0, rng)
56                    .expect("Sample range should be valid");
57
58                let jittered_wait_for = (interval
59                    - min_interval.mul_f64(Self::BOUNDED_MIN_BOUND_FRACTION))
60                .mul_f64(jitter_factor);
61
62                jittered_wait_for + min_interval.mul_f64(Self::BOUNDED_MIN_BOUND_FRACTION)
63            }
64        }
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use rand::{rngs::StdRng, SeedableRng};
71
72    use super::*;
73    use std::time::Duration;
74
75    const SEED: u64 = 3097268606784207815;
76
77    #[test]
78    fn test_jitter_none() {
79        let jitter = Jitter::None;
80        let min_interval = Duration::from_secs(5);
81        let interval = Duration::from_secs(10);
82        assert_eq!(
83            jitter.apply(interval, min_interval, &mut rand::rng()),
84            interval,
85        );
86    }
87
88    #[test]
89    fn test_jitter_full() {
90        let jitter = Jitter::Full;
91        let min_interval = Duration::from_secs(5);
92        let interval = Duration::from_secs(10);
93        let result = jitter.apply(interval, min_interval, &mut rand::rng());
94        assert!(result >= Duration::ZERO && result <= interval);
95    }
96
97    #[test]
98    fn test_jitter_bounded() {
99        let jitter = Jitter::Bounded;
100        let min_interval = Duration::from_secs(5);
101        let interval = Duration::from_secs(10);
102        let result = jitter.apply(interval, min_interval, &mut rand::rng());
103        assert!(
104            result >= min_interval.mul_f64(Jitter::BOUNDED_MIN_BOUND_FRACTION)
105                && result <= interval
106        );
107    }
108
109    #[test]
110    fn test_jitter_bounded_first_retry() {
111        let jitter = Jitter::Bounded;
112        let min_interval = Duration::from_secs(1);
113        let interval = min_interval;
114        let mut rng: StdRng = SeedableRng::seed_from_u64(SEED);
115        let result = jitter.apply(interval, min_interval, &mut rng);
116        assert!(
117            result < interval,
118            "should have jittered to below the min interval"
119        );
120        assert_eq!(result, Duration::from_nanos(708_215_236));
121    }
122}