governor/state/direct/
streams.rs

1use std::prelude::v1::*;
2
3use crate::{clock, Jitter, NotUntil, RateLimiter};
4use crate::{
5    middleware::RateLimitingMiddleware,
6    state::{DirectStateStore, NotKeyed},
7};
8use futures_timer::Delay;
9use futures_util::task::{Context, Poll};
10use futures_util::{Future, Sink, Stream};
11use std::pin::Pin;
12use std::time::Duration;
13
14/// Allows converting a [`futures_util::Stream`] combinator into a rate-limited stream.
15pub trait StreamRateLimitExt<'a>: Stream {
16    /// Limits the rate at which the stream produces items.
17    ///
18    /// Note that this combinator limits the rate at which it yields
19    /// items, not necessarily the rate at which the underlying stream is polled.
20    /// The combinator will buffer at most one item in order to adhere to the
21    /// given limiter. I.e. if it already has an item buffered and needs to wait
22    /// it will not `poll` the underlying stream.
23    fn ratelimit_stream<
24        D: DirectStateStore,
25        C: clock::ReasonablyRealtime,
26        MW: RateLimitingMiddleware<C::Instant>,
27    >(
28        self,
29        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
30    ) -> RatelimitedStream<'a, Self, D, C, MW>
31    where
32        Self: Sized;
33
34    /// Limits the rate at which the stream produces items, with a randomized wait period.
35    ///
36    /// Note that this combinator limits the rate at which it yields
37    /// items, not necessarily the rate at which the underlying stream is polled.
38    /// The combinator will buffer at most one item in order to adhere to the
39    /// given limiter. I.e. if it already has an item buffered and needs to wait
40    /// it will not `poll` the underlying stream.
41    fn ratelimit_stream_with_jitter<
42        D: DirectStateStore,
43        C: clock::ReasonablyRealtime,
44        MW: RateLimitingMiddleware<C::Instant>,
45    >(
46        self,
47        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
48        jitter: Jitter,
49    ) -> RatelimitedStream<'a, Self, D, C, MW>
50    where
51        Self: Sized;
52}
53
54impl<'a, S: Stream> StreamRateLimitExt<'a> for S {
55    fn ratelimit_stream<
56        D: DirectStateStore,
57        C: clock::ReasonablyRealtime,
58        MW: RateLimitingMiddleware<C::Instant>,
59    >(
60        self,
61        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
62    ) -> RatelimitedStream<'a, Self, D, C, MW>
63    where
64        Self: Sized,
65    {
66        self.ratelimit_stream_with_jitter(limiter, Jitter::NONE)
67    }
68
69    fn ratelimit_stream_with_jitter<
70        D: DirectStateStore,
71        C: clock::ReasonablyRealtime,
72        MW: RateLimitingMiddleware<C::Instant>,
73    >(
74        self,
75        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
76        jitter: Jitter,
77    ) -> RatelimitedStream<'a, Self, D, C, MW>
78    where
79        Self: Sized,
80    {
81        RatelimitedStream {
82            inner: self,
83            limiter,
84            buf: None,
85            delay: Delay::new(Duration::new(0, 0)),
86            jitter,
87            state: State::ReadInner,
88        }
89    }
90}
91
92enum State {
93    ReadInner,
94    NotReady,
95    Wait,
96}
97
98/// A [`Stream`][futures_util::Stream] combinator which will limit the rate of items being received.
99///
100/// This is produced by the [`StreamRateLimitExt::ratelimit_stream`] and
101/// [`StreamRateLimitExt::ratelimit_stream_with_jitter`] methods.
102pub struct RatelimitedStream<
103    'a,
104    S: Stream,
105    D: DirectStateStore,
106    C: clock::Clock,
107    MW: RateLimitingMiddleware<C::Instant>,
108> {
109    inner: S,
110    limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
111    delay: Delay,
112    buf: Option<S::Item>,
113    jitter: Jitter,
114    state: State,
115}
116
117/// Conversion methods for the stream combinator.
118impl<S: Stream, D: DirectStateStore, C: clock::Clock, MW: RateLimitingMiddleware<C::Instant>>
119    RatelimitedStream<'_, S, D, C, MW>
120{
121    /// Acquires a reference to the underlying stream that this combinator is pulling from.
122    /// ```rust
123    /// # use futures_util::{Stream, stream};
124    /// # use governor::{prelude::*, Quota, RateLimiter};
125    /// # use nonzero_ext::nonzero;
126    /// let inner = stream::repeat(());
127    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
128    /// let outer = inner.clone().ratelimit_stream(&lim);
129    /// assert!(outer.get_ref().size_hint().1.is_none());
130    /// assert_eq!(outer.size_hint(), outer.get_ref().size_hint());
131    /// ```
132    pub fn get_ref(&self) -> &S {
133        &self.inner
134    }
135
136    /// Acquires a mutable reference to the underlying stream that this combinator is pulling from.
137    /// ```rust
138    /// # use futures_util::{stream, StreamExt};
139    /// # use futures_executor::block_on;
140    /// # use governor::{prelude::*, Quota, RateLimiter};
141    /// # use nonzero_ext::nonzero;
142    /// let inner = stream::repeat(());
143    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
144    /// let mut outer = inner.clone().ratelimit_stream(&lim);
145    /// assert_eq!(block_on(outer.get_mut().next()), Some(()));
146    /// ```
147    pub fn get_mut(&mut self) -> &mut S {
148        &mut self.inner
149    }
150
151    /// Consumes this combinator, returning the underlying stream and any item
152    /// which it has already produced but which is still being held back
153    /// in order to abide by the limiter.
154    /// ```rust
155    /// # use futures_util::{stream, StreamExt};
156    /// # use futures_executor::block_on;
157    /// # use governor::{prelude::*, Quota, RateLimiter};
158    /// # use nonzero_ext::nonzero;
159    /// let inner = stream::repeat(());
160    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
161    /// let mut outer = inner.clone().ratelimit_stream(&lim);
162    /// let (mut inner_again, _) = outer.into_inner();
163    /// assert_eq!(block_on(inner_again.next()), Some(()));
164    /// ```
165    pub fn into_inner(self) -> (S, Option<S::Item>) {
166        (self.inner, self.buf)
167    }
168}
169
170/// Implements the [`futures_util::Stream`] combinator.
171impl<S: Stream, D: DirectStateStore, C: clock::Clock, MW> Stream
172    for RatelimitedStream<'_, S, D, C, MW>
173where
174    S: Unpin,
175    S::Item: Unpin,
176    Self: Unpin,
177    C: clock::ReasonablyRealtime,
178    MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
179{
180    type Item = S::Item;
181
182    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183        loop {
184            match self.state {
185                State::ReadInner => {
186                    let inner = Pin::new(&mut self.inner);
187                    match inner.poll_next(cx) {
188                        Poll::Pending => return Poll::Pending,
189                        Poll::Ready(None) => {
190                            // never talk tome or my inner again
191                            return Poll::Ready(None);
192                        }
193                        Poll::Ready(Some(x)) => {
194                            self.buf.replace(x);
195                            self.state = State::NotReady;
196                        }
197                    }
198                }
199                State::NotReady => {
200                    let reference = self.limiter.reference_reading();
201                    if let Err(negative) = self.limiter.check() {
202                        let earliest = negative.wait_time_with_offset(reference, self.jitter);
203                        self.delay.reset(earliest);
204                        let future = Pin::new(&mut self.delay);
205                        match future.poll(cx) {
206                            Poll::Pending => {
207                                self.state = State::Wait;
208                                return Poll::Pending;
209                            }
210                            Poll::Ready(_) => {}
211                        }
212                    } else {
213                        self.state = State::ReadInner;
214                        return Poll::Ready(self.buf.take());
215                    }
216                }
217                State::Wait => {
218                    let future = Pin::new(&mut self.delay);
219                    match future.poll(cx) {
220                        Poll::Pending => {
221                            return Poll::Pending;
222                        }
223                        Poll::Ready(_) => {
224                            self.state = State::NotReady;
225                        }
226                    }
227                }
228            }
229        }
230    }
231
232    fn size_hint(&self) -> (usize, Option<usize>) {
233        self.inner.size_hint()
234    }
235}
236
237/// Pass-through implementation for [`futures_util::Sink`] if the Stream also implements it.
238impl<
239        Item,
240        S: Stream + Sink<Item>,
241        D: DirectStateStore,
242        C: clock::Clock,
243        MW: RateLimitingMiddleware<C::Instant>,
244    > Sink<Item> for RatelimitedStream<'_, S, D, C, MW>
245where
246    S: Unpin,
247    S::Item: Unpin,
248{
249    type Error = <S as Sink<Item>>::Error;
250
251    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252        let inner = Pin::new(&mut self.inner);
253        inner.poll_ready(cx)
254    }
255
256    fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
257        let inner = Pin::new(&mut self.inner);
258        inner.start_send(item)
259    }
260
261    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        let inner = Pin::new(&mut self.inner);
263        inner.poll_flush(cx)
264    }
265
266    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267        let inner = Pin::new(&mut self.inner);
268        inner.poll_close(cx)
269    }
270}