1use super::Rate;
2use futures_core::ready;
3use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::time::{Instant, Sleep};
9use tower_service::Service;
1011/// Enforces a rate limit on the number of requests the underlying
12/// service can handle over a period of time.
13#[derive(Debug)]
14pub struct RateLimit<T> {
15 inner: T,
16 rate: Rate,
17 state: State,
18 sleep: Pin<Box<Sleep>>,
19}
2021#[derive(Debug)]
22enum State {
23// The service has hit its limit
24Limited,
25 Ready { until: Instant, rem: u64 },
26}
2728impl<T> RateLimit<T> {
29/// Create a new rate limiter
30pub fn new(inner: T, rate: Rate) -> Self {
31let until = Instant::now();
32let state = State::Ready {
33 until,
34 rem: rate.num(),
35 };
3637 RateLimit {
38 inner,
39 rate,
40 state,
41// The sleep won't actually be used with this duration, but
42 // we create it eagerly so that we can reset it in place rather than
43 // `Box::pin`ning a new `Sleep` every time we need one.
44sleep: Box::pin(tokio::time::sleep_until(until)),
45 }
46 }
4748/// Get a reference to the inner service
49pub fn get_ref(&self) -> &T {
50&self.inner
51 }
5253/// Get a mutable reference to the inner service
54pub fn get_mut(&mut self) -> &mut T {
55&mut self.inner
56 }
5758/// Consume `self`, returning the inner service
59pub fn into_inner(self) -> T {
60self.inner
61 }
62}
6364impl<S, Request> Service<Request> for RateLimit<S>
65where
66S: Service<Request>,
67{
68type Response = S::Response;
69type Error = S::Error;
70type Future = S::Future;
7172fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73match self.state {
74 State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
75 State::Limited => {
76if Pin::new(&mut self.sleep).poll(cx).is_pending() {
77tracing::trace!("rate limit exceeded; sleeping.");
78return Poll::Pending;
79 }
80 }
81 }
8283self.state = State::Ready {
84 until: Instant::now() + self.rate.per(),
85 rem: self.rate.num(),
86 };
8788 Poll::Ready(ready!(self.inner.poll_ready(cx)))
89 }
9091fn call(&mut self, request: Request) -> Self::Future {
92match self.state {
93 State::Ready { mut until, mut rem } => {
94let now = Instant::now();
9596// If the period has elapsed, reset it.
97if now >= until {
98 until = now + self.rate.per();
99 rem = self.rate.num();
100 }
101102if rem > 1 {
103 rem -= 1;
104self.state = State::Ready { until, rem };
105 } else {
106// The service is disabled until further notice
107 // Reset the sleep future in place, so that we don't have to
108 // deallocate the existing box and allocate a new one.
109self.sleep.as_mut().reset(until);
110self.state = State::Limited;
111 }
112113// Call the inner future
114self.inner.call(request)
115 }
116 State::Limited => panic!("service not ready; poll_ready must be called first"),
117 }
118 }
119}
120121#[cfg(feature = "load")]
122#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
123impl<S> crate::load::Load for RateLimit<S>
124where
125S: crate::load::Load,
126{
127type Metric = S::Metric;
128fn load(&self) -> Self::Metric {
129self.inner.load()
130 }
131}