1use super::future::ResponseFuture;
2use tokio::sync::{OwnedSemaphorePermit, Semaphore};
3use tokio_util::sync::PollSemaphore;
4use tower_service::Service;
56use futures_core::ready;
7use std::{
8 sync::Arc,
9 task::{Context, Poll},
10};
1112/// Enforces a limit on the concurrent number of requests the underlying
13/// service can handle.
14#[derive(Debug)]
15pub struct ConcurrencyLimit<T> {
16 inner: T,
17 semaphore: PollSemaphore,
18/// The currently acquired semaphore permit, if there is sufficient
19 /// concurrency to send a new request.
20 ///
21 /// The permit is acquired in `poll_ready`, and taken in `call` when sending
22 /// a new request.
23permit: Option<OwnedSemaphorePermit>,
24}
2526impl<T> ConcurrencyLimit<T> {
27/// Create a new concurrency limiter.
28pub fn new(inner: T, max: usize) -> Self {
29Self::with_semaphore(inner, Arc::new(Semaphore::new(max)))
30 }
3132/// Create a new concurrency limiter with a provided shared semaphore
33pub fn with_semaphore(inner: T, semaphore: Arc<Semaphore>) -> Self {
34 ConcurrencyLimit {
35 inner,
36 semaphore: PollSemaphore::new(semaphore),
37 permit: None,
38 }
39 }
4041/// Get a reference to the inner service
42pub fn get_ref(&self) -> &T {
43&self.inner
44 }
4546/// Get a mutable reference to the inner service
47pub fn get_mut(&mut self) -> &mut T {
48&mut self.inner
49 }
5051/// Consume `self`, returning the inner service
52pub fn into_inner(self) -> T {
53self.inner
54 }
55}
5657impl<S, Request> Service<Request> for ConcurrencyLimit<S>
58where
59S: Service<Request>,
60{
61type Response = S::Response;
62type Error = S::Error;
63type Future = ResponseFuture<S::Future>;
6465fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
66// If we haven't already acquired a permit from the semaphore, try to
67 // acquire one first.
68if self.permit.is_none() {
69self.permit = ready!(self.semaphore.poll_acquire(cx));
70debug_assert!(
71self.permit.is_some(),
72"ConcurrencyLimit semaphore is never closed, so `poll_acquire` \
73 should never fail",
74 );
75 }
7677// Once we've acquired a permit (or if we already had one), poll the
78 // inner service.
79self.inner.poll_ready(cx)
80 }
8182fn call(&mut self, request: Request) -> Self::Future {
83// Take the permit
84let permit = self
85.permit
86 .take()
87 .expect("max requests in-flight; poll_ready must be called first");
8889// Call the inner service
90let future = self.inner.call(request);
9192 ResponseFuture::new(future, permit)
93 }
94}
9596impl<T: Clone> Clone for ConcurrencyLimit<T> {
97fn clone(&self) -> Self {
98// Since we hold an `OwnedSemaphorePermit`, we can't derive `Clone`.
99 // Instead, when cloning the service, create a new service with the
100 // same semaphore, but with the permit in the un-acquired state.
101Self {
102 inner: self.inner.clone(),
103 semaphore: self.semaphore.clone(),
104 permit: None,
105 }
106 }
107}
108109#[cfg(feature = "load")]
110#[cfg_attr(docsrs, doc(cfg(feature = "load")))]
111impl<S> crate::load::Load for ConcurrencyLimit<S>
112where
113S: crate::load::Load,
114{
115type Metric = S::Metric;
116fn load(&self) -> Self::Metric {
117self.inner.load()
118 }
119}