governor/state/
in_memory.rs

1use std::prelude::v1::*;
2
3use crate::nanos::Nanos;
4use crate::state::{NotKeyed, StateStore};
5use std::fmt;
6use std::fmt::Debug;
7use std::num::NonZeroU64;
8use std::sync::atomic::Ordering;
9use std::time::Duration;
10
11use portable_atomic::AtomicU64;
12
13/// An in-memory representation of a GCRA's rate-limiting state.
14///
15/// Implemented using [`AtomicU64`] operations, this state representation can be used to
16/// construct rate limiting states for other in-memory states: e.g., this crate uses
17/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
18///
19/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
20/// nanoseconds since the rate limiter was created.
21#[derive(Default)]
22pub struct InMemoryState(AtomicU64);
23
24impl InMemoryState {
25    pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
26    where
27        F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
28    {
29        let mut prev = self.0.load(Ordering::Acquire);
30        let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
31        while let Ok((result, new_data)) = decision {
32            match self.0.compare_exchange_weak(
33                prev,
34                new_data.into(),
35                Ordering::Release,
36                Ordering::Relaxed,
37            ) {
38                Ok(_) => return Ok(result),
39                Err(next_prev) => prev = next_prev,
40            }
41            decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
42        }
43        // This map shouldn't be needed, as we only get here in the error case, but the compiler
44        // can't see it.
45        decision.map(|(result, _)| result)
46    }
47
48    pub(crate) fn is_older_than(&self, nanos: Nanos) -> bool {
49        self.0.load(Ordering::Relaxed) <= nanos.into()
50    }
51}
52
53/// The InMemoryState is the canonical "direct" state store.
54impl StateStore for InMemoryState {
55    type Key = NotKeyed;
56
57    fn measure_and_replace<T, F, E>(&self, _key: &Self::Key, f: F) -> Result<T, E>
58    where
59        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
60    {
61        self.measure_and_replace_one(f)
62    }
63}
64
65impl Debug for InMemoryState {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
67        let d = Duration::from_nanos(self.0.load(Ordering::Relaxed));
68        write!(f, "InMemoryState({:?})", d)
69    }
70}
71
72#[cfg(test)]
73#[allow(clippy::needless_collect)]
74mod test {
75
76    use all_asserts::assert_gt;
77
78    use super::*;
79
80    #[cfg(feature = "std")]
81    fn try_triggering_collisions(n_threads: u64, tries_per_thread: u64) -> (u64, u64) {
82        use std::sync::Arc;
83        use std::thread;
84
85        let mut state = Arc::new(InMemoryState(AtomicU64::new(0)));
86        let threads: Vec<thread::JoinHandle<_>> = (0..n_threads)
87            .map(|_| {
88                thread::spawn({
89                    let state = Arc::clone(&state);
90                    move || {
91                        let mut hits = 0;
92                        for _ in 0..tries_per_thread {
93                            assert!(state
94                                .measure_and_replace_one(|old| {
95                                    hits += 1;
96                                    Ok::<((), Nanos), ()>((
97                                        (),
98                                        Nanos::from(old.map(Nanos::as_u64).unwrap_or(0) + 1),
99                                    ))
100                                })
101                                .is_ok());
102                        }
103                        hits
104                    }
105                })
106            })
107            .collect();
108        let hits: u64 = threads.into_iter().map(|t| t.join().unwrap()).sum();
109        let value = Arc::get_mut(&mut state).unwrap().0.get_mut();
110        (*value, hits)
111    }
112
113    #[cfg(feature = "std")]
114    #[test]
115    /// Checks that many threads running simultaneously will collide,
116    /// but result in the correct number being recorded in the state.
117    fn stresstest_collisions() {
118        use all_asserts::assert_gt;
119
120        const THREADS: u64 = 8;
121        const MAX_TRIES: u64 = 20_000_000;
122        let (mut value, mut hits) = (0, 0);
123        for tries in (0..MAX_TRIES).step_by((MAX_TRIES / 100) as usize) {
124            let attempt = try_triggering_collisions(THREADS, tries);
125            value = attempt.0;
126            hits = attempt.1;
127            assert_eq!(value, tries * THREADS);
128            if hits > value {
129                break;
130            }
131            println!("Didn't trigger a collision in {} iterations", tries);
132        }
133        assert_gt!(hits, value);
134    }
135
136    #[test]
137    fn in_memory_state_impls() {
138        let state = InMemoryState(AtomicU64::new(0));
139        assert_gt!(format!("{:?}", state).len(), 0);
140    }
141}