1#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16use tower_service::Service;
17
18#[derive(Debug)]
20pub struct PendingRequests<S, C = CompleteOnResponse> {
21 service: S,
22 ref_count: RefCount,
23 completion: C,
24}
25
26#[derive(Clone, Debug, Default)]
28struct RefCount(Arc<()>);
29
30#[cfg(feature = "discover")]
31pin_project! {
32 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
34 #[derive(Debug)]
35 pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
36 #[pin]
37 discover: D,
38 completion: C,
39 }
40}
41
42#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
44pub struct Count(usize);
45
46#[derive(Debug)]
48pub struct Handle(RefCount);
49
50impl<S, C> PendingRequests<S, C> {
53 pub fn new(service: S, completion: C) -> Self {
55 Self {
56 service,
57 completion,
58 ref_count: RefCount::default(),
59 }
60 }
61
62 fn handle(&self) -> Handle {
63 Handle(self.ref_count.clone())
64 }
65}
66
67impl<S, C> Load for PendingRequests<S, C> {
68 type Metric = Count;
69
70 fn load(&self) -> Count {
71 Count(self.ref_count.ref_count() - 1)
73 }
74}
75
76impl<S, C, Request> Service<Request> for PendingRequests<S, C>
77where
78 S: Service<Request>,
79 C: TrackCompletion<Handle, S::Response>,
80{
81 type Response = C::Output;
82 type Error = S::Error;
83 type Future = TrackCompletionFuture<S::Future, C, Handle>;
84
85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 self.service.poll_ready(cx)
87 }
88
89 fn call(&mut self, req: Request) -> Self::Future {
90 TrackCompletionFuture::new(
91 self.completion.clone(),
92 self.handle(),
93 self.service.call(req),
94 )
95 }
96}
97
98#[cfg(feature = "discover")]
101impl<D, C> PendingRequestsDiscover<D, C> {
102 pub fn new<Request>(discover: D, completion: C) -> Self
104 where
105 D: Discover,
106 D::Service: Service<Request>,
107 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
108 {
109 Self {
110 discover,
111 completion,
112 }
113 }
114}
115
116#[cfg(feature = "discover")]
117impl<D, C> Stream for PendingRequestsDiscover<D, C>
118where
119 D: Discover,
120 C: Clone,
121{
122 type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
123
124 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126 use self::Change::*;
127
128 let this = self.project();
129 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
130 None => return Poll::Ready(None),
131 Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
132 Some(Remove(k)) => Remove(k),
133 };
134
135 Poll::Ready(Some(Ok(change)))
136 }
137}
138
139impl RefCount {
142 pub(crate) fn ref_count(&self) -> usize {
143 Arc::strong_count(&self.0)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use futures_util::future;
151 use std::task::{Context, Poll};
152
153 struct Svc;
154 impl Service<()> for Svc {
155 type Response = ();
156 type Error = ();
157 type Future = future::Ready<Result<(), ()>>;
158
159 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
160 Poll::Ready(Ok(()))
161 }
162
163 fn call(&mut self, (): ()) -> Self::Future {
164 future::ok(())
165 }
166 }
167
168 #[test]
169 fn default() {
170 let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
171 assert_eq!(svc.load(), Count(0));
172
173 let rsp0 = svc.call(());
174 assert_eq!(svc.load(), Count(1));
175
176 let rsp1 = svc.call(());
177 assert_eq!(svc.load(), Count(2));
178
179 let () = tokio_test::block_on(rsp0).unwrap();
180 assert_eq!(svc.load(), Count(1));
181
182 let () = tokio_test::block_on(rsp1).unwrap();
183 assert_eq!(svc.load(), Count(0));
184 }
185
186 #[test]
187 fn with_completion() {
188 #[derive(Clone)]
189 struct IntoHandle;
190 impl TrackCompletion<Handle, ()> for IntoHandle {
191 type Output = Handle;
192 fn track_completion(&self, i: Handle, (): ()) -> Handle {
193 i
194 }
195 }
196
197 let mut svc = PendingRequests::new(Svc, IntoHandle);
198 assert_eq!(svc.load(), Count(0));
199
200 let rsp = svc.call(());
201 assert_eq!(svc.load(), Count(1));
202 let i0 = tokio_test::block_on(rsp).unwrap();
203 assert_eq!(svc.load(), Count(1));
204
205 let rsp = svc.call(());
206 assert_eq!(svc.load(), Count(2));
207 let i1 = tokio_test::block_on(rsp).unwrap();
208 assert_eq!(svc.load(), Count(2));
209
210 drop(i1);
211 assert_eq!(svc.load(), Count(1));
212
213 drop(i0);
214 assert_eq!(svc.load(), Count(0));
215 }
216}