governor/
middleware.rs

1//! Additional, customizable behavior for rate limiters.
2//!
3//! Rate-limiting middleware follows the principle that basic
4//! rate-limiting should be very cheap, and unless users desire more
5//! behavior, they should not pay any extra price.
6//!
7//! However, if you do desire more information about what the
8//! rate-limiter does (or the ability to install hooks in the
9//! decision-making process), you can. The [`RateLimitingMiddleware`]
10//! trait in this module allows you to customize:
11//!
12//! * Any additional code that gets run when a rate-limiting decision is made.
13//! * What value is returned in the positive or negative case.
14//!
15//! Writing middleware does **not** let you override rate-limiting
16//! decisions: They remain either positive (returning `Ok`) or negative
17//! (returning `Err`). However, you can override the values returned
18//! inside the Result for either decision.
19//!
20//! This crate ships two middlewares (named after their behavior in the
21//! positive outcome):
22//!
23//! * The cheapest still-useful one, [`NoOpMiddleware`], named after its
24//!   behavior in the positive case. In the positive case it returns
25//!   `Ok(())`; in the negative case, `Err(`[`NotUntil`]`)`.
26//!
27//! * A more informative middleware, [`StateInformationMiddleware`], which
28//!   returns `Ok(`[`StateSnapshot`]`)`, or
29//!   `Err(`[`NotUntil`]`)`.
30//!
31//! ## Using a custom middleware
32//!
33//! Middlewares are attached to the
34//! [`RateLimiter`][crate::RateLimiter] at construction time using
35//! [`RateLimiter::with_middleware`][crate::RateLimiter::with_middleware]:
36//!
37//! ```rust
38//! # #[cfg(feature = "std")]
39//! # fn main () {
40//! # use nonzero_ext::nonzero;
41//! use governor::{RateLimiter, Quota, middleware::StateInformationMiddleware};
42//! let lim = RateLimiter::direct(Quota::per_hour(nonzero!(1_u32)))
43//!     .with_middleware::<StateInformationMiddleware>();
44//!
45//! // A positive outcome with additional information:
46//! assert!(
47//!     lim.check()
48//!         // Here we receive an Ok(StateSnapshot):
49//!         .map(|outcome| assert_eq!(outcome.remaining_burst_capacity(), 0))
50//!         .is_ok()
51//! );
52//!
53//! // The negative case:
54//! assert!(
55//!     lim.check()
56//!         // Here we receive Err(NotUntil):
57//!         .map_err(|outcome| assert_eq!(outcome.quota().burst_size().get(), 1))
58//!         .is_err()
59//! );
60//! # }
61//! # #[cfg(not(feature = "std"))]
62//! # fn main() {}
63//! ```
64//!
65//! You can define your own middleware by `impl`ing [`RateLimitingMiddleware`].
66use core::fmt;
67use std::{cmp, marker::PhantomData};
68
69use crate::{clock, nanos::Nanos, NotUntil, Quota};
70
71/// Information about the rate-limiting state used to reach a decision.
72#[derive(Clone, PartialEq, Eq, Debug)]
73pub struct StateSnapshot {
74    /// The "weight" of a single packet in units of time.
75    t: Nanos,
76
77    /// The "burst capacity" of the bucket.
78    tau: Nanos,
79
80    /// The time at which the measurement was taken.
81    pub(crate) time_of_measurement: Nanos,
82
83    /// The next time a cell is expected to arrive
84    pub(crate) tat: Nanos,
85}
86
87impl StateSnapshot {
88    #[inline]
89    pub(crate) fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self {
90        Self {
91            t,
92            tau,
93            time_of_measurement,
94            tat,
95        }
96    }
97
98    /// Returns the quota used to make the rate limiting decision.
99    pub fn quota(&self) -> Quota {
100        Quota::from_gcra_parameters(self.t, self.tau)
101    }
102
103    /// Returns the number of cells that can be let through in
104    /// addition to a (possible) positive outcome.
105    ///
106    /// If this state snapshot is based on a negative rate limiting
107    /// outcome, this method returns 0.
108    pub fn remaining_burst_capacity(&self) -> u32 {
109        let t0 = self.time_of_measurement + self.t;
110        (cmp::min(
111            (t0 + self.tau).saturating_sub(self.tat).as_u64(),
112            self.tau.as_u64(),
113        ) / self.t.as_u64()) as u32
114    }
115}
116
117/// Defines the behavior and return values of rate limiting decisions.
118///
119/// While the rate limiter defines whether a decision is positive, the
120/// middleware defines what additional values (other than `Ok` or `Err`)
121/// are returned from the [`RateLimiter`][crate::RateLimiter]'s check methods.
122///
123/// The default middleware in this crate is [`NoOpMiddleware`] (which does
124/// nothing in the positive case and returns [`NotUntil`] in the
125/// negative) - so it does only the smallest amount of work it needs to do
126/// in order to be useful to users.
127///
128/// Other middleware gets to adjust these trade-offs: The pre-made
129/// [`StateInformationMiddleware`] returns quota and burst capacity
130/// information, while custom middleware could return a set of HTTP
131/// headers or increment counters per each rate limiter key's decision.
132///
133/// # Defining your own middleware
134///
135/// Here's an example of a rate limiting middleware that does no
136/// computations at all on positive and negative outcomes: All the
137/// information that a caller will receive is that a request should be
138/// allowed or disallowed. This can allow for faster negative outcome
139/// handling, and is useful if you don't need to tell users when they
140/// can try again (or anything at all about their rate limiting
141/// status).
142///
143/// ```rust
144/// # use std::num::NonZeroU32;
145/// # use nonzero_ext::*;
146/// use governor::{middleware::{RateLimitingMiddleware, StateSnapshot},
147///                Quota, RateLimiter, clock::Reference};
148/// # #[cfg(feature = "std")]
149/// # fn main () {
150/// #[derive(Debug)]
151/// struct NullMiddleware;
152///
153/// impl<P: Reference> RateLimitingMiddleware<P> for NullMiddleware {
154///     type PositiveOutcome = ();
155///     type NegativeOutcome = ();
156///
157///     fn allow<K>(_key: &K, _state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {}
158///     fn disallow<K>(_: &K, _: impl Into<StateSnapshot>, _: P) -> Self::NegativeOutcome {}
159/// }
160///
161/// let lim = RateLimiter::direct(Quota::per_hour(nonzero!(1_u32)))
162///     .with_middleware::<NullMiddleware>();
163///
164/// assert_eq!(lim.check(), Ok(()));
165/// assert_eq!(lim.check(), Err(()));
166/// # }
167/// # #[cfg(not(feature = "std"))]
168/// # fn main() {}
169/// ```
170pub trait RateLimitingMiddleware<P: clock::Reference>: fmt::Debug {
171    /// The type that's returned by the rate limiter when a cell is allowed.
172    ///
173    /// By default, rate limiters return `Ok(())`, which does not give
174    /// much information. By using custom middleware, users can obtain
175    /// more information about the rate limiter state that was used to
176    /// come to a decision. That state can then be used to pass
177    /// information downstream about, e.g. how much burst capacity is
178    /// remaining.
179    type PositiveOutcome: Sized;
180
181    /// The type that's returned by the rate limiter when a cell is *not* allowed.
182    ///
183    /// By default, rate limiters return `Err(NotUntil)`, which
184    /// allows interrogating the minimum amount of time to wait until
185    /// a client can expect to have a cell allowed again.
186    type NegativeOutcome: Sized;
187
188    /// Called when a positive rate-limiting decision is made.
189    ///
190    /// This function is able to affect the return type of
191    /// [RateLimiter.check](../struct.RateLimiter.html#method.check)
192    /// (and others) in the Ok case: Whatever is returned here is the
193    /// value of the Ok result returned from the check functions.
194    ///
195    /// The function is passed a snapshot of the rate-limiting state
196    /// updated to *after* the decision was reached: E.g., if there
197    /// was one cell left in the burst capacity before the decision
198    /// was reached, the [`StateSnapshot::remaining_burst_capacity`]
199    /// method will return 0.
200    fn allow<K>(key: &K, state: impl Into<StateSnapshot>) -> Self::PositiveOutcome;
201
202    /// Called when a negative rate-limiting decision is made (the
203    /// "not allowed but OK" case).
204    ///
205    /// This method returns whatever value is returned inside the
206    /// `Err` variant a [`RateLimiter`][crate::RateLimiter]'s check
207    /// method returns.
208    fn disallow<K>(
209        key: &K,
210        limiter: impl Into<StateSnapshot>,
211        start_time: P,
212    ) -> Self::NegativeOutcome;
213}
214
215/// A middleware that does nothing and returns `()` in the positive outcome.
216pub struct NoOpMiddleware<P: clock::Reference = <clock::DefaultClock as clock::Clock>::Instant> {
217    phantom: PhantomData<P>,
218}
219
220impl<P: clock::Reference> std::fmt::Debug for NoOpMiddleware<P> {
221    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222        write!(f, "NoOpMiddleware")
223    }
224}
225
226impl<P: clock::Reference> RateLimitingMiddleware<P> for NoOpMiddleware<P> {
227    /// By default, rate limiters return nothing other than an
228    /// indicator that the element should be let through.
229    type PositiveOutcome = ();
230
231    type NegativeOutcome = NotUntil<P>;
232
233    #[inline]
234    /// Returns `()` and has no side-effects.
235    fn allow<K>(_key: &K, _state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {}
236
237    #[inline]
238    /// Returns the error indicating what
239    fn disallow<K>(
240        _key: &K,
241        state: impl Into<StateSnapshot>,
242        start_time: P,
243    ) -> Self::NegativeOutcome {
244        NotUntil::new(state.into(), start_time)
245    }
246}
247
248/// Middleware that returns the state of the rate limiter if a
249/// positive decision is reached.
250#[derive(Debug)]
251pub struct StateInformationMiddleware;
252
253impl<P: clock::Reference> RateLimitingMiddleware<P> for StateInformationMiddleware {
254    /// The state snapshot returned from the limiter.
255    type PositiveOutcome = StateSnapshot;
256
257    type NegativeOutcome = NotUntil<P>;
258
259    fn allow<K>(_key: &K, state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {
260        state.into()
261    }
262
263    fn disallow<K>(
264        _key: &K,
265        state: impl Into<StateSnapshot>,
266        start_time: P,
267    ) -> Self::NegativeOutcome {
268        NotUntil::new(state.into(), start_time)
269    }
270}
271
272#[cfg(all(feature = "std", test))]
273mod test {
274    use std::time::Duration;
275
276    use super::*;
277
278    #[test]
279    fn middleware_impl_derives() {
280        assert_eq!(
281            format!("{:?}", StateInformationMiddleware),
282            "StateInformationMiddleware"
283        );
284        assert_eq!(
285            format!(
286                "{:?}",
287                NoOpMiddleware {
288                    phantom: PhantomData::<Duration>,
289                }
290            ),
291            "NoOpMiddleware"
292        );
293    }
294}