retry_policies/
retry_policy.rs1use std::time::{Duration, SystemTime};
2
3use rand::{
4 distr::uniform::{UniformFloat, UniformSampler},
5 Rng,
6};
7
8pub trait RetryPolicy {
10 fn should_retry(&self, request_start_time: SystemTime, n_past_retries: u32) -> RetryDecision;
12}
13
14#[derive(Debug)]
16pub enum RetryDecision {
17 Retry { execute_after: SystemTime },
19 DoNotRetry,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[non_exhaustive]
26pub enum Jitter {
27 None,
29 Full,
31 Bounded,
33}
34
35impl Jitter {
36 const BOUNDED_MIN_BOUND_FRACTION: f64 = 0.5;
39
40 pub(crate) fn apply(
41 &self,
42 interval: Duration,
43 min_interval: Duration,
44 rng: &mut impl Rng,
45 ) -> Duration {
46 match self {
47 Jitter::None => interval,
48 Jitter::Full => {
49 let jitter_factor = UniformFloat::<f64>::sample_single(0.0, 1.0, rng)
50 .expect("Sample range should be valid");
51
52 interval.mul_f64(jitter_factor)
53 }
54 Jitter::Bounded => {
55 let jitter_factor = UniformFloat::<f64>::sample_single(0.0, 1.0, rng)
56 .expect("Sample range should be valid");
57
58 let jittered_wait_for = (interval
59 - min_interval.mul_f64(Self::BOUNDED_MIN_BOUND_FRACTION))
60 .mul_f64(jitter_factor);
61
62 jittered_wait_for + min_interval.mul_f64(Self::BOUNDED_MIN_BOUND_FRACTION)
63 }
64 }
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use rand::{rngs::StdRng, SeedableRng};
71
72 use super::*;
73 use std::time::Duration;
74
75 const SEED: u64 = 3097268606784207815;
76
77 #[test]
78 fn test_jitter_none() {
79 let jitter = Jitter::None;
80 let min_interval = Duration::from_secs(5);
81 let interval = Duration::from_secs(10);
82 assert_eq!(
83 jitter.apply(interval, min_interval, &mut rand::rng()),
84 interval,
85 );
86 }
87
88 #[test]
89 fn test_jitter_full() {
90 let jitter = Jitter::Full;
91 let min_interval = Duration::from_secs(5);
92 let interval = Duration::from_secs(10);
93 let result = jitter.apply(interval, min_interval, &mut rand::rng());
94 assert!(result >= Duration::ZERO && result <= interval);
95 }
96
97 #[test]
98 fn test_jitter_bounded() {
99 let jitter = Jitter::Bounded;
100 let min_interval = Duration::from_secs(5);
101 let interval = Duration::from_secs(10);
102 let result = jitter.apply(interval, min_interval, &mut rand::rng());
103 assert!(
104 result >= min_interval.mul_f64(Jitter::BOUNDED_MIN_BOUND_FRACTION)
105 && result <= interval
106 );
107 }
108
109 #[test]
110 fn test_jitter_bounded_first_retry() {
111 let jitter = Jitter::Bounded;
112 let min_interval = Duration::from_secs(1);
113 let interval = min_interval;
114 let mut rng: StdRng = SeedableRng::seed_from_u64(SEED);
115 let result = jitter.apply(interval, min_interval, &mut rng);
116 assert!(
117 result < interval,
118 "should have jittered to below the min interval"
119 );
120 assert_eq!(result, Duration::from_nanos(708_215_236));
121 }
122}