1#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16 sync::{Arc, Mutex},
17 time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45 service: S,
46 decay_ns: f64,
47 rtt_estimate: Arc<Mutex<RttEstimate>>,
48 completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55 #[derive(Debug)]
56 pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57 #[pin]
58 discover: D,
59 decay_ns: f64,
60 default_rtt: Duration,
61 completion: C,
62 }
63}
64
65#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72#[derive(Debug)]
74pub struct Handle {
75 sent_at: Instant,
76 decay_ns: f64,
77 rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80#[derive(Debug)]
82struct RttEstimate {
83 update_at: Instant,
84 rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89impl<S, C> PeakEwma<S, C> {
92 pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94 debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95 Self {
96 service,
97 decay_ns,
98 rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99 completion,
100 }
101 }
102
103 fn handle(&self) -> Handle {
104 Handle {
105 decay_ns: self.decay_ns,
106 sent_at: Instant::now(),
107 rtt_estimate: self.rtt_estimate.clone(),
108 }
109 }
110}
111
112impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113where
114 S: Service<Request>,
115 C: TrackCompletion<Handle, S::Response>,
116{
117 type Response = C::Output;
118 type Error = S::Error;
119 type Future = TrackCompletionFuture<S::Future, C, Handle>;
120
121 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 self.service.poll_ready(cx)
123 }
124
125 fn call(&mut self, req: Request) -> Self::Future {
126 TrackCompletionFuture::new(
127 self.completion.clone(),
128 self.handle(),
129 self.service.call(req),
130 )
131 }
132}
133
134impl<S, C> Load for PeakEwma<S, C> {
135 type Metric = Cost;
136
137 fn load(&self) -> Self::Metric {
138 let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139
140 let estimate = self.update_estimate();
143
144 let cost = Cost(estimate * f64::from(pending + 1));
145 trace!(
146 "load estimate={:.0}ms pending={} cost={:?}",
147 estimate / NANOS_PER_MILLI,
148 pending,
149 cost,
150 );
151 cost
152 }
153}
154
155impl<S, C> PeakEwma<S, C> {
156 fn update_estimate(&self) -> f64 {
157 let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158 rtt.decay(self.decay_ns)
159 }
160}
161
162#[cfg(feature = "discover")]
165impl<D, C> PeakEwmaDiscover<D, C> {
166 pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174 where
175 D: Discover,
176 D::Service: Service<Request>,
177 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178 {
179 PeakEwmaDiscover {
180 discover,
181 decay_ns: nanos(decay),
182 default_rtt,
183 completion,
184 }
185 }
186}
187
188#[cfg(feature = "discover")]
189impl<D, C> Stream for PeakEwmaDiscover<D, C>
190where
191 D: Discover,
192 C: Clone,
193{
194 type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
195
196 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197 let this = self.project();
198 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
199 None => return Poll::Ready(None),
200 Some(Change::Remove(k)) => Change::Remove(k),
201 Some(Change::Insert(k, svc)) => {
202 let peak_ewma = PeakEwma::new(
203 svc,
204 *this.default_rtt,
205 *this.decay_ns,
206 this.completion.clone(),
207 );
208 Change::Insert(k, peak_ewma)
209 }
210 };
211
212 Poll::Ready(Some(Ok(change)))
213 }
214}
215
216impl RttEstimate {
219 fn new(rtt_ns: f64) -> Self {
220 debug_assert!(0.0 < rtt_ns, "rtt must be positive");
221 Self {
222 rtt_ns,
223 update_at: Instant::now(),
224 }
225 }
226
227 fn decay(&mut self, decay_ns: f64) -> f64 {
229 let now = Instant::now();
231 self.update(now, now, decay_ns)
232 }
233
234 fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
238 debug_assert!(
239 sent_at <= recv_at,
240 "recv_at={:?} after sent_at={:?}",
241 recv_at,
242 sent_at
243 );
244 let rtt = nanos(recv_at.saturating_duration_since(sent_at));
245
246 let now = Instant::now();
247 debug_assert!(
248 self.update_at <= now,
249 "update_at={:?} in the future",
250 self.update_at
251 );
252
253 self.rtt_ns = if self.rtt_ns < rtt {
254 trace!(
257 "update peak rtt={}ms prior={}ms",
258 rtt / NANOS_PER_MILLI,
259 self.rtt_ns / NANOS_PER_MILLI,
260 );
261 rtt
262 } else {
263 let elapsed = nanos(now.saturating_duration_since(self.update_at));
268 let decay = (-elapsed / decay_ns).exp();
269 let recency = 1.0 - decay;
270 let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
271 trace!(
272 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
273 rtt / NANOS_PER_MILLI,
274 self.rtt_ns - next_estimate,
275 next_estimate / NANOS_PER_MILLI,
276 );
277 next_estimate
278 };
279 self.update_at = now;
280
281 self.rtt_ns
282 }
283}
284
285impl Drop for Handle {
288 fn drop(&mut self) {
289 let recv_at = Instant::now();
290
291 if let Ok(mut rtt) = self.rtt_estimate.lock() {
292 rtt.update(self.sent_at, recv_at, self.decay_ns);
293 }
294 }
295}
296
297fn nanos(d: Duration) -> f64 {
304 const NANOS_PER_SEC: u64 = 1_000_000_000;
305 let n = f64::from(d.subsec_nanos());
306 let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
307 n + s
308}
309
310#[cfg(test)]
311mod tests {
312 use futures_util::future;
313 use std::time::Duration;
314 use tokio::time;
315 use tokio_test::{assert_ready, assert_ready_ok, task};
316
317 use super::*;
318
319 struct Svc;
320 impl Service<()> for Svc {
321 type Response = ();
322 type Error = ();
323 type Future = future::Ready<Result<(), ()>>;
324
325 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
326 Poll::Ready(Ok(()))
327 }
328
329 fn call(&mut self, (): ()) -> Self::Future {
330 future::ok(())
331 }
332 }
333
334 #[tokio::test]
337 async fn default_decay() {
338 time::pause();
339
340 let svc = PeakEwma::new(
341 Svc,
342 Duration::from_millis(10),
343 NANOS_PER_MILLI * 1_000.0,
344 CompleteOnResponse,
345 );
346 let Cost(load) = svc.load();
347 assert_eq!(load, 10.0 * NANOS_PER_MILLI);
348
349 time::advance(Duration::from_millis(100)).await;
350 let Cost(load) = svc.load();
351 assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
352
353 time::advance(Duration::from_millis(100)).await;
354 let Cost(load) = svc.load();
355 assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
356 }
357
358 #[tokio::test]
361 async fn compound_decay() {
362 time::pause();
363
364 let mut svc = PeakEwma::new(
365 Svc,
366 Duration::from_millis(20),
367 NANOS_PER_MILLI * 1_000.0,
368 CompleteOnResponse,
369 );
370 assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
371
372 time::advance(Duration::from_millis(100)).await;
373 let mut rsp0 = task::spawn(svc.call(()));
374 assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
375
376 time::advance(Duration::from_millis(100)).await;
377 let mut rsp1 = task::spawn(svc.call(()));
378 assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
379
380 time::advance(Duration::from_millis(100)).await;
381 let () = assert_ready_ok!(rsp0.poll());
382 assert_eq!(svc.load(), Cost(400_000_000.0));
383
384 time::advance(Duration::from_millis(100)).await;
385 let () = assert_ready_ok!(rsp1.poll());
386 assert_eq!(svc.load(), Cost(200_000_000.0));
387
388 time::advance(Duration::from_secs(1)).await;
390 assert!(svc.load() < Cost(100_000_000.0));
391
392 time::advance(Duration::from_secs(10)).await;
393 assert!(svc.load() < Cost(100_000.0));
394 }
395
396 #[test]
397 fn nanos() {
398 assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
399 assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
400 assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
401 assert_eq!(
402 super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
403 18446744074709553000.0
404 );
405 }
406}