tower/load/
pending_requests.rs

1//! A [`Load`] implementation that measures load using the number of in-flight requests.
2
3#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16use tower_service::Service;
17
18/// Measures the load of the underlying service using the number of currently-pending requests.
19#[derive(Debug)]
20pub struct PendingRequests<S, C = CompleteOnResponse> {
21    service: S,
22    ref_count: RefCount,
23    completion: C,
24}
25
26/// Shared between instances of [`PendingRequests`] and [`Handle`] to track active references.
27#[derive(Clone, Debug, Default)]
28struct RefCount(Arc<()>);
29
30#[cfg(feature = "discover")]
31pin_project! {
32    /// Wraps a `D`-typed stream of discovered services with [`PendingRequests`].
33    #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
34    #[derive(Debug)]
35    pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
36        #[pin]
37        discover: D,
38        completion: C,
39    }
40}
41
42/// Represents the number of currently-pending requests to a given service.
43#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
44pub struct Count(usize);
45
46/// Tracks an in-flight request by reference count.
47#[derive(Debug)]
48pub struct Handle(RefCount);
49
50// ===== impl PendingRequests =====
51
52impl<S, C> PendingRequests<S, C> {
53    /// Wraps an `S`-typed service so that its load is tracked by the number of pending requests.
54    pub fn new(service: S, completion: C) -> Self {
55        Self {
56            service,
57            completion,
58            ref_count: RefCount::default(),
59        }
60    }
61
62    fn handle(&self) -> Handle {
63        Handle(self.ref_count.clone())
64    }
65}
66
67impl<S, C> Load for PendingRequests<S, C> {
68    type Metric = Count;
69
70    fn load(&self) -> Count {
71        // Count the number of references that aren't `self`.
72        Count(self.ref_count.ref_count() - 1)
73    }
74}
75
76impl<S, C, Request> Service<Request> for PendingRequests<S, C>
77where
78    S: Service<Request>,
79    C: TrackCompletion<Handle, S::Response>,
80{
81    type Response = C::Output;
82    type Error = S::Error;
83    type Future = TrackCompletionFuture<S::Future, C, Handle>;
84
85    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        self.service.poll_ready(cx)
87    }
88
89    fn call(&mut self, req: Request) -> Self::Future {
90        TrackCompletionFuture::new(
91            self.completion.clone(),
92            self.handle(),
93            self.service.call(req),
94        )
95    }
96}
97
98// ===== impl PendingRequestsDiscover =====
99
100#[cfg(feature = "discover")]
101impl<D, C> PendingRequestsDiscover<D, C> {
102    /// Wraps a [`Discover`], wrapping all of its services with [`PendingRequests`].
103    pub fn new<Request>(discover: D, completion: C) -> Self
104    where
105        D: Discover,
106        D::Service: Service<Request>,
107        C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
108    {
109        Self {
110            discover,
111            completion,
112        }
113    }
114}
115
116#[cfg(feature = "discover")]
117impl<D, C> Stream for PendingRequestsDiscover<D, C>
118where
119    D: Discover,
120    C: Clone,
121{
122    type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
123
124    /// Yields the next discovery change set.
125    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126        use self::Change::*;
127
128        let this = self.project();
129        let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
130            None => return Poll::Ready(None),
131            Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
132            Some(Remove(k)) => Remove(k),
133        };
134
135        Poll::Ready(Some(Ok(change)))
136    }
137}
138
139// ==== RefCount ====
140
141impl RefCount {
142    pub(crate) fn ref_count(&self) -> usize {
143        Arc::strong_count(&self.0)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use futures_util::future;
151    use std::task::{Context, Poll};
152
153    struct Svc;
154    impl Service<()> for Svc {
155        type Response = ();
156        type Error = ();
157        type Future = future::Ready<Result<(), ()>>;
158
159        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
160            Poll::Ready(Ok(()))
161        }
162
163        fn call(&mut self, (): ()) -> Self::Future {
164            future::ok(())
165        }
166    }
167
168    #[test]
169    fn default() {
170        let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
171        assert_eq!(svc.load(), Count(0));
172
173        let rsp0 = svc.call(());
174        assert_eq!(svc.load(), Count(1));
175
176        let rsp1 = svc.call(());
177        assert_eq!(svc.load(), Count(2));
178
179        let () = tokio_test::block_on(rsp0).unwrap();
180        assert_eq!(svc.load(), Count(1));
181
182        let () = tokio_test::block_on(rsp1).unwrap();
183        assert_eq!(svc.load(), Count(0));
184    }
185
186    #[test]
187    fn with_completion() {
188        #[derive(Clone)]
189        struct IntoHandle;
190        impl TrackCompletion<Handle, ()> for IntoHandle {
191            type Output = Handle;
192            fn track_completion(&self, i: Handle, (): ()) -> Handle {
193                i
194            }
195        }
196
197        let mut svc = PendingRequests::new(Svc, IntoHandle);
198        assert_eq!(svc.load(), Count(0));
199
200        let rsp = svc.call(());
201        assert_eq!(svc.load(), Count(1));
202        let i0 = tokio_test::block_on(rsp).unwrap();
203        assert_eq!(svc.load(), Count(1));
204
205        let rsp = svc.call(());
206        assert_eq!(svc.load(), Count(2));
207        let i1 = tokio_test::block_on(rsp).unwrap();
208        assert_eq!(svc.load(), Count(2));
209
210        drop(i1);
211        assert_eq!(svc.load(), Count(1));
212
213        drop(i0);
214        assert_eq!(svc.load(), Count(0));
215    }
216}