1use std::prelude::v1::*;
2
3use crate::{
4 clock,
5 middleware::RateLimitingMiddleware,
6 state::{DirectStateStore, NotKeyed},
7 Jitter, NotUntil, RateLimiter,
8};
9use futures::task::{Context, Poll};
10use futures::{Future, Sink, Stream};
11use futures_timer::Delay;
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15pub trait SinkRateLimitExt<Item, S>: Sink<Item>
17where
18 S: Sink<Item>,
19{
20 fn ratelimit_sink<
22 D: DirectStateStore,
23 C: clock::ReasonablyRealtime,
24 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
25 >(
26 self,
27 limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
28 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
29 where
30 Self: Sized;
31
32 #[cfg(feature = "jitter")]
35 fn ratelimit_sink_with_jitter<
36 D: DirectStateStore,
37 C: clock::ReasonablyRealtime,
38 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
39 >(
40 self,
41 limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
42 jitter: Jitter,
43 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
44 where
45 Self: Sized;
46}
47
48impl<Item, S: Sink<Item>> SinkRateLimitExt<Item, S> for S {
49 fn ratelimit_sink<
50 D: DirectStateStore,
51 C: clock::ReasonablyRealtime,
52 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
53 >(
54 self,
55 limiter: &RateLimiter<NotKeyed, D, C, MW>,
56 ) -> RatelimitedSink<Item, S, D, C, MW>
57 where
58 Self: Sized,
59 {
60 RatelimitedSink::new(self, limiter, Jitter::NONE)
61 }
62
63 #[cfg(feature = "jitter")]
64 fn ratelimit_sink_with_jitter<
65 D: DirectStateStore,
66 C: clock::ReasonablyRealtime,
67 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
68 >(
69 self,
70 limiter: &RateLimiter<NotKeyed, D, C, MW>,
71 jitter: Jitter,
72 ) -> RatelimitedSink<Item, S, D, C, MW>
73 where
74 Self: Sized,
75 {
76 RatelimitedSink::new(self, limiter, jitter)
77 }
78}
79
80#[derive(Debug)]
81enum State {
82 NotReady,
83 Wait,
84 Ready,
85}
86
87pub struct RatelimitedSink<
90 'a,
91 Item,
92 S: Sink<Item>,
93 D: DirectStateStore,
94 C: clock::ReasonablyRealtime,
95 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
96> {
97 inner: S,
98 state: State,
99 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
100 delay: Delay,
101 jitter: Jitter,
102 phantom: PhantomData<Item>,
103}
104
105impl<
107 'a,
108 Item,
109 S: Sink<Item>,
110 D: DirectStateStore,
111 C: clock::ReasonablyRealtime,
112 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
113 > RatelimitedSink<'a, Item, S, D, C, MW>
114{
115 fn new(inner: S, limiter: &'a RateLimiter<NotKeyed, D, C, MW>, jitter: Jitter) -> Self {
116 RatelimitedSink {
117 inner,
118 limiter,
119 delay: Delay::new(Default::default()),
120 state: State::NotReady,
121 jitter,
122 phantom: PhantomData,
123 }
124 }
125
126 pub fn get_ref(&self) -> &S {
128 &self.inner
129 }
130
131 pub fn get_mut(&mut self) -> &mut S {
145 &mut self.inner
146 }
147
148 pub fn into_inner(self) -> S {
150 self.inner
151 }
152}
153
154impl<
155 'a,
156 Item,
157 S: Sink<Item>,
158 D: DirectStateStore,
159 C: clock::ReasonablyRealtime,
160 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
161 > Sink<Item> for RatelimitedSink<'a, Item, S, D, C, MW>
162where
163 S: Unpin,
164 Item: Unpin,
165{
166 type Error = S::Error;
167
168 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169 loop {
170 match self.state {
171 State::NotReady => {
172 let reference = self.limiter.reference_reading();
173 if let Err(negative) = self.limiter.check() {
174 let earliest = negative.wait_time_with_offset(reference, self.jitter);
175 self.delay.reset(earliest);
176 let future = Pin::new(&mut self.delay);
177 match future.poll(cx) {
178 Poll::Pending => {
179 self.state = State::Wait;
180 return Poll::Pending;
181 }
182 Poll::Ready(_) => {}
183 }
184 } else {
185 self.state = State::Ready;
186 }
187 }
188 State::Wait => {
189 let future = Pin::new(&mut self.delay);
190 match future.poll(cx) {
191 Poll::Pending => {
192 return Poll::Pending;
193 }
194 Poll::Ready(_) => {
195 self.state = State::NotReady;
196 }
197 }
198 }
199 State::Ready => {
200 let inner = Pin::new(&mut self.inner);
201 return inner.poll_ready(cx);
202 }
203 }
204 }
205 }
206
207 fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
208 match self.state {
209 State::Wait | State::NotReady => {
210 unreachable!("Must not start_send before we're ready"); }
212 State::Ready => {
213 self.state = State::NotReady;
214 let inner = Pin::new(&mut self.inner);
215 inner.start_send(item)
216 }
217 }
218 }
219
220 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
221 let inner = Pin::new(&mut self.inner);
222 inner.poll_flush(cx)
223 }
224
225 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
226 let inner = Pin::new(&mut self.inner);
227 inner.poll_close(cx)
228 }
229}
230
231impl<
233 'a,
234 Item,
235 S: Stream + Sink<Item>,
236 D: DirectStateStore,
237 C: clock::ReasonablyRealtime,
238 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
239 > Stream for RatelimitedSink<'a, Item, S, D, C, MW>
240where
241 S::Item: Unpin,
242 S: Unpin,
243 Item: Unpin,
244{
245 type Item = <S as Stream>::Item;
246
247 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
248 let inner = Pin::new(&mut self.inner);
249 inner.poll_next(cx)
250 }
251
252 fn size_hint(&self) -> (usize, Option<usize>) {
253 self.inner.size_hint()
254 }
255}