governor/state/direct/
sinks.rs

1use std::prelude::v1::*;
2
3use crate::{
4    clock,
5    middleware::RateLimitingMiddleware,
6    state::{DirectStateStore, NotKeyed},
7    Jitter, NotUntil, RateLimiter,
8};
9use futures_timer::Delay;
10use futures_util::task::{Context, Poll};
11use futures_util::{Future, Sink, Stream};
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15/// Allows converting a [`futures_util::Sink`] combinator into a rate-limited sink.
16pub trait SinkRateLimitExt<Item, S>: Sink<Item>
17where
18    S: Sink<Item>,
19{
20    /// Limits the rate at which items can be put into the current sink.
21    fn ratelimit_sink<
22        D: DirectStateStore,
23        C: clock::ReasonablyRealtime,
24        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
25    >(
26        self,
27        limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
28    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
29    where
30        Self: Sized;
31
32    /// Limits the rate at which items can be put into the current sink, with a randomized wait
33    /// period.
34    #[cfg(feature = "jitter")]
35    fn ratelimit_sink_with_jitter<
36        D: DirectStateStore,
37        C: clock::ReasonablyRealtime,
38        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
39    >(
40        self,
41        limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
42        jitter: Jitter,
43    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
44    where
45        Self: Sized;
46}
47
48impl<Item, S: Sink<Item>> SinkRateLimitExt<Item, S> for S {
49    fn ratelimit_sink<
50        D: DirectStateStore,
51        C: clock::ReasonablyRealtime,
52        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
53    >(
54        self,
55        limiter: &RateLimiter<NotKeyed, D, C, MW>,
56    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
57    where
58        Self: Sized,
59    {
60        RatelimitedSink::new(self, limiter, Jitter::NONE)
61    }
62
63    #[cfg(feature = "jitter")]
64    fn ratelimit_sink_with_jitter<
65        D: DirectStateStore,
66        C: clock::ReasonablyRealtime,
67        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
68    >(
69        self,
70        limiter: &RateLimiter<NotKeyed, D, C, MW>,
71        jitter: Jitter,
72    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
73    where
74        Self: Sized,
75    {
76        RatelimitedSink::new(self, limiter, jitter)
77    }
78}
79
80#[derive(Debug)]
81enum State {
82    NotReady,
83    Wait,
84    Ready,
85}
86
87/// A [`Sink`][futures_util::Sink] combinator that only allows sending elements when the rate-limiter
88/// allows it.
89pub struct RatelimitedSink<
90    'a,
91    Item,
92    S: Sink<Item>,
93    D: DirectStateStore,
94    C: clock::ReasonablyRealtime,
95    MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
96> {
97    inner: S,
98    state: State,
99    limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
100    delay: Delay,
101    jitter: Jitter,
102    phantom: PhantomData<Item>,
103}
104
105/// Conversion methods for the sink combinator.
106impl<
107        'a,
108        Item,
109        S: Sink<Item>,
110        D: DirectStateStore,
111        C: clock::ReasonablyRealtime,
112        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
113    > RatelimitedSink<'a, Item, S, D, C, MW>
114{
115    fn new(inner: S, limiter: &'a RateLimiter<NotKeyed, D, C, MW>, jitter: Jitter) -> Self {
116        RatelimitedSink {
117            inner,
118            limiter,
119            delay: Delay::new(Default::default()),
120            state: State::NotReady,
121            jitter,
122            phantom: PhantomData,
123        }
124    }
125
126    /// Acquires a reference to the underlying sink that this combinator is sending into.
127    pub fn get_ref(&self) -> &S {
128        &self.inner
129    }
130
131    /// Acquires a mutable reference to the underlying sink that this combinator is sending into.
132    ///
133    /// ```
134    /// # futures_executor::block_on(async {
135    /// # use futures_util::sink::{self, SinkExt};
136    /// # use nonzero_ext::nonzero;
137    /// use governor::{prelude::*, RateLimiter, Quota};
138    /// let drain = sink::drain();
139    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
140    /// let mut limited = drain.ratelimit_sink(&lim);
141    /// limited.get_mut().send(5).await?;
142    /// # Ok::<(), futures_util::never::Never>(()) }).unwrap();
143    /// ```
144    pub fn get_mut(&mut self) -> &mut S {
145        &mut self.inner
146    }
147
148    /// Consumes this combinator, returning the underlying sink.
149    pub fn into_inner(self) -> S {
150        self.inner
151    }
152}
153
154impl<
155        Item,
156        S: Sink<Item>,
157        D: DirectStateStore,
158        C: clock::ReasonablyRealtime,
159        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
160    > Sink<Item> for RatelimitedSink<'_, Item, S, D, C, MW>
161where
162    S: Unpin,
163    Item: Unpin,
164{
165    type Error = S::Error;
166
167    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168        loop {
169            match self.state {
170                State::NotReady => {
171                    let reference = self.limiter.reference_reading();
172                    if let Err(negative) = self.limiter.check() {
173                        let earliest = negative.wait_time_with_offset(reference, self.jitter);
174                        self.delay.reset(earliest);
175                        let future = Pin::new(&mut self.delay);
176                        match future.poll(cx) {
177                            Poll::Pending => {
178                                self.state = State::Wait;
179                                return Poll::Pending;
180                            }
181                            Poll::Ready(_) => {}
182                        }
183                    } else {
184                        self.state = State::Ready;
185                    }
186                }
187                State::Wait => {
188                    let future = Pin::new(&mut self.delay);
189                    match future.poll(cx) {
190                        Poll::Pending => {
191                            return Poll::Pending;
192                        }
193                        Poll::Ready(_) => {
194                            self.state = State::NotReady;
195                        }
196                    }
197                }
198                State::Ready => {
199                    let inner = Pin::new(&mut self.inner);
200                    return inner.poll_ready(cx);
201                }
202            }
203        }
204    }
205
206    fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
207        match self.state {
208            State::Wait | State::NotReady => {
209                unreachable!("Must not start_send before we're ready"); // !no_rcov!
210            }
211            State::Ready => {
212                self.state = State::NotReady;
213                let inner = Pin::new(&mut self.inner);
214                inner.start_send(item)
215            }
216        }
217    }
218
219    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220        let inner = Pin::new(&mut self.inner);
221        inner.poll_flush(cx)
222    }
223
224    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225        let inner = Pin::new(&mut self.inner);
226        inner.poll_close(cx)
227    }
228}
229
230/// Pass-through implementation for [`futures_util::Stream`] if the Sink also implements it.
231impl<
232        Item,
233        S: Stream + Sink<Item>,
234        D: DirectStateStore,
235        C: clock::ReasonablyRealtime,
236        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
237    > Stream for RatelimitedSink<'_, Item, S, D, C, MW>
238where
239    S::Item: Unpin,
240    S: Unpin,
241    Item: Unpin,
242{
243    type Item = <S as Stream>::Item;
244
245    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
246        let inner = Pin::new(&mut self.inner);
247        inner.poll_next(cx)
248    }
249
250    fn size_hint(&self) -> (usize, Option<usize>) {
251        self.inner.size_hint()
252    }
253}