governor/state/direct/
sinks.rs

1use std::prelude::v1::*;
2
3use crate::{
4    clock,
5    middleware::RateLimitingMiddleware,
6    state::{DirectStateStore, NotKeyed},
7    Jitter, NotUntil, RateLimiter,
8};
9use futures::task::{Context, Poll};
10use futures::{Future, Sink, Stream};
11use futures_timer::Delay;
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15/// Allows converting a [`futures::Sink`] combinator into a rate-limited sink.
16pub trait SinkRateLimitExt<Item, S>: Sink<Item>
17where
18    S: Sink<Item>,
19{
20    /// Limits the rate at which items can be put into the current sink.
21    fn ratelimit_sink<
22        D: DirectStateStore,
23        C: clock::ReasonablyRealtime,
24        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
25    >(
26        self,
27        limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
28    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
29    where
30        Self: Sized;
31
32    /// Limits the rate at which items can be put into the current sink, with a randomized wait
33    /// period.
34    #[cfg(feature = "jitter")]
35    fn ratelimit_sink_with_jitter<
36        D: DirectStateStore,
37        C: clock::ReasonablyRealtime,
38        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
39    >(
40        self,
41        limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
42        jitter: Jitter,
43    ) -> RatelimitedSink<'_, Item, S, D, C, MW>
44    where
45        Self: Sized;
46}
47
48impl<Item, S: Sink<Item>> SinkRateLimitExt<Item, S> for S {
49    fn ratelimit_sink<
50        D: DirectStateStore,
51        C: clock::ReasonablyRealtime,
52        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
53    >(
54        self,
55        limiter: &RateLimiter<NotKeyed, D, C, MW>,
56    ) -> RatelimitedSink<Item, S, D, C, MW>
57    where
58        Self: Sized,
59    {
60        RatelimitedSink::new(self, limiter, Jitter::NONE)
61    }
62
63    #[cfg(feature = "jitter")]
64    fn ratelimit_sink_with_jitter<
65        D: DirectStateStore,
66        C: clock::ReasonablyRealtime,
67        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
68    >(
69        self,
70        limiter: &RateLimiter<NotKeyed, D, C, MW>,
71        jitter: Jitter,
72    ) -> RatelimitedSink<Item, S, D, C, MW>
73    where
74        Self: Sized,
75    {
76        RatelimitedSink::new(self, limiter, jitter)
77    }
78}
79
80#[derive(Debug)]
81enum State {
82    NotReady,
83    Wait,
84    Ready,
85}
86
87/// A [`Sink`][futures::Sink] combinator that only allows sending elements when the rate-limiter
88/// allows it.
89pub struct RatelimitedSink<
90    'a,
91    Item,
92    S: Sink<Item>,
93    D: DirectStateStore,
94    C: clock::ReasonablyRealtime,
95    MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
96> {
97    inner: S,
98    state: State,
99    limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
100    delay: Delay,
101    jitter: Jitter,
102    phantom: PhantomData<Item>,
103}
104
105/// Conversion methods for the sink combinator.
106impl<
107        'a,
108        Item,
109        S: Sink<Item>,
110        D: DirectStateStore,
111        C: clock::ReasonablyRealtime,
112        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
113    > RatelimitedSink<'a, Item, S, D, C, MW>
114{
115    fn new(inner: S, limiter: &'a RateLimiter<NotKeyed, D, C, MW>, jitter: Jitter) -> Self {
116        RatelimitedSink {
117            inner,
118            limiter,
119            delay: Delay::new(Default::default()),
120            state: State::NotReady,
121            jitter,
122            phantom: PhantomData,
123        }
124    }
125
126    /// Acquires a reference to the underlying sink that this combinator is sending into.
127    pub fn get_ref(&self) -> &S {
128        &self.inner
129    }
130
131    /// Acquires a mutable reference to the underlying sink that this combinator is sending into.
132    ///
133    /// ```
134    /// # futures::executor::block_on(async {
135    /// # use futures::sink::{self, SinkExt};
136    /// # use nonzero_ext::nonzero;
137    /// use governor::{prelude::*, RateLimiter, Quota};
138    /// let drain = sink::drain();
139    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
140    /// let mut limited = drain.ratelimit_sink(&lim);
141    /// limited.get_mut().send(5).await?;
142    /// # Ok::<(), futures::never::Never>(()) }).unwrap();
143    /// ```
144    pub fn get_mut(&mut self) -> &mut S {
145        &mut self.inner
146    }
147
148    /// Consumes this combinator, returning the underlying sink.
149    pub fn into_inner(self) -> S {
150        self.inner
151    }
152}
153
154impl<
155        'a,
156        Item,
157        S: Sink<Item>,
158        D: DirectStateStore,
159        C: clock::ReasonablyRealtime,
160        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
161    > Sink<Item> for RatelimitedSink<'a, Item, S, D, C, MW>
162where
163    S: Unpin,
164    Item: Unpin,
165{
166    type Error = S::Error;
167
168    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        loop {
170            match self.state {
171                State::NotReady => {
172                    let reference = self.limiter.reference_reading();
173                    if let Err(negative) = self.limiter.check() {
174                        let earliest = negative.wait_time_with_offset(reference, self.jitter);
175                        self.delay.reset(earliest);
176                        let future = Pin::new(&mut self.delay);
177                        match future.poll(cx) {
178                            Poll::Pending => {
179                                self.state = State::Wait;
180                                return Poll::Pending;
181                            }
182                            Poll::Ready(_) => {}
183                        }
184                    } else {
185                        self.state = State::Ready;
186                    }
187                }
188                State::Wait => {
189                    let future = Pin::new(&mut self.delay);
190                    match future.poll(cx) {
191                        Poll::Pending => {
192                            return Poll::Pending;
193                        }
194                        Poll::Ready(_) => {
195                            self.state = State::NotReady;
196                        }
197                    }
198                }
199                State::Ready => {
200                    let inner = Pin::new(&mut self.inner);
201                    return inner.poll_ready(cx);
202                }
203            }
204        }
205    }
206
207    fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
208        match self.state {
209            State::Wait | State::NotReady => {
210                unreachable!("Must not start_send before we're ready"); // !no_rcov!
211            }
212            State::Ready => {
213                self.state = State::NotReady;
214                let inner = Pin::new(&mut self.inner);
215                inner.start_send(item)
216            }
217        }
218    }
219
220    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
221        let inner = Pin::new(&mut self.inner);
222        inner.poll_flush(cx)
223    }
224
225    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226        let inner = Pin::new(&mut self.inner);
227        inner.poll_close(cx)
228    }
229}
230
231/// Pass-through implementation for [`futures::Stream`] if the Sink also implements it.
232impl<
233        'a,
234        Item,
235        S: Stream + Sink<Item>,
236        D: DirectStateStore,
237        C: clock::ReasonablyRealtime,
238        MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
239    > Stream for RatelimitedSink<'a, Item, S, D, C, MW>
240where
241    S::Item: Unpin,
242    S: Unpin,
243    Item: Unpin,
244{
245    type Item = <S as Stream>::Item;
246
247    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
248        let inner = Pin::new(&mut self.inner);
249        inner.poll_next(cx)
250    }
251
252    fn size_hint(&self) -> (usize, Option<usize>) {
253        self.inner.size_hint()
254    }
255}