tower/load/
peak_ewma.rs

1//! A `Load` implementation that measures load using the PeakEWMA response latency.
2
3#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16    sync::{Arc, Mutex},
17    time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23/// Measures the load of the underlying service using Peak-EWMA load measurement.
24///
25/// [`PeakEwma`] implements [`Load`] with the [`Cost`] metric that estimates the amount of
26/// pending work to an endpoint. Work is calculated by multiplying the
27/// exponentially-weighted moving average (EWMA) of response latencies by the number of
28/// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
29/// worst-case latencies. Over time, the peak latency value decays towards the moving
30/// average of latencies to the endpoint.
31///
32/// When no latency information has been measured for an endpoint, an arbitrary default
33/// RTT of 1 second is used to prevent the endpoint from being overloaded before a
34/// meaningful baseline can be established..
35///
36/// ## Note
37///
38/// This is derived from [Finagle][finagle], which is distributed under the Apache V2
39/// license. Copyright 2017, Twitter Inc.
40///
41/// [finagle]:
42/// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
43#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45    service: S,
46    decay_ns: f64,
47    rtt_estimate: Arc<Mutex<RttEstimate>>,
48    completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53    /// Wraps a `D`-typed stream of discovered services with `PeakEwma`.
54    #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55    #[derive(Debug)]
56    pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57        #[pin]
58        discover: D,
59        decay_ns: f64,
60        default_rtt: Duration,
61        completion: C,
62    }
63}
64
65/// Represents the relative cost of communicating with a service.
66///
67/// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
68/// latency estimate multiplied by the number of pending requests.
69#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72/// Tracks an in-flight request and updates the RTT-estimate on Drop.
73#[derive(Debug)]
74pub struct Handle {
75    sent_at: Instant,
76    decay_ns: f64,
77    rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80/// Holds the current RTT estimate and the last time this value was updated.
81#[derive(Debug)]
82struct RttEstimate {
83    update_at: Instant,
84    rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89// ===== impl PeakEwma =====
90
91impl<S, C> PeakEwma<S, C> {
92    /// Wraps an `S`-typed service so that its load is tracked by the EWMA of its peak latency.
93    pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94        debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95        Self {
96            service,
97            decay_ns,
98            rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99            completion,
100        }
101    }
102
103    fn handle(&self) -> Handle {
104        Handle {
105            decay_ns: self.decay_ns,
106            sent_at: Instant::now(),
107            rtt_estimate: self.rtt_estimate.clone(),
108        }
109    }
110}
111
112impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113where
114    S: Service<Request>,
115    C: TrackCompletion<Handle, S::Response>,
116{
117    type Response = C::Output;
118    type Error = S::Error;
119    type Future = TrackCompletionFuture<S::Future, C, Handle>;
120
121    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        self.service.poll_ready(cx)
123    }
124
125    fn call(&mut self, req: Request) -> Self::Future {
126        TrackCompletionFuture::new(
127            self.completion.clone(),
128            self.handle(),
129            self.service.call(req),
130        )
131    }
132}
133
134impl<S, C> Load for PeakEwma<S, C> {
135    type Metric = Cost;
136
137    fn load(&self) -> Self::Metric {
138        let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139
140        // Update the RTT estimate to account for decay since the last update.
141        // If an estimate has not been established, a default is provided
142        let estimate = self.update_estimate();
143
144        let cost = Cost(estimate * f64::from(pending + 1));
145        trace!(
146            "load estimate={:.0}ms pending={} cost={:?}",
147            estimate / NANOS_PER_MILLI,
148            pending,
149            cost,
150        );
151        cost
152    }
153}
154
155impl<S, C> PeakEwma<S, C> {
156    fn update_estimate(&self) -> f64 {
157        let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158        rtt.decay(self.decay_ns)
159    }
160}
161
162// ===== impl PeakEwmaDiscover =====
163
164#[cfg(feature = "discover")]
165impl<D, C> PeakEwmaDiscover<D, C> {
166    /// Wraps a `D`-typed [`Discover`] so that services have a [`PeakEwma`] load metric.
167    ///
168    /// The provided `default_rtt` is used as the default RTT estimate for newly
169    /// added services.
170    ///
171    /// They `decay` value determines over what time period a RTT estimate should
172    /// decay.
173    pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174    where
175        D: Discover,
176        D::Service: Service<Request>,
177        C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178    {
179        PeakEwmaDiscover {
180            discover,
181            decay_ns: nanos(decay),
182            default_rtt,
183            completion,
184        }
185    }
186}
187
188#[cfg(feature = "discover")]
189impl<D, C> Stream for PeakEwmaDiscover<D, C>
190where
191    D: Discover,
192    C: Clone,
193{
194    type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
195
196    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
197        let this = self.project();
198        let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
199            None => return Poll::Ready(None),
200            Some(Change::Remove(k)) => Change::Remove(k),
201            Some(Change::Insert(k, svc)) => {
202                let peak_ewma = PeakEwma::new(
203                    svc,
204                    *this.default_rtt,
205                    *this.decay_ns,
206                    this.completion.clone(),
207                );
208                Change::Insert(k, peak_ewma)
209            }
210        };
211
212        Poll::Ready(Some(Ok(change)))
213    }
214}
215
216// ===== impl RttEstimate =====
217
218impl RttEstimate {
219    fn new(rtt_ns: f64) -> Self {
220        debug_assert!(0.0 < rtt_ns, "rtt must be positive");
221        Self {
222            rtt_ns,
223            update_at: Instant::now(),
224        }
225    }
226
227    /// Decays the RTT estimate with a decay period of `decay_ns`.
228    fn decay(&mut self, decay_ns: f64) -> f64 {
229        // Updates with a 0 duration so that the estimate decays towards 0.
230        let now = Instant::now();
231        self.update(now, now, decay_ns)
232    }
233
234    /// Updates the Peak-EWMA RTT estimate.
235    ///
236    /// The elapsed time from `sent_at` to `recv_at` is added
237    fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
238        debug_assert!(
239            sent_at <= recv_at,
240            "recv_at={:?} after sent_at={:?}",
241            recv_at,
242            sent_at
243        );
244        let rtt = nanos(recv_at.saturating_duration_since(sent_at));
245
246        let now = Instant::now();
247        debug_assert!(
248            self.update_at <= now,
249            "update_at={:?} in the future",
250            self.update_at
251        );
252
253        self.rtt_ns = if self.rtt_ns < rtt {
254            // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
255            // subsequent requests.
256            trace!(
257                "update peak rtt={}ms prior={}ms",
258                rtt / NANOS_PER_MILLI,
259                self.rtt_ns / NANOS_PER_MILLI,
260            );
261            rtt
262        } else {
263            // When an RTT is observed that is less than the estimated RTT, we decay the
264            // prior estimate according to how much time has elapsed since the last
265            // update. The inverse of the decay is used to scale the estimate towards the
266            // observed RTT value.
267            let elapsed = nanos(now.saturating_duration_since(self.update_at));
268            let decay = (-elapsed / decay_ns).exp();
269            let recency = 1.0 - decay;
270            let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
271            trace!(
272                "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
273                rtt / NANOS_PER_MILLI,
274                self.rtt_ns - next_estimate,
275                next_estimate / NANOS_PER_MILLI,
276            );
277            next_estimate
278        };
279        self.update_at = now;
280
281        self.rtt_ns
282    }
283}
284
285// ===== impl Handle =====
286
287impl Drop for Handle {
288    fn drop(&mut self) {
289        let recv_at = Instant::now();
290
291        if let Ok(mut rtt) = self.rtt_estimate.lock() {
292            rtt.update(self.sent_at, recv_at, self.decay_ns);
293        }
294    }
295}
296
297// ===== impl Cost =====
298
299// Utility that converts durations to nanos in f64.
300//
301// Due to a lossy transformation, the maximum value that can be represented is ~585 years,
302// which, I hope, is more than enough to represent request latencies.
303fn nanos(d: Duration) -> f64 {
304    const NANOS_PER_SEC: u64 = 1_000_000_000;
305    let n = f64::from(d.subsec_nanos());
306    let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
307    n + s
308}
309
310#[cfg(test)]
311mod tests {
312    use futures_util::future;
313    use std::time::Duration;
314    use tokio::time;
315    use tokio_test::{assert_ready, assert_ready_ok, task};
316
317    use super::*;
318
319    struct Svc;
320    impl Service<()> for Svc {
321        type Response = ();
322        type Error = ();
323        type Future = future::Ready<Result<(), ()>>;
324
325        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
326            Poll::Ready(Ok(()))
327        }
328
329        fn call(&mut self, (): ()) -> Self::Future {
330            future::ok(())
331        }
332    }
333
334    /// The default RTT estimate decays, so that new nodes are considered if the
335    /// default RTT is too high.
336    #[tokio::test]
337    async fn default_decay() {
338        time::pause();
339
340        let svc = PeakEwma::new(
341            Svc,
342            Duration::from_millis(10),
343            NANOS_PER_MILLI * 1_000.0,
344            CompleteOnResponse,
345        );
346        let Cost(load) = svc.load();
347        assert_eq!(load, 10.0 * NANOS_PER_MILLI);
348
349        time::advance(Duration::from_millis(100)).await;
350        let Cost(load) = svc.load();
351        assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
352
353        time::advance(Duration::from_millis(100)).await;
354        let Cost(load) = svc.load();
355        assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
356    }
357
358    // The default RTT estimate decays, so that new nodes are considered if the default RTT is too
359    // high.
360    #[tokio::test]
361    async fn compound_decay() {
362        time::pause();
363
364        let mut svc = PeakEwma::new(
365            Svc,
366            Duration::from_millis(20),
367            NANOS_PER_MILLI * 1_000.0,
368            CompleteOnResponse,
369        );
370        assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
371
372        time::advance(Duration::from_millis(100)).await;
373        let mut rsp0 = task::spawn(svc.call(()));
374        assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
375
376        time::advance(Duration::from_millis(100)).await;
377        let mut rsp1 = task::spawn(svc.call(()));
378        assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
379
380        time::advance(Duration::from_millis(100)).await;
381        let () = assert_ready_ok!(rsp0.poll());
382        assert_eq!(svc.load(), Cost(400_000_000.0));
383
384        time::advance(Duration::from_millis(100)).await;
385        let () = assert_ready_ok!(rsp1.poll());
386        assert_eq!(svc.load(), Cost(200_000_000.0));
387
388        // Check that values decay as time elapses
389        time::advance(Duration::from_secs(1)).await;
390        assert!(svc.load() < Cost(100_000_000.0));
391
392        time::advance(Duration::from_secs(10)).await;
393        assert!(svc.load() < Cost(100_000.0));
394    }
395
396    #[test]
397    fn nanos() {
398        assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
399        assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
400        assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
401        assert_eq!(
402            super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
403            18446744074709553000.0
404        );
405    }
406}