use crate::state::StateStore;
use crate::InsufficientCapacity;
use crate::{clock, middleware::StateSnapshot, Quota};
use crate::{middleware::RateLimitingMiddleware, nanos::Nanos};
use std::num::NonZeroU32;
use std::time::Duration;
use std::{cmp, fmt};
#[cfg(feature = "std")]
use crate::Jitter;
#[derive(Debug, PartialEq, Eq)]
pub struct NotUntil<P: clock::Reference> {
state: StateSnapshot,
start: P,
}
impl<P: clock::Reference> NotUntil<P> {
#[inline]
pub(crate) fn new(state: StateSnapshot, start: P) -> Self {
Self { state, start }
}
#[inline]
pub fn earliest_possible(&self) -> P {
let tat: Nanos = self.state.tat;
self.start + tat
}
#[inline]
pub fn wait_time_from(&self, from: P) -> Duration {
let earliest = self.earliest_possible();
earliest.duration_since(earliest.min(from)).into()
}
#[inline]
pub fn quota(&self) -> Quota {
self.state.quota()
}
#[cfg(feature = "std")] #[inline]
pub(crate) fn earliest_possible_with_offset(&self, jitter: Jitter) -> P {
let tat = jitter + self.state.tat;
self.start + tat
}
#[cfg(feature = "std")] #[inline]
pub(crate) fn wait_time_with_offset(&self, from: P, jitter: Jitter) -> Duration {
let earliest = self.earliest_possible_with_offset(jitter);
earliest.duration_since(earliest.min(from)).into()
}
}
impl<P: clock::Reference> fmt::Display for NotUntil<P> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(f, "rate-limited until {:?}", self.start + self.state.tat)
}
}
#[derive(Debug, PartialEq, Eq)]
pub(crate) struct Gcra {
t: Nanos,
tau: Nanos,
}
impl Gcra {
pub(crate) fn new(quota: Quota) -> Self {
let tau: Nanos = (quota.replenish_1_per * quota.max_burst.get()).into();
let t: Nanos = quota.replenish_1_per.into();
Gcra { t, tau }
}
fn starting_state(&self, t0: Nanos) -> Nanos {
t0 + self.t
}
pub(crate) fn test_and_update<
K,
P: clock::Reference,
S: StateStore<Key = K>,
MW: RateLimitingMiddleware<P>,
>(
&self,
start: P,
key: &K,
state: &S,
t0: P,
) -> Result<MW::PositiveOutcome, MW::NegativeOutcome> {
let t0 = t0.duration_since(start);
let tau = self.tau;
let t = self.t;
state.measure_and_replace(key, |tat| {
let tat = tat.unwrap_or_else(|| self.starting_state(t0));
let earliest_time = tat.saturating_sub(tau);
if t0 < earliest_time {
Err(MW::disallow(
key,
StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
start,
))
} else {
let next = cmp::max(tat, t0) + t;
Ok((
MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
next,
))
}
})
}
pub(crate) fn test_n_all_and_update<
K,
P: clock::Reference,
S: StateStore<Key = K>,
MW: RateLimitingMiddleware<P>,
>(
&self,
start: P,
key: &K,
n: NonZeroU32,
state: &S,
t0: P,
) -> Result<Result<MW::PositiveOutcome, MW::NegativeOutcome>, InsufficientCapacity> {
let t0 = t0.duration_since(start);
let tau = self.tau;
let t = self.t;
let additional_weight = t * (n.get() - 1) as u64;
if additional_weight + t > tau {
return Err(InsufficientCapacity((tau.as_u64() / t.as_u64()) as u32));
}
Ok(state.measure_and_replace(key, |tat| {
let tat = tat.unwrap_or_else(|| self.starting_state(t0));
let earliest_time = (tat + additional_weight).saturating_sub(tau);
if t0 < earliest_time {
Err(MW::disallow(
key,
StateSnapshot::new(self.t, self.tau, earliest_time, earliest_time),
start,
))
} else {
let next = cmp::max(tat, t0) + t + additional_weight;
Ok((
MW::allow(key, StateSnapshot::new(self.t, self.tau, t0, next)),
next,
))
}
}))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::Quota;
use std::num::NonZeroU32;
use proptest::prelude::*;
#[cfg(feature = "std")]
#[test]
fn gcra_derives() {
use all_asserts::assert_gt;
use nonzero_ext::nonzero;
let g = Gcra::new(Quota::per_second(nonzero!(1u32)));
let g2 = Gcra::new(Quota::per_second(nonzero!(2u32)));
assert_eq!(g, g);
assert_ne!(g, g2);
assert_gt!(format!("{:?}", g).len(), 0);
}
#[cfg(feature = "std")]
#[test]
fn notuntil_impls() {
use crate::RateLimiter;
use all_asserts::assert_gt;
use clock::FakeRelativeClock;
use nonzero_ext::nonzero;
let clock = FakeRelativeClock::default();
let quota = Quota::per_second(nonzero!(1u32));
let lb = RateLimiter::direct_with_clock(quota, &clock);
assert!(lb.check().is_ok());
assert!(lb
.check()
.map_err(|nu| {
assert_eq!(nu, nu);
assert_gt!(format!("{:?}", nu).len(), 0);
assert_eq!(format!("{}", nu), "rate-limited until Nanos(1s)");
assert_eq!(nu.quota(), quota);
})
.is_err());
}
#[derive(Debug)]
struct Count(NonZeroU32);
impl Arbitrary for Count {
type Parameters = ();
fn arbitrary_with(_args: ()) -> Self::Strategy {
(1..10000u32)
.prop_map(|x| Count(NonZeroU32::new(x).unwrap()))
.boxed()
}
type Strategy = BoxedStrategy<Count>;
}
#[cfg(feature = "std")]
#[test]
fn cover_count_derives() {
assert_eq!(
format!("{:?}", Count(nonzero_ext::nonzero!(1_u32))),
"Count(1)"
);
}
#[test]
fn roundtrips_quota() {
proptest!(ProptestConfig::default(), |(per_second: Count, burst: Count)| {
let quota = Quota::per_second(per_second.0).allow_burst(burst.0);
let gcra = Gcra::new(quota);
let back = Quota::from_gcra_parameters(gcra.t, gcra.tau);
assert_eq!(quota, back);
})
}
}