1#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16 sync::{Arc, Mutex},
17 time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45 service: S,
46 decay_ns: f64,
47 rtt_estimate: Arc<Mutex<RttEstimate>>,
48 completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53 #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55 #[derive(Debug)]
56 pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57 #[pin]
58 discover: D,
59 decay_ns: f64,
60 default_rtt: Duration,
61 completion: C,
62 }
63}
64
65#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72#[derive(Debug)]
74pub struct Handle {
75 sent_at: Instant,
76 decay_ns: f64,
77 rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80#[derive(Debug)]
82struct RttEstimate {
83 update_at: Instant,
84 rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89impl<S, C> PeakEwma<S, C> {
92 pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94 debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95 Self {
96 service,
97 decay_ns,
98 rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99 completion,
100 }
101 }
102
103 fn handle(&self) -> Handle {
104 Handle {
105 decay_ns: self.decay_ns,
106 sent_at: Instant::now(),
107 rtt_estimate: self.rtt_estimate.clone(),
108 }
109 }
110}
111
112impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113where
114 S: Service<Request>,
115 C: TrackCompletion<Handle, S::Response>,
116{
117 type Response = C::Output;
118 type Error = S::Error;
119 type Future = TrackCompletionFuture<S::Future, C, Handle>;
120
121 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122 self.service.poll_ready(cx)
123 }
124
125 fn call(&mut self, req: Request) -> Self::Future {
126 TrackCompletionFuture::new(
127 self.completion.clone(),
128 self.handle(),
129 self.service.call(req),
130 )
131 }
132}
133
134impl<S, C> Load for PeakEwma<S, C> {
135 type Metric = Cost;
136
137 fn load(&self) -> Self::Metric {
138 let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139
140 let estimate = self.update_estimate();
143
144 let cost = Cost(estimate * f64::from(pending + 1));
145 trace!(
146 "load estimate={:.0}ms pending={} cost={:?}",
147 estimate / NANOS_PER_MILLI,
148 pending,
149 cost,
150 );
151 cost
152 }
153}
154
155impl<S, C> PeakEwma<S, C> {
156 fn update_estimate(&self) -> f64 {
157 let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158 rtt.decay(self.decay_ns)
159 }
160}
161
162#[cfg(feature = "discover")]
165impl<D, C> PeakEwmaDiscover<D, C> {
166 pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174 where
175 D: Discover,
176 D::Service: Service<Request>,
177 C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178 {
179 PeakEwmaDiscover {
180 discover,
181 decay_ns: nanos(decay),
182 default_rtt,
183 completion,
184 }
185 }
186}
187
188#[cfg(feature = "discover")]
189#[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
190impl<D, C> Stream for PeakEwmaDiscover<D, C>
191where
192 D: Discover,
193 C: Clone,
194{
195 type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
196
197 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 let this = self.project();
199 let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
200 None => return Poll::Ready(None),
201 Some(Change::Remove(k)) => Change::Remove(k),
202 Some(Change::Insert(k, svc)) => {
203 let peak_ewma = PeakEwma::new(
204 svc,
205 *this.default_rtt,
206 *this.decay_ns,
207 this.completion.clone(),
208 );
209 Change::Insert(k, peak_ewma)
210 }
211 };
212
213 Poll::Ready(Some(Ok(change)))
214 }
215}
216
217impl RttEstimate {
220 fn new(rtt_ns: f64) -> Self {
221 debug_assert!(0.0 < rtt_ns, "rtt must be positive");
222 Self {
223 rtt_ns,
224 update_at: Instant::now(),
225 }
226 }
227
228 fn decay(&mut self, decay_ns: f64) -> f64 {
230 let now = Instant::now();
232 self.update(now, now, decay_ns)
233 }
234
235 fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
239 debug_assert!(
240 sent_at <= recv_at,
241 "recv_at={:?} after sent_at={:?}",
242 recv_at,
243 sent_at
244 );
245 let rtt = nanos(recv_at.saturating_duration_since(sent_at));
246
247 let now = Instant::now();
248 debug_assert!(
249 self.update_at <= now,
250 "update_at={:?} in the future",
251 self.update_at
252 );
253
254 self.rtt_ns = if self.rtt_ns < rtt {
255 trace!(
258 "update peak rtt={}ms prior={}ms",
259 rtt / NANOS_PER_MILLI,
260 self.rtt_ns / NANOS_PER_MILLI,
261 );
262 rtt
263 } else {
264 let elapsed = nanos(now.saturating_duration_since(self.update_at));
269 let decay = (-elapsed / decay_ns).exp();
270 let recency = 1.0 - decay;
271 let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
272 trace!(
273 "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
274 rtt / NANOS_PER_MILLI,
275 self.rtt_ns - next_estimate,
276 next_estimate / NANOS_PER_MILLI,
277 );
278 next_estimate
279 };
280 self.update_at = now;
281
282 self.rtt_ns
283 }
284}
285
286impl Drop for Handle {
289 fn drop(&mut self) {
290 let recv_at = Instant::now();
291
292 if let Ok(mut rtt) = self.rtt_estimate.lock() {
293 rtt.update(self.sent_at, recv_at, self.decay_ns);
294 }
295 }
296}
297
298fn nanos(d: Duration) -> f64 {
305 const NANOS_PER_SEC: u64 = 1_000_000_000;
306 let n = f64::from(d.subsec_nanos());
307 let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
308 n + s
309}
310
311#[cfg(test)]
312mod tests {
313 use futures_util::future;
314 use std::time::Duration;
315 use tokio::time;
316 use tokio_test::{assert_ready, assert_ready_ok, task};
317
318 use super::*;
319
320 struct Svc;
321 impl Service<()> for Svc {
322 type Response = ();
323 type Error = ();
324 type Future = future::Ready<Result<(), ()>>;
325
326 fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
327 Poll::Ready(Ok(()))
328 }
329
330 fn call(&mut self, (): ()) -> Self::Future {
331 future::ok(())
332 }
333 }
334
335 #[tokio::test]
338 async fn default_decay() {
339 time::pause();
340
341 let svc = PeakEwma::new(
342 Svc,
343 Duration::from_millis(10),
344 NANOS_PER_MILLI * 1_000.0,
345 CompleteOnResponse,
346 );
347 let Cost(load) = svc.load();
348 assert_eq!(load, 10.0 * NANOS_PER_MILLI);
349
350 time::advance(Duration::from_millis(100)).await;
351 let Cost(load) = svc.load();
352 assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
353
354 time::advance(Duration::from_millis(100)).await;
355 let Cost(load) = svc.load();
356 assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
357 }
358
359 #[tokio::test]
362 async fn compound_decay() {
363 time::pause();
364
365 let mut svc = PeakEwma::new(
366 Svc,
367 Duration::from_millis(20),
368 NANOS_PER_MILLI * 1_000.0,
369 CompleteOnResponse,
370 );
371 assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
372
373 time::advance(Duration::from_millis(100)).await;
374 let mut rsp0 = task::spawn(svc.call(()));
375 assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
376
377 time::advance(Duration::from_millis(100)).await;
378 let mut rsp1 = task::spawn(svc.call(()));
379 assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
380
381 time::advance(Duration::from_millis(100)).await;
382 let () = assert_ready_ok!(rsp0.poll());
383 assert_eq!(svc.load(), Cost(400_000_000.0));
384
385 time::advance(Duration::from_millis(100)).await;
386 let () = assert_ready_ok!(rsp1.poll());
387 assert_eq!(svc.load(), Cost(200_000_000.0));
388
389 time::advance(Duration::from_secs(1)).await;
391 assert!(svc.load() < Cost(100_000_000.0));
392
393 time::advance(Duration::from_secs(10)).await;
394 assert!(svc.load() < Cost(100_000.0));
395 }
396
397 #[test]
398 fn nanos() {
399 assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
400 assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
401 assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
402 assert_eq!(
403 super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
404 18446744074709553000.0
405 );
406 }
407}