governor/middleware.rs
1//! Additional, customizable behavior for rate limiters.
2//!
3//! Rate-limiting middleware follows the principle that basic
4//! rate-limiting should be very cheap, and unless users desire more
5//! behavior, they should not pay any extra price.
6//!
7//! However, if you do desire more information about what the
8//! rate-limiter does (or the ability to install hooks in the
9//! decision-making process), you can. The [`RateLimitingMiddleware`]
10//! trait in this module allows you to customize:
11//!
12//! * Any additional code that gets run when a rate-limiting decision is made.
13//! * What value is returned in the positive or negative case.
14//!
15//! Writing middleware does **not** let you override rate-limiting
16//! decisions: They remain either positive (returning `Ok`) or negative
17//! (returning `Err`). However, you can override the values returned
18//! inside the Result for either decision.
19//!
20//! This crate ships two middlewares (named after their behavior in the
21//! positive outcome):
22//!
23//! * The cheapest still-useful one, [`NoOpMiddleware`], named after its
24//! behavior in the positive case. In the positive case it returns
25//! `Ok(())`; in the negative case, `Err(`[`NotUntil`]`)`.
26//!
27//! * A more informative middleware, [`StateInformationMiddleware`], which
28//! returns `Ok(`[`StateSnapshot`]`)`, or
29//! `Err(`[`NotUntil`]`)`.
30//!
31//! ## Using a custom middleware
32//!
33//! Middlewares are attached to the
34//! [`RateLimiter`][crate::RateLimiter] at construction time using
35//! [`RateLimiter::with_middleware`][crate::RateLimiter::with_middleware]:
36//!
37//! ```rust
38//! # #[cfg(feature = "std")]
39//! # fn main () {
40//! # use nonzero_ext::nonzero;
41//! use governor::{RateLimiter, Quota, middleware::StateInformationMiddleware};
42//! let lim = RateLimiter::direct(Quota::per_hour(nonzero!(1_u32)))
43//! .with_middleware::<StateInformationMiddleware>();
44//!
45//! // A positive outcome with additional information:
46//! assert!(
47//! lim.check()
48//! // Here we receive an Ok(StateSnapshot):
49//! .map(|outcome| assert_eq!(outcome.remaining_burst_capacity(), 0))
50//! .is_ok()
51//! );
52//!
53//! // The negative case:
54//! assert!(
55//! lim.check()
56//! // Here we receive Err(NotUntil):
57//! .map_err(|outcome| assert_eq!(outcome.quota().burst_size().get(), 1))
58//! .is_err()
59//! );
60//! # }
61//! # #[cfg(not(feature = "std"))]
62//! # fn main() {}
63//! ```
64//!
65//! You can define your own middleware by `impl`ing [`RateLimitingMiddleware`].
66use core::fmt;
67use std::{cmp, marker::PhantomData};
68
69use crate::{clock, nanos::Nanos, NotUntil, Quota};
70
71/// Information about the rate-limiting state used to reach a decision.
72#[derive(Clone, PartialEq, Eq, Debug)]
73pub struct StateSnapshot {
74 /// The "weight" of a single packet in units of time.
75 t: Nanos,
76
77 /// The "burst capacity" of the bucket.
78 tau: Nanos,
79
80 /// The time at which the measurement was taken.
81 pub(crate) time_of_measurement: Nanos,
82
83 /// The next time a cell is expected to arrive
84 pub(crate) tat: Nanos,
85}
86
87impl StateSnapshot {
88 #[inline]
89 pub(crate) fn new(t: Nanos, tau: Nanos, time_of_measurement: Nanos, tat: Nanos) -> Self {
90 Self {
91 t,
92 tau,
93 time_of_measurement,
94 tat,
95 }
96 }
97
98 /// Returns the quota used to make the rate limiting decision.
99 pub fn quota(&self) -> Quota {
100 Quota::from_gcra_parameters(self.t, self.tau)
101 }
102
103 /// Returns the number of cells that can be let through in
104 /// addition to a (possible) positive outcome.
105 ///
106 /// If this state snapshot is based on a negative rate limiting
107 /// outcome, this method returns 0.
108 pub fn remaining_burst_capacity(&self) -> u32 {
109 let t0 = self.time_of_measurement + self.t;
110 (cmp::min(
111 (t0 + self.tau).saturating_sub(self.tat).as_u64(),
112 self.tau.as_u64(),
113 ) / self.t.as_u64()) as u32
114 }
115}
116
117/// Defines the behavior and return values of rate limiting decisions.
118///
119/// While the rate limiter defines whether a decision is positive, the
120/// middleware defines what additional values (other than `Ok` or `Err`)
121/// are returned from the [`RateLimiter`][crate::RateLimiter]'s check methods.
122///
123/// The default middleware in this crate is [`NoOpMiddleware`] (which does
124/// nothing in the positive case and returns [`NotUntil`] in the
125/// negative) - so it does only the smallest amount of work it needs to do
126/// in order to be useful to users.
127///
128/// Other middleware gets to adjust these trade-offs: The pre-made
129/// [`StateInformationMiddleware`] returns quota and burst capacity
130/// information, while custom middleware could return a set of HTTP
131/// headers or increment counters per each rate limiter key's decision.
132///
133/// # Defining your own middleware
134///
135/// Here's an example of a rate limiting middleware that does no
136/// computations at all on positive and negative outcomes: All the
137/// information that a caller will receive is that a request should be
138/// allowed or disallowed. This can allow for faster negative outcome
139/// handling, and is useful if you don't need to tell users when they
140/// can try again (or anything at all about their rate limiting
141/// status).
142///
143/// ```rust
144/// # use std::num::NonZeroU32;
145/// # use nonzero_ext::*;
146/// use governor::{middleware::{RateLimitingMiddleware, StateSnapshot},
147/// Quota, RateLimiter, clock::Reference};
148/// # #[cfg(feature = "std")]
149/// # fn main () {
150/// #[derive(Debug)]
151/// struct NullMiddleware;
152///
153/// impl<P: Reference> RateLimitingMiddleware<P> for NullMiddleware {
154/// type PositiveOutcome = ();
155/// type NegativeOutcome = ();
156///
157/// fn allow<K>(_key: &K, _state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {}
158/// fn disallow<K>(_: &K, _: impl Into<StateSnapshot>, _: P) -> Self::NegativeOutcome {}
159/// }
160///
161/// let lim = RateLimiter::direct(Quota::per_hour(nonzero!(1_u32)))
162/// .with_middleware::<NullMiddleware>();
163///
164/// assert_eq!(lim.check(), Ok(()));
165/// assert_eq!(lim.check(), Err(()));
166/// # }
167/// # #[cfg(not(feature = "std"))]
168/// # fn main() {}
169/// ```
170pub trait RateLimitingMiddleware<P: clock::Reference>: fmt::Debug {
171 /// The type that's returned by the rate limiter when a cell is allowed.
172 ///
173 /// By default, rate limiters return `Ok(())`, which does not give
174 /// much information. By using custom middleware, users can obtain
175 /// more information about the rate limiter state that was used to
176 /// come to a decision. That state can then be used to pass
177 /// information downstream about, e.g. how much burst capacity is
178 /// remaining.
179 type PositiveOutcome: Sized;
180
181 /// The type that's returned by the rate limiter when a cell is *not* allowed.
182 ///
183 /// By default, rate limiters return `Err(NotUntil)`, which
184 /// allows interrogating the minimum amount of time to wait until
185 /// a client can expect to have a cell allowed again.
186 type NegativeOutcome: Sized;
187
188 /// Called when a positive rate-limiting decision is made.
189 ///
190 /// This function is able to affect the return type of
191 /// [RateLimiter.check](../struct.RateLimiter.html#method.check)
192 /// (and others) in the Ok case: Whatever is returned here is the
193 /// value of the Ok result returned from the check functions.
194 ///
195 /// The function is passed a snapshot of the rate-limiting state
196 /// updated to *after* the decision was reached: E.g., if there
197 /// was one cell left in the burst capacity before the decision
198 /// was reached, the [`StateSnapshot::remaining_burst_capacity`]
199 /// method will return 0.
200 fn allow<K>(key: &K, state: impl Into<StateSnapshot>) -> Self::PositiveOutcome;
201
202 /// Called when a negative rate-limiting decision is made (the
203 /// "not allowed but OK" case).
204 ///
205 /// This method returns whatever value is returned inside the
206 /// `Err` variant a [`RateLimiter`][crate::RateLimiter]'s check
207 /// method returns.
208 fn disallow<K>(
209 key: &K,
210 limiter: impl Into<StateSnapshot>,
211 start_time: P,
212 ) -> Self::NegativeOutcome;
213}
214
215/// A middleware that does nothing and returns `()` in the positive outcome.
216pub struct NoOpMiddleware<P: clock::Reference = <clock::DefaultClock as clock::Clock>::Instant> {
217 phantom: PhantomData<P>,
218}
219
220impl<P: clock::Reference> std::fmt::Debug for NoOpMiddleware<P> {
221 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
222 write!(f, "NoOpMiddleware")
223 }
224}
225
226impl<P: clock::Reference> RateLimitingMiddleware<P> for NoOpMiddleware<P> {
227 /// By default, rate limiters return nothing other than an
228 /// indicator that the element should be let through.
229 type PositiveOutcome = ();
230
231 type NegativeOutcome = NotUntil<P>;
232
233 #[inline]
234 /// Returns `()` and has no side-effects.
235 fn allow<K>(_key: &K, _state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {}
236
237 #[inline]
238 /// Returns the error indicating what
239 fn disallow<K>(
240 _key: &K,
241 state: impl Into<StateSnapshot>,
242 start_time: P,
243 ) -> Self::NegativeOutcome {
244 NotUntil::new(state.into(), start_time)
245 }
246}
247
248/// Middleware that returns the state of the rate limiter if a
249/// positive decision is reached.
250#[derive(Debug)]
251pub struct StateInformationMiddleware;
252
253impl<P: clock::Reference> RateLimitingMiddleware<P> for StateInformationMiddleware {
254 /// The state snapshot returned from the limiter.
255 type PositiveOutcome = StateSnapshot;
256
257 type NegativeOutcome = NotUntil<P>;
258
259 fn allow<K>(_key: &K, state: impl Into<StateSnapshot>) -> Self::PositiveOutcome {
260 state.into()
261 }
262
263 fn disallow<K>(
264 _key: &K,
265 state: impl Into<StateSnapshot>,
266 start_time: P,
267 ) -> Self::NegativeOutcome {
268 NotUntil::new(state.into(), start_time)
269 }
270}
271
272#[cfg(all(feature = "std", test))]
273mod test {
274 use std::time::Duration;
275
276 use super::*;
277
278 #[test]
279 fn middleware_impl_derives() {
280 assert_eq!(
281 format!("{:?}", StateInformationMiddleware),
282 "StateInformationMiddleware"
283 );
284 assert_eq!(
285 format!(
286 "{:?}",
287 NoOpMiddleware {
288 phantom: PhantomData::<Duration>,
289 }
290 ),
291 "NoOpMiddleware"
292 );
293 }
294}