governor/state/direct/
streams.rs

1#![cfg(feature = "std")]
2
3use std::prelude::v1::*;
4
5use crate::{clock, Jitter, NotUntil, RateLimiter};
6use crate::{
7    middleware::RateLimitingMiddleware,
8    state::{DirectStateStore, NotKeyed},
9};
10use futures::task::{Context, Poll};
11use futures::{Future, Sink, Stream};
12use futures_timer::Delay;
13use std::pin::Pin;
14use std::time::Duration;
15
16/// Allows converting a [`futures::Stream`] combinator into a rate-limited stream.
17pub trait StreamRateLimitExt<'a>: Stream {
18    /// Limits the rate at which the stream produces items.
19    ///
20    /// Note that this combinator limits the rate at which it yields
21    /// items, not necessarily the rate at which the underlying stream is polled.
22    /// The combinator will buffer at most one item in order to adhere to the
23    /// given limiter. I.e. if it already has an item buffered and needs to wait
24    /// it will not `poll` the underlying stream.
25    fn ratelimit_stream<
26        D: DirectStateStore,
27        C: clock::Clock,
28        MW: RateLimitingMiddleware<C::Instant>,
29    >(
30        self,
31        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
32    ) -> RatelimitedStream<'a, Self, D, C, MW>
33    where
34        Self: Sized,
35        C: clock::ReasonablyRealtime;
36
37    /// Limits the rate at which the stream produces items, with a randomized wait period.
38    ///
39    /// Note that this combinator limits the rate at which it yields
40    /// items, not necessarily the rate at which the underlying stream is polled.
41    /// The combinator will buffer at most one item in order to adhere to the
42    /// given limiter. I.e. if it already has an item buffered and needs to wait
43    /// it will not `poll` the underlying stream.
44    fn ratelimit_stream_with_jitter<
45        D: DirectStateStore,
46        C: clock::Clock,
47        MW: RateLimitingMiddleware<C::Instant>,
48    >(
49        self,
50        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
51        jitter: Jitter,
52    ) -> RatelimitedStream<'a, Self, D, C, MW>
53    where
54        Self: Sized,
55        C: clock::ReasonablyRealtime;
56}
57
58impl<'a, S: Stream> StreamRateLimitExt<'a> for S {
59    fn ratelimit_stream<
60        D: DirectStateStore,
61        C: clock::Clock,
62        MW: RateLimitingMiddleware<C::Instant>,
63    >(
64        self,
65        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
66    ) -> RatelimitedStream<'a, Self, D, C, MW>
67    where
68        Self: Sized,
69        C: clock::ReasonablyRealtime,
70    {
71        self.ratelimit_stream_with_jitter(limiter, Jitter::NONE)
72    }
73
74    fn ratelimit_stream_with_jitter<
75        D: DirectStateStore,
76        C: clock::Clock,
77        MW: RateLimitingMiddleware<C::Instant>,
78    >(
79        self,
80        limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
81        jitter: Jitter,
82    ) -> RatelimitedStream<'a, Self, D, C, MW>
83    where
84        Self: Sized,
85        C: clock::ReasonablyRealtime,
86    {
87        RatelimitedStream {
88            inner: self,
89            limiter,
90            buf: None,
91            delay: Delay::new(Duration::new(0, 0)),
92            jitter,
93            state: State::ReadInner,
94        }
95    }
96}
97
98enum State {
99    ReadInner,
100    NotReady,
101    Wait,
102}
103
104/// A [`Stream`][futures::Stream] combinator which will limit the rate of items being received.
105///
106/// This is produced by the [`StreamRateLimitExt::ratelimit_stream`] and
107/// [`StreamRateLimitExt::ratelimit_stream_with_jitter`] methods.
108pub struct RatelimitedStream<
109    'a,
110    S: Stream,
111    D: DirectStateStore,
112    C: clock::Clock,
113    MW: RateLimitingMiddleware<C::Instant>,
114> {
115    inner: S,
116    limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
117    delay: Delay,
118    buf: Option<S::Item>,
119    jitter: Jitter,
120    state: State,
121}
122
123/// Conversion methods for the stream combinator.
124impl<
125        'a,
126        S: Stream,
127        D: DirectStateStore,
128        C: clock::Clock,
129        MW: RateLimitingMiddleware<C::Instant>,
130    > RatelimitedStream<'a, S, D, C, MW>
131{
132    /// Acquires a reference to the underlying stream that this combinator is pulling from.
133    /// ```rust
134    /// # use futures::{Stream, stream};
135    /// # use governor::{prelude::*, Quota, RateLimiter};
136    /// # use nonzero_ext::nonzero;
137    /// let inner = stream::repeat(());
138    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
139    /// let outer = inner.clone().ratelimit_stream(&lim);
140    /// assert!(outer.get_ref().size_hint().1.is_none());
141    /// assert_eq!(outer.size_hint(), outer.get_ref().size_hint());
142    /// ```
143    pub fn get_ref(&self) -> &S {
144        &self.inner
145    }
146
147    /// Acquires a mutable reference to the underlying stream that this combinator is pulling from.
148    /// ```rust
149    /// # use futures::{stream, StreamExt};
150    /// # use futures::executor::block_on;
151    /// # use governor::{prelude::*, Quota, RateLimiter};
152    /// # use nonzero_ext::nonzero;
153    /// let inner = stream::repeat(());
154    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
155    /// let mut outer = inner.clone().ratelimit_stream(&lim);
156    /// assert_eq!(block_on(outer.get_mut().next()), Some(()));
157    /// ```
158    pub fn get_mut(&mut self) -> &mut S {
159        &mut self.inner
160    }
161
162    /// Consumes this combinator, returning the underlying stream and any item
163    /// which it has already produced but which is still being held back
164    /// in order to abide by the limiter.
165    /// ```rust
166    /// # use futures::{stream, StreamExt};
167    /// # use futures::executor::block_on;
168    /// # use governor::{prelude::*, Quota, RateLimiter};
169    /// # use nonzero_ext::nonzero;
170    /// let inner = stream::repeat(());
171    /// let lim = RateLimiter::direct(Quota::per_second(nonzero!(10u32)));
172    /// let mut outer = inner.clone().ratelimit_stream(&lim);
173    /// let (mut inner_again, _) = outer.into_inner();
174    /// assert_eq!(block_on(inner_again.next()), Some(()));
175    /// ```
176    pub fn into_inner(self) -> (S, Option<S::Item>) {
177        (self.inner, self.buf)
178    }
179}
180
181/// Implements the [`futures::Stream`] combinator.
182impl<'a, S: Stream, D: DirectStateStore, C: clock::Clock, MW> Stream
183    for RatelimitedStream<'a, S, D, C, MW>
184where
185    S: Unpin,
186    S::Item: Unpin,
187    Self: Unpin,
188    C: clock::ReasonablyRealtime,
189    MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
190{
191    type Item = S::Item;
192
193    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194        loop {
195            match self.state {
196                State::ReadInner => {
197                    let inner = Pin::new(&mut self.inner);
198                    match inner.poll_next(cx) {
199                        Poll::Pending => return Poll::Pending,
200                        Poll::Ready(None) => {
201                            // never talk tome or my inner again
202                            return Poll::Ready(None);
203                        }
204                        Poll::Ready(Some(x)) => {
205                            self.buf.replace(x);
206                            self.state = State::NotReady;
207                        }
208                    }
209                }
210                State::NotReady => {
211                    let reference = self.limiter.reference_reading();
212                    if let Err(negative) = self.limiter.check() {
213                        let earliest = negative.wait_time_with_offset(reference, self.jitter);
214                        self.delay.reset(earliest);
215                        let future = Pin::new(&mut self.delay);
216                        match future.poll(cx) {
217                            Poll::Pending => {
218                                self.state = State::Wait;
219                                return Poll::Pending;
220                            }
221                            Poll::Ready(_) => {}
222                        }
223                    } else {
224                        self.state = State::ReadInner;
225                        return Poll::Ready(self.buf.take());
226                    }
227                }
228                State::Wait => {
229                    let future = Pin::new(&mut self.delay);
230                    match future.poll(cx) {
231                        Poll::Pending => {
232                            return Poll::Pending;
233                        }
234                        Poll::Ready(_) => {
235                            self.state = State::NotReady;
236                        }
237                    }
238                }
239            }
240        }
241    }
242
243    fn size_hint(&self) -> (usize, Option<usize>) {
244        self.inner.size_hint()
245    }
246}
247
248/// Pass-through implementation for [`futures::Sink`] if the Stream also implements it.
249impl<
250        'a,
251        Item,
252        S: Stream + Sink<Item>,
253        D: DirectStateStore,
254        C: clock::Clock,
255        MW: RateLimitingMiddleware<C::Instant>,
256    > Sink<Item> for RatelimitedStream<'a, S, D, C, MW>
257where
258    S: Unpin,
259    S::Item: Unpin,
260{
261    type Error = <S as Sink<Item>>::Error;
262
263    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
264        let inner = Pin::new(&mut self.inner);
265        inner.poll_ready(cx)
266    }
267
268    fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
269        let inner = Pin::new(&mut self.inner);
270        inner.start_send(item)
271    }
272
273    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274        let inner = Pin::new(&mut self.inner);
275        inner.poll_flush(cx)
276    }
277
278    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279        let inner = Pin::new(&mut self.inner);
280        inner.poll_close(cx)
281    }
282}