1#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16use tower_service::Service;
17
18#[derive(Debug)]
20pub struct PendingRequests<S, C = CompleteOnResponse> {
21 service: S,
22 ref_count: RefCount,
23 completion: C,
24}
25
26#[derive(Clone, Debug, Default)]
28struct RefCount(Arc<()>);
29
30#[cfg(feature = "discover")]
31pin_project! {
32 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
34 #[derive(Debug)]
35 pub struct PendingRequestsDiscover<D, C = CompleteOnResponse> {
36 #[pin]
37 discover: D,
38 completion: C,
39 }
40}
41
42#[derive(Clone, Copy, Debug, Default, PartialOrd, PartialEq, Ord, Eq)]
44pub struct Count(usize);
45
46#[derive(Debug)]
48#[allow(dead_code)]
49pub struct Handle(RefCount);
50
51impl<S, C> PendingRequests<S, C> {
54 pub fn new(service: S, completion: C) -> Self {
56 Self {
57 service,
58 completion,
59 ref_count: RefCount::default(),
60 }
61 }
62
63 fn handle(&self) -> Handle {
64 Handle(self.ref_count.clone())
65 }
66}
67
68impl<S, C> Load for PendingRequests<S, C> {
69 type Metric = Count;
70
71 fn load(&self) -> Count {
72 Count(self.ref_count.ref_count() - 1)
74 }
75}
76
77impl<S, C, Request> Service<Request> for PendingRequests<S, C>
78where
79 S: Service<Request>,
80 C: TrackCompletion<Handle, S::Response>,
81{
82 type Response = C::Output;
83 type Error = S::Error;
84 type Future = TrackCompletionFuture<S::Future, C, Handle>;
85
86 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 self.service.poll_ready(cx)
88 }
89
90 fn call(&mut self, req: Request) -> Self::Future {
91 TrackCompletionFuture::new(
92 self.completion.clone(),
93 self.handle(),
94 self.service.call(req),
95 )
96 }
97}
98
99#[cfg(feature = "discover")]
102impl<D, C> PendingRequestsDiscover<D, C> {
103 pub const fn new<Request>(discover: D, completion: C) -> Self
105 where
106 D: Discover,
107 D::Service: Service<Request>,
108 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
109 {
110 Self {
111 discover,
112 completion,
113 }
114 }
115}
116
117#[cfg(feature = "discover")]
118impl<D, C> Stream for PendingRequestsDiscover<D, C>
119where
120 D: Discover,
121 C: Clone,
122{
123 type Item = Result<Change<D::Key, PendingRequests<D::Service, C>>, D::Error>;
124
125 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 use self::Change::*;
128
129 let this = self.project();
130 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
131 None => return Poll::Ready(None),
132 Some(Insert(k, svc)) => Insert(k, PendingRequests::new(svc, this.completion.clone())),
133 Some(Remove(k)) => Remove(k),
134 };
135
136 Poll::Ready(Some(Ok(change)))
137 }
138}
139
140impl RefCount {
143 pub(crate) fn ref_count(&self) -> usize {
144 Arc::strong_count(&self.0)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use futures_util::future;
152 use std::task::{Context, Poll};
153
154 struct Svc;
155 impl Service<()> for Svc {
156 type Response = ();
157 type Error = ();
158 type Future = future::Ready<Result<(), ()>>;
159
160 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
161 Poll::Ready(Ok(()))
162 }
163
164 fn call(&mut self, (): ()) -> Self::Future {
165 future::ok(())
166 }
167 }
168
169 #[test]
170 fn default() {
171 let mut svc = PendingRequests::new(Svc, CompleteOnResponse);
172 assert_eq!(svc.load(), Count(0));
173
174 let rsp0 = svc.call(());
175 assert_eq!(svc.load(), Count(1));
176
177 let rsp1 = svc.call(());
178 assert_eq!(svc.load(), Count(2));
179
180 let () = tokio_test::block_on(rsp0).unwrap();
181 assert_eq!(svc.load(), Count(1));
182
183 let () = tokio_test::block_on(rsp1).unwrap();
184 assert_eq!(svc.load(), Count(0));
185 }
186
187 #[test]
188 fn with_completion() {
189 #[derive(Clone)]
190 struct IntoHandle;
191 impl TrackCompletion<Handle, ()> for IntoHandle {
192 type Output = Handle;
193 fn track_completion(&self, i: Handle, (): ()) -> Handle {
194 i
195 }
196 }
197
198 let mut svc = PendingRequests::new(Svc, IntoHandle);
199 assert_eq!(svc.load(), Count(0));
200
201 let rsp = svc.call(());
202 assert_eq!(svc.load(), Count(1));
203 let i0 = tokio_test::block_on(rsp).unwrap();
204 assert_eq!(svc.load(), Count(1));
205
206 let rsp = svc.call(());
207 assert_eq!(svc.load(), Count(2));
208 let i1 = tokio_test::block_on(rsp).unwrap();
209 assert_eq!(svc.load(), Count(2));
210
211 drop(i1);
212 assert_eq!(svc.load(), Count(1));
213
214 drop(i0);
215 assert_eq!(svc.load(), Count(0));
216 }
217}