1use std::prelude::v1::*;
2
3use crate::{clock, Jitter, NotUntil, RateLimiter};
4use crate::{
5 middleware::RateLimitingMiddleware,
6 state::{DirectStateStore, NotKeyed},
7};
8use futures_timer::Delay;
9use futures_util::task::{Context, Poll};
10use futures_util::{Future, Sink, Stream};
11use std::pin::Pin;
12use std::time::Duration;
13
14pub trait StreamRateLimitExt<'a>: Stream {
16 fn ratelimit_stream<
24 D: DirectStateStore,
25 C: clock::ReasonablyRealtime,
26 MW: RateLimitingMiddleware<C::Instant>,
27 >(
28 self,
29 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
30 ) -> RatelimitedStream<'a, Self, D, C, MW>
31 where
32 Self: Sized;
33
34 fn ratelimit_stream_with_jitter<
42 D: DirectStateStore,
43 C: clock::ReasonablyRealtime,
44 MW: RateLimitingMiddleware<C::Instant>,
45 >(
46 self,
47 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
48 jitter: Jitter,
49 ) -> RatelimitedStream<'a, Self, D, C, MW>
50 where
51 Self: Sized;
52}
53
54impl<'a, S: Stream> StreamRateLimitExt<'a> for S {
55 fn ratelimit_stream<
56 D: DirectStateStore,
57 C: clock::ReasonablyRealtime,
58 MW: RateLimitingMiddleware<C::Instant>,
59 >(
60 self,
61 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
62 ) -> RatelimitedStream<'a, Self, D, C, MW>
63 where
64 Self: Sized,
65 {
66 self.ratelimit_stream_with_jitter(limiter, Jitter::NONE)
67 }
68
69 fn ratelimit_stream_with_jitter<
70 D: DirectStateStore,
71 C: clock::ReasonablyRealtime,
72 MW: RateLimitingMiddleware<C::Instant>,
73 >(
74 self,
75 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
76 jitter: Jitter,
77 ) -> RatelimitedStream<'a, Self, D, C, MW>
78 where
79 Self: Sized,
80 {
81 RatelimitedStream {
82 inner: self,
83 limiter,
84 buf: None,
85 delay: Delay::new(Duration::new(0, 0)),
86 jitter,
87 state: State::ReadInner,
88 }
89 }
90}
91
92enum State {
93 ReadInner,
94 NotReady,
95 Wait,
96}
97
98pub struct RatelimitedStream<
103 'a,
104 S: Stream,
105 D: DirectStateStore,
106 C: clock::Clock,
107 MW: RateLimitingMiddleware<C::Instant>,
108> {
109 inner: S,
110 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
111 delay: Delay,
112 buf: Option<S::Item>,
113 jitter: Jitter,
114 state: State,
115}
116
117impl<S: Stream, D: DirectStateStore, C: clock::Clock, MW: RateLimitingMiddleware<C::Instant>>
119 RatelimitedStream<'_, S, D, C, MW>
120{
121 pub fn get_ref(&self) -> &S {
133 &self.inner
134 }
135
136 pub fn get_mut(&mut self) -> &mut S {
148 &mut self.inner
149 }
150
151 pub fn into_inner(self) -> (S, Option<S::Item>) {
166 (self.inner, self.buf)
167 }
168}
169
170impl<S: Stream, D: DirectStateStore, C: clock::Clock, MW> Stream
172 for RatelimitedStream<'_, S, D, C, MW>
173where
174 S: Unpin,
175 S::Item: Unpin,
176 Self: Unpin,
177 C: clock::ReasonablyRealtime,
178 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
179{
180 type Item = S::Item;
181
182 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
183 loop {
184 match self.state {
185 State::ReadInner => {
186 let inner = Pin::new(&mut self.inner);
187 match inner.poll_next(cx) {
188 Poll::Pending => return Poll::Pending,
189 Poll::Ready(None) => {
190 return Poll::Ready(None);
192 }
193 Poll::Ready(Some(x)) => {
194 self.buf.replace(x);
195 self.state = State::NotReady;
196 }
197 }
198 }
199 State::NotReady => {
200 let reference = self.limiter.reference_reading();
201 if let Err(negative) = self.limiter.check() {
202 let earliest = negative.wait_time_with_offset(reference, self.jitter);
203 self.delay.reset(earliest);
204 let future = Pin::new(&mut self.delay);
205 match future.poll(cx) {
206 Poll::Pending => {
207 self.state = State::Wait;
208 return Poll::Pending;
209 }
210 Poll::Ready(_) => {}
211 }
212 } else {
213 self.state = State::ReadInner;
214 return Poll::Ready(self.buf.take());
215 }
216 }
217 State::Wait => {
218 let future = Pin::new(&mut self.delay);
219 match future.poll(cx) {
220 Poll::Pending => {
221 return Poll::Pending;
222 }
223 Poll::Ready(_) => {
224 self.state = State::NotReady;
225 }
226 }
227 }
228 }
229 }
230 }
231
232 fn size_hint(&self) -> (usize, Option<usize>) {
233 self.inner.size_hint()
234 }
235}
236
237impl<
239 Item,
240 S: Stream + Sink<Item>,
241 D: DirectStateStore,
242 C: clock::Clock,
243 MW: RateLimitingMiddleware<C::Instant>,
244 > Sink<Item> for RatelimitedStream<'_, S, D, C, MW>
245where
246 S: Unpin,
247 S::Item: Unpin,
248{
249 type Error = <S as Sink<Item>>::Error;
250
251 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252 let inner = Pin::new(&mut self.inner);
253 inner.poll_ready(cx)
254 }
255
256 fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
257 let inner = Pin::new(&mut self.inner);
258 inner.start_send(item)
259 }
260
261 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 let inner = Pin::new(&mut self.inner);
263 inner.poll_flush(cx)
264 }
265
266 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267 let inner = Pin::new(&mut self.inner);
268 inner.poll_close(cx)
269 }
270}