tower/load/
pending_requests.rs

1//! A [`Load`] implementation that measures load using the number of in-flight requests.
2
3#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16use tower_service::Service;
17
18/// Measures the load of the underlying service using the number of currently-pending requests.
19#[derive(Debug)]
20pub struct PendingRequests<S, C = CompleteOnResponse> {
21    service: S,
22    ref_count: RefCount,
23    completion: C,
24}
25
26/// Shared between instances of [`PendingRequests`] and [`Handle`] to track active references.
27#[derive(Clone, Debug, Default)]
28struct RefCount(Arc<()>);
29
30#[cfg(feature = "discover")]
31pin_project! {
32    /// Wraps a `D`-typed stream of discovered services with [`PendingRequests`].
33    #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
34    #[derive(Debug)]
35    pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
36        #[pin]
37        discover: D,
38        completion: C,
39    }
40}
41
42/// Represents the number of currently-pending requests to a given service.
43#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
44pub struct Count(usize);
45
46/// Tracks an in-flight request by reference count.
47#[derive(Debug)]
48#[allow(dead_code)]
49pub struct Handle(RefCount);
50
51// ===== impl PendingRequests =====
52
53impl<S, C> PendingRequests<S, C> {
54    /// Wraps an `S`-typed service so that its load is tracked by the number of pending requests.
55    pub fn new(service: S, completion: C) -> Self {
56        Self {
57            service,
58            completion,
59            ref_count: RefCount::default(),
60        }
61    }
62
63    fn handle(&self) -> Handle {
64        Handle(self.ref_count.clone())
65    }
66}
67
68impl<S, C> Load for PendingRequests<S, C> {
69    type Metric = Count;
70
71    fn load(&self) -> Count {
72        // Count the number of references that aren't `self`.
73        Count(self.ref_count.ref_count() - 1)
74    }
75}
76
77impl<S, C, Request> Service<Request> for PendingRequests<S, C>
78where
79    S: Service<Request>,
80    C: TrackCompletion<Handle, S::Response>,
81{
82    type Response = C::Output;
83    type Error = S::Error;
84    type Future = TrackCompletionFuture<S::Future, C, Handle>;
85
86    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        self.service.poll_ready(cx)
88    }
89
90    fn call(&mut self, req: Request) -> Self::Future {
91        TrackCompletionFuture::new(
92            self.completion.clone(),
93            self.handle(),
94            self.service.call(req),
95        )
96    }
97}
98
99// ===== impl PendingRequestsDiscover =====
100
101#[cfg(feature = "discover")]
102impl<D, C> PendingRequestsDiscover<D, C> {
103    /// Wraps a [`Discover`], wrapping all of its services with [`PendingRequests`].
104    pub const fn new<Request>(discover: D, completion: C) -> Self
105    where
106        D: Discover,
107        D::Service: Service<Request>,
108        C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
109    {
110        Self {
111            discover,
112            completion,
113        }
114    }
115}
116
117#[cfg(feature = "discover")]
118impl<D, C> Stream for PendingRequestsDiscover<D, C>
119where
120    D: Discover,
121    C: Clone,
122{
123    type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
124
125    /// Yields the next discovery change set.
126    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        use self::Change::*;
128
129        let this = self.project();
130        let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
131            None => return Poll::Ready(None),
132            Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
133            Some(Remove(k)) => Remove(k),
134        };
135
136        Poll::Ready(Some(Ok(change)))
137    }
138}
139
140// ==== RefCount ====
141
142impl RefCount {
143    pub(crate) fn ref_count(&self) -> usize {
144        Arc::strong_count(&self.0)
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use futures_util::future;
152    use std::task::{Context, Poll};
153
154    struct Svc;
155    impl Service<()> for Svc {
156        type Response = ();
157        type Error = ();
158        type Future = future::Ready<Result<(), ()>>;
159
160        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
161            Poll::Ready(Ok(()))
162        }
163
164        fn call(&mut self, (): ()) -> Self::Future {
165            future::ok(())
166        }
167    }
168
169    #[test]
170    fn default() {
171        let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
172        assert_eq!(svc.load(), Count(0));
173
174        let rsp0 = svc.call(());
175        assert_eq!(svc.load(), Count(1));
176
177        let rsp1 = svc.call(());
178        assert_eq!(svc.load(), Count(2));
179
180        let () = tokio_test::block_on(rsp0).unwrap();
181        assert_eq!(svc.load(), Count(1));
182
183        let () = tokio_test::block_on(rsp1).unwrap();
184        assert_eq!(svc.load(), Count(0));
185    }
186
187    #[test]
188    fn with_completion() {
189        #[derive(Clone)]
190        struct IntoHandle;
191        impl TrackCompletion<Handle, ()> for IntoHandle {
192            type Output = Handle;
193            fn track_completion(&self, i: Handle, (): ()) -> Handle {
194                i
195            }
196        }
197
198        let mut svc = PendingRequests::new(Svc, IntoHandle);
199        assert_eq!(svc.load(), Count(0));
200
201        let rsp = svc.call(());
202        assert_eq!(svc.load(), Count(1));
203        let i0 = tokio_test::block_on(rsp).unwrap();
204        assert_eq!(svc.load(), Count(1));
205
206        let rsp = svc.call(());
207        assert_eq!(svc.load(), Count(2));
208        let i1 = tokio_test::block_on(rsp).unwrap();
209        assert_eq!(svc.load(), Count(2));
210
211        drop(i1);
212        assert_eq!(svc.load(), Count(1));
213
214        drop(i0);
215        assert_eq!(svc.load(), Count(0));
216    }
217}