tower/limit/rate/
service.rs

1use super::Rate;
2use futures_core::ready;
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::time::{Instant, Sleep};
9use tower_service::Service;
10
11/// Enforces a rate limit on the number of requests the underlying
12/// service can handle over a period of time.
13#[derive(Debug)]
14pub struct RateLimit<T> {
15    inner: T,
16    rate: Rate,
17    state: State,
18    sleep: Pin<Box<Sleep>>,
19}
20
21#[derive(Debug)]
22enum State {
23    // The service has hit its limit
24    Limited,
25    Ready { until: Instant, rem: u64 },
26}
27
28impl<T> RateLimit<T> {
29    /// Create a new rate limiter
30    pub fn new(inner: T, rate: Rate) -> Self {
31        let until = Instant::now();
32        let state = State::Ready {
33            until,
34            rem: rate.num(),
35        };
36
37        RateLimit {
38            inner,
39            rate,
40            state,
41            // The sleep won't actually be used with this duration, but
42            // we create it eagerly so that we can reset it in place rather than
43            // `Box::pin`ning a new `Sleep` every time we need one.
44            sleep: Box::pin(tokio::time::sleep_until(until)),
45        }
46    }
47
48    /// Get a reference to the inner service
49    pub fn get_ref(&self) -> &T {
50        &self.inner
51    }
52
53    /// Get a mutable reference to the inner service
54    pub fn get_mut(&mut self) -> &mut T {
55        &mut self.inner
56    }
57
58    /// Consume `self`, returning the inner service
59    pub fn into_inner(self) -> T {
60        self.inner
61    }
62}
63
64impl<S, Request> Service<Request> for RateLimit<S>
65where
66    S: Service<Request>,
67{
68    type Response = S::Response;
69    type Error = S::Error;
70    type Future = S::Future;
71
72    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        match self.state {
74            State::Ready { .. } => return Poll::Ready(ready!(self.inner.poll_ready(cx))),
75            State::Limited => {
76                if Pin::new(&mut self.sleep).poll(cx).is_pending() {
77                    tracing::trace!("rate limit exceeded; sleeping.");
78                    return Poll::Pending;
79                }
80            }
81        }
82
83        self.state = State::Ready {
84            until: Instant::now() + self.rate.per(),
85            rem: self.rate.num(),
86        };
87
88        Poll::Ready(ready!(self.inner.poll_ready(cx)))
89    }
90
91    fn call(&mut self, request: Request) -> Self::Future {
92        match self.state {
93            State::Ready { mut until, mut rem } => {
94                let now = Instant::now();
95
96                // If the period has elapsed, reset it.
97                if now >= until {
98                    until = now + self.rate.per();
99                    rem = self.rate.num();
100                }
101
102                if rem > 1 {
103                    rem -= 1;
104                    self.state = State::Ready { until, rem };
105                } else {
106                    // The service is disabled until further notice
107                    // Reset the sleep future in place, so that we don't have to
108                    // deallocate the existing box and allocate a new one.
109                    self.sleep.as_mut().reset(until);
110                    self.state = State::Limited;
111                }
112
113                // Call the inner future
114                self.inner.call(request)
115            }
116            State::Limited => panic!("service not ready; poll_ready must be called first"),
117        }
118    }
119}
120
121#[cfg(feature = "load")]
122#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
123impl<S> crate::load::Load for RateLimit<S>
124where
125    S: crate::load::Load,
126{
127    type Metric = S::Metric;
128    fn load(&self) -> Self::Metric {
129        self.inner.load()
130    }
131}