1#![cfg(feature = "std")]
2
3use std::prelude::v1::*;
4
5use crate::{clock, Jitter, NotUntil, RateLimiter};
6use crate::{
7 middleware::RateLimitingMiddleware,
8 state::{DirectStateStore, NotKeyed},
9};
10use futures::task::{Context, Poll};
11use futures::{Future, Sink, Stream};
12use futures_timer::Delay;
13use std::pin::Pin;
14use std::time::Duration;
15
16pub trait StreamRateLimitExt<'a>: Stream {
18 fn ratelimit_stream<
26 D: DirectStateStore,
27 C: clock::Clock,
28 MW: RateLimitingMiddleware<C::Instant>,
29 >(
30 self,
31 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
32 ) -> RatelimitedStream<'a, Self, D, C, MW>
33 where
34 Self: Sized,
35 C: clock::ReasonablyRealtime;
36
37 fn ratelimit_stream_with_jitter<
45 D: DirectStateStore,
46 C: clock::Clock,
47 MW: RateLimitingMiddleware<C::Instant>,
48 >(
49 self,
50 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
51 jitter: Jitter,
52 ) -> RatelimitedStream<'a, Self, D, C, MW>
53 where
54 Self: Sized,
55 C: clock::ReasonablyRealtime;
56}
57
58impl<'a, S: Stream> StreamRateLimitExt<'a> for S {
59 fn ratelimit_stream<
60 D: DirectStateStore,
61 C: clock::Clock,
62 MW: RateLimitingMiddleware<C::Instant>,
63 >(
64 self,
65 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
66 ) -> RatelimitedStream<'a, Self, D, C, MW>
67 where
68 Self: Sized,
69 C: clock::ReasonablyRealtime,
70 {
71 self.ratelimit_stream_with_jitter(limiter, Jitter::NONE)
72 }
73
74 fn ratelimit_stream_with_jitter<
75 D: DirectStateStore,
76 C: clock::Clock,
77 MW: RateLimitingMiddleware<C::Instant>,
78 >(
79 self,
80 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
81 jitter: Jitter,
82 ) -> RatelimitedStream<'a, Self, D, C, MW>
83 where
84 Self: Sized,
85 C: clock::ReasonablyRealtime,
86 {
87 RatelimitedStream {
88 inner: self,
89 limiter,
90 buf: None,
91 delay: Delay::new(Duration::new(0, 0)),
92 jitter,
93 state: State::ReadInner,
94 }
95 }
96}
97
98enum State {
99 ReadInner,
100 NotReady,
101 Wait,
102}
103
104pub struct RatelimitedStream<
109 'a,
110 S: Stream,
111 D: DirectStateStore,
112 C: clock::Clock,
113 MW: RateLimitingMiddleware<C::Instant>,
114> {
115 inner: S,
116 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
117 delay: Delay,
118 buf: Option<S::Item>,
119 jitter: Jitter,
120 state: State,
121}
122
123impl<
125 'a,
126 S: Stream,
127 D: DirectStateStore,
128 C: clock::Clock,
129 MW: RateLimitingMiddleware<C::Instant>,
130 > RatelimitedStream<'a, S, D, C, MW>
131{
132 pub fn get_ref(&self) -> &S {
144 &self.inner
145 }
146
147 pub fn get_mut(&mut self) -> &mut S {
159 &mut self.inner
160 }
161
162 pub fn into_inner(self) -> (S, Option<S::Item>) {
177 (self.inner, self.buf)
178 }
179}
180
181impl<'a, S: Stream, D: DirectStateStore, C: clock::Clock, MW> Stream
183 for RatelimitedStream<'a, S, D, C, MW>
184where
185 S: Unpin,
186 S::Item: Unpin,
187 Self: Unpin,
188 C: clock::ReasonablyRealtime,
189 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
190{
191 type Item = S::Item;
192
193 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
194 loop {
195 match self.state {
196 State::ReadInner => {
197 let inner = Pin::new(&mut self.inner);
198 match inner.poll_next(cx) {
199 Poll::Pending => return Poll::Pending,
200 Poll::Ready(None) => {
201 return Poll::Ready(None);
203 }
204 Poll::Ready(Some(x)) => {
205 self.buf.replace(x);
206 self.state = State::NotReady;
207 }
208 }
209 }
210 State::NotReady => {
211 let reference = self.limiter.reference_reading();
212 if let Err(negative) = self.limiter.check() {
213 let earliest = negative.wait_time_with_offset(reference, self.jitter);
214 self.delay.reset(earliest);
215 let future = Pin::new(&mut self.delay);
216 match future.poll(cx) {
217 Poll::Pending => {
218 self.state = State::Wait;
219 return Poll::Pending;
220 }
221 Poll::Ready(_) => {}
222 }
223 } else {
224 self.state = State::ReadInner;
225 return Poll::Ready(self.buf.take());
226 }
227 }
228 State::Wait => {
229 let future = Pin::new(&mut self.delay);
230 match future.poll(cx) {
231 Poll::Pending => {
232 return Poll::Pending;
233 }
234 Poll::Ready(_) => {
235 self.state = State::NotReady;
236 }
237 }
238 }
239 }
240 }
241 }
242
243 fn size_hint(&self) -> (usize, Option<usize>) {
244 self.inner.size_hint()
245 }
246}
247
248impl<
250 'a,
251 Item,
252 S: Stream + Sink<Item>,
253 D: DirectStateStore,
254 C: clock::Clock,
255 MW: RateLimitingMiddleware<C::Instant>,
256 > Sink<Item> for RatelimitedStream<'a, S, D, C, MW>
257where
258 S: Unpin,
259 S::Item: Unpin,
260{
261 type Error = <S as Sink<Item>>::Error;
262
263 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
264 let inner = Pin::new(&mut self.inner);
265 inner.poll_ready(cx)
266 }
267
268 fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
269 let inner = Pin::new(&mut self.inner);
270 inner.start_send(item)
271 }
272
273 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
274 let inner = Pin::new(&mut self.inner);
275 inner.poll_flush(cx)
276 }
277
278 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
279 let inner = Pin::new(&mut self.inner);
280 inner.poll_close(cx)
281 }
282}