1use std::prelude::v1::*;
2
3use crate::{
4 clock,
5 middleware::RateLimitingMiddleware,
6 state::{DirectStateStore, NotKeyed},
7 Jitter, NotUntil, RateLimiter,
8};
9use futures_timer::Delay;
10use futures_util::task::{Context, Poll};
11use futures_util::{Future, Sink, Stream};
12use std::marker::PhantomData;
13use std::pin::Pin;
14
15pub trait SinkRateLimitExt<Item, S>: Sink<Item>
17where
18 S: Sink<Item>,
19{
20 fn ratelimit_sink<
22 D: DirectStateStore,
23 C: clock::ReasonablyRealtime,
24 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
25 >(
26 self,
27 limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
28 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
29 where
30 Self: Sized;
31
32 #[cfg(feature = "jitter")]
35 fn ratelimit_sink_with_jitter<
36 D: DirectStateStore,
37 C: clock::ReasonablyRealtime,
38 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
39 >(
40 self,
41 limiter: &'_ RateLimiter<NotKeyed, D, C, MW>,
42 jitter: Jitter,
43 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
44 where
45 Self: Sized;
46}
47
48impl<Item, S: Sink<Item>> SinkRateLimitExt<Item, S> for S {
49 fn ratelimit_sink<
50 D: DirectStateStore,
51 C: clock::ReasonablyRealtime,
52 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
53 >(
54 self,
55 limiter: &RateLimiter<NotKeyed, D, C, MW>,
56 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
57 where
58 Self: Sized,
59 {
60 RatelimitedSink::new(self, limiter, Jitter::NONE)
61 }
62
63 #[cfg(feature = "jitter")]
64 fn ratelimit_sink_with_jitter<
65 D: DirectStateStore,
66 C: clock::ReasonablyRealtime,
67 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
68 >(
69 self,
70 limiter: &RateLimiter<NotKeyed, D, C, MW>,
71 jitter: Jitter,
72 ) -> RatelimitedSink<'_, Item, S, D, C, MW>
73 where
74 Self: Sized,
75 {
76 RatelimitedSink::new(self, limiter, jitter)
77 }
78}
79
80#[derive(Debug)]
81enum State {
82 NotReady,
83 Wait,
84 Ready,
85}
86
87pub struct RatelimitedSink<
90 'a,
91 Item,
92 S: Sink<Item>,
93 D: DirectStateStore,
94 C: clock::ReasonablyRealtime,
95 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
96> {
97 inner: S,
98 state: State,
99 limiter: &'a RateLimiter<NotKeyed, D, C, MW>,
100 delay: Delay,
101 jitter: Jitter,
102 phantom: PhantomData<Item>,
103}
104
105impl<
107 'a,
108 Item,
109 S: Sink<Item>,
110 D: DirectStateStore,
111 C: clock::ReasonablyRealtime,
112 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
113 > RatelimitedSink<'a, Item, S, D, C, MW>
114{
115 fn new(inner: S, limiter: &'a RateLimiter<NotKeyed, D, C, MW>, jitter: Jitter) -> Self {
116 RatelimitedSink {
117 inner,
118 limiter,
119 delay: Delay::new(Default::default()),
120 state: State::NotReady,
121 jitter,
122 phantom: PhantomData,
123 }
124 }
125
126 pub fn get_ref(&self) -> &S {
128 &self.inner
129 }
130
131 pub fn get_mut(&mut self) -> &mut S {
145 &mut self.inner
146 }
147
148 pub fn into_inner(self) -> S {
150 self.inner
151 }
152}
153
154impl<
155 Item,
156 S: Sink<Item>,
157 D: DirectStateStore,
158 C: clock::ReasonablyRealtime,
159 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
160 > Sink<Item> for RatelimitedSink<'_, Item, S, D, C, MW>
161where
162 S: Unpin,
163 Item: Unpin,
164{
165 type Error = S::Error;
166
167 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
168 loop {
169 match self.state {
170 State::NotReady => {
171 let reference = self.limiter.reference_reading();
172 if let Err(negative) = self.limiter.check() {
173 let earliest = negative.wait_time_with_offset(reference, self.jitter);
174 self.delay.reset(earliest);
175 let future = Pin::new(&mut self.delay);
176 match future.poll(cx) {
177 Poll::Pending => {
178 self.state = State::Wait;
179 return Poll::Pending;
180 }
181 Poll::Ready(_) => {}
182 }
183 } else {
184 self.state = State::Ready;
185 }
186 }
187 State::Wait => {
188 let future = Pin::new(&mut self.delay);
189 match future.poll(cx) {
190 Poll::Pending => {
191 return Poll::Pending;
192 }
193 Poll::Ready(_) => {
194 self.state = State::NotReady;
195 }
196 }
197 }
198 State::Ready => {
199 let inner = Pin::new(&mut self.inner);
200 return inner.poll_ready(cx);
201 }
202 }
203 }
204 }
205
206 fn start_send(mut self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
207 match self.state {
208 State::Wait | State::NotReady => {
209 unreachable!("Must not start_send before we're ready"); }
211 State::Ready => {
212 self.state = State::NotReady;
213 let inner = Pin::new(&mut self.inner);
214 inner.start_send(item)
215 }
216 }
217 }
218
219 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220 let inner = Pin::new(&mut self.inner);
221 inner.poll_flush(cx)
222 }
223
224 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 let inner = Pin::new(&mut self.inner);
226 inner.poll_close(cx)
227 }
228}
229
230impl<
232 Item,
233 S: Stream + Sink<Item>,
234 D: DirectStateStore,
235 C: clock::ReasonablyRealtime,
236 MW: RateLimitingMiddleware<C::Instant, NegativeOutcome = NotUntil<C::Instant>>,
237 > Stream for RatelimitedSink<'_, Item, S, D, C, MW>
238where
239 S::Item: Unpin,
240 S: Unpin,
241 Item: Unpin,
242{
243 type Item = <S as Stream>::Item;
244
245 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
246 let inner = Pin::new(&mut self.inner);
247 inner.poll_next(cx)
248 }
249
250 fn size_hint(&self) -> (usize, Option<usize>) {
251 self.inner.size_hint()
252 }
253}