1use crate::state::StateStore;
2use crate::InsufficientCapacity;
3use crate::{clock, middleware::StateSnapshot, Quota};
4use crate::{middleware::RateLimitingMiddleware, nanos::Nanos};
5use std::num::NonZeroU32;
6use std::time::Duration;
7use std::{cmp, fmt};
8
9#[cfg(feature = "std")]
10use crate::Jitter;
11
12#[derive(Debug, PartialEq, Eq)]
17pub struct NotUntil<P: clock::Reference> {
18 state: StateSnapshot,
19 start: P,
20}
21
22impl<P: clock::Reference> NotUntil<P> {
23 #[inline]
25 pub(crate) fn new(state: StateSnapshot, start: P) -> Self {
26 Self { state, start }
27 }
28
29 #[inline]
33 pub fn earliest_possible(&self) -> P {
34 let tat: Nanos = self.state.tat;
35 self.start + tat
36 }
37
38 #[inline]
45 pub fn wait_time_from(&self, from: P) -> Duration {
46 let earliest = self.earliest_possible();
47 earliest.duration_since(earliest.min(from)).into()
48 }
49
50 #[inline]
52 pub fn quota(&self) -> Quota {
53 self.state.quota()
54 }
55
56 #[cfg(feature = "std")] #[inline]
58 pub(crate) fn earliest_possible_with_offset(&self, jitter: Jitter) -> P {
59 let tat = jitter + self.state.tat;
60 self.start + tat
61 }
62
63 #[cfg(feature = "std")] #[inline]
65 pub(crate) fn wait_time_with_offset(&self, from: P, jitter: Jitter) -> Duration {
66 let earliest = self.earliest_possible_with_offset(jitter);
67 earliest.duration_since(earliest.min(from)).into()
68 }
69}
70
71impl<P: clock::Reference> fmt::Display for NotUntil<P> {
72 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
73 write!(f, "rate-limited until {:?}", self.start + self.state.tat)
74 }
75}
76
77#[derive(Debug, PartialEq, Eq)]
78pub(crate) struct Gcra {
79 t: Nanos,
81
82 tau: Nanos,
84}
85
86impl Gcra {
87 pub(crate) fn new(quota: Quota) -> Self {
88 let tau: Nanos = (cmp::max(quota.replenish_1_per, Duration::from_nanos(1))
89 * quota.max_burst.get())
90 .into();
91 let t: Nanos = quota.replenish_1_per.into();
92 Gcra { t, tau }
93 }
94
95 fn starting_state(&self, t0: Nanos) -> Nanos {
97 t0 + self.t
98 }
99
100 pub(crate) fn test_and_update<
102 K,
103 P: clock::Reference,
104 S: StateStore<Key = K>,
105 MW: RateLimitingMiddleware<P>,
106 >(
107 &self,
108 start: P,
109 key: &K,
110 state: &S,
111 t0: P,
112 ) -> Result<MW::PositiveOutcome, MW::NegativeOutcome> {
113 let t0 = t0.duration_since(start);
114 let tau = self.tau;
115 let t = self.t;
116 state.measure_and_replace(key, |tat| {
117 let tat = tat.unwrap_or_else(|| self.starting_state(t0));
118 let earliest_time = tat.saturating_sub(tau);
119 if t0 < earliest_time {
120 Err(MW::disallow(
121 key,
122 StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
123 start,
124 ))
125 } else {
126 let next = cmp::max(tat, t0) + t;
127 Ok((
128 MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
129 next,
130 ))
131 }
132 })
133 }
134
135 pub(crate) fn test_n_all_and_update<
137 K,
138 P: clock::Reference,
139 S: StateStore<Key = K>,
140 MW: RateLimitingMiddleware<P>,
141 >(
142 &self,
143 start: P,
144 key: &K,
145 n: NonZeroU32,
146 state: &S,
147 t0: P,
148 ) -> Result<Result<MW::PositiveOutcome, MW::NegativeOutcome>, InsufficientCapacity> {
149 let t0 = t0.duration_since(start);
150 let tau = self.tau;
151 let t = self.t;
152 let additional_weight = t * (n.get() - 1) as u64;
153
154 if additional_weight + t > tau {
157 return Err(InsufficientCapacity((tau.as_u64() / t.as_u64()) as u32));
158 }
159 Ok(state.measure_and_replace(key, |tat| {
160 let tat = tat.unwrap_or_else(|| self.starting_state(t0));
161 let earliest_time = (tat + additional_weight).saturating_sub(tau);
162 if t0 < earliest_time {
163 Err(MW::disallow(
164 key,
165 StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
166 start,
167 ))
168 } else {
169 let next = cmp::max(tat, t0) + t + additional_weight;
170 Ok((
171 MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
172 next,
173 ))
174 }
175 }))
176 }
177}
178
179#[cfg(test)]
180mod test {
181 use super::*;
182 use crate::Quota;
183 use std::num::NonZeroU32;
184
185 use proptest::prelude::*;
186
187 #[cfg(feature = "std")]
189 #[test]
190 fn gcra_derives() {
191 use all_asserts::assert_gt;
192 use nonzero_ext::nonzero;
193
194 let g = Gcra::new(Quota::per_second(nonzero!(1u32)));
195 let g2 = Gcra::new(Quota::per_second(nonzero!(2u32)));
196 assert_eq!(g, g);
197 assert_ne!(g, g2);
198 assert_gt!(format!("{:?}", g).len(), 0);
199 }
200
201 #[cfg(feature = "std")]
203 #[test]
204 fn notuntil_impls() {
205 use crate::RateLimiter;
206 use all_asserts::assert_gt;
207 use clock::FakeRelativeClock;
208 use nonzero_ext::nonzero;
209
210 let clock = FakeRelativeClock::default();
211 let quota = Quota::per_second(nonzero!(1u32));
212 let lb = RateLimiter::direct_with_clock(quota, &clock);
213 assert!(lb.check().is_ok());
214 assert!(lb
215 .check()
216 .map_err(|nu| {
217 assert_eq!(nu, nu);
218 assert_gt!(format!("{:?}", nu).len(), 0);
219 assert_eq!(format!("{}", nu), "rate-limited until Nanos(1s)");
220 assert_eq!(nu.quota(), quota);
221 })
222 .is_err());
223 }
224
225 #[derive(Debug)]
226 struct Count(NonZeroU32);
227 impl Arbitrary for Count {
228 type Parameters = ();
229 fn arbitrary_with(_args: ()) -> Self::Strategy {
230 (1..10000u32)
231 .prop_map(|x| Count(NonZeroU32::new(x).unwrap()))
232 .boxed()
233 }
234
235 type Strategy = BoxedStrategy<Count>;
236 }
237
238 #[cfg(feature = "std")]
239 #[test]
240 fn cover_count_derives() {
241 assert_eq!(
242 format!("{:?}", Count(nonzero_ext::nonzero!(1_u32))),
243 "Count(1)"
244 );
245 }
246
247 #[test]
248 fn roundtrips_quota() {
249 proptest!(ProptestConfig::default(), |(per_second: Count, burst: Count)| {
250 let quota = Quota::per_second(per_second.0).allow_burst(burst.0);
251 let gcra = Gcra::new(quota);
252 let back = Quota::from_gcra_parameters(gcra.t, gcra.tau);
253 assert_eq!(quota, back);
254 })
255 }
256}