tower/load/
peak_ewma.rs

1//! A `Load` implementation that measures load using the PeakEWMA response latency.
2
3#[cfg(feature = "discover")]
4use crate::discover::{Change, Discover};
5#[cfg(feature = "discover")]
6use futures_core::{ready, Stream};
7#[cfg(feature = "discover")]
8use pin_project_lite::pin_project;
9#[cfg(feature = "discover")]
10use std::pin::Pin;
11
12use super::completion::{CompleteOnResponse, TrackCompletion, TrackCompletionFuture};
13use super::Load;
14use std::task::{Context, Poll};
15use std::{
16    sync::{Arc, Mutex},
17    time::Duration,
18};
19use tokio::time::Instant;
20use tower_service::Service;
21use tracing::trace;
22
23/// Measures the load of the underlying service using Peak-EWMA load measurement.
24///
25/// [`PeakEwma`] implements [`Load`] with the [`Cost`] metric that estimates the amount of
26/// pending work to an endpoint. Work is calculated by multiplying the
27/// exponentially-weighted moving average (EWMA) of response latencies by the number of
28/// pending requests. The Peak-EWMA algorithm is designed to be especially sensitive to
29/// worst-case latencies. Over time, the peak latency value decays towards the moving
30/// average of latencies to the endpoint.
31///
32/// When no latency information has been measured for an endpoint, an arbitrary default
33/// RTT of 1 second is used to prevent the endpoint from being overloaded before a
34/// meaningful baseline can be established..
35///
36/// ## Note
37///
38/// This is derived from [Finagle][finagle], which is distributed under the Apache V2
39/// license. Copyright 2017, Twitter Inc.
40///
41/// [finagle]:
42/// https://github.com/twitter/finagle/blob/9cc08d15216497bb03a1cafda96b7266cfbbcff1/finagle-core/src/main/scala/com/twitter/finagle/loadbalancer/PeakEwma.scala
43#[derive(Debug)]
44pub struct PeakEwma<S, C = CompleteOnResponse> {
45    service: S,
46    decay_ns: f64,
47    rtt_estimate: Arc<Mutex<RttEstimate>>,
48    completion: C,
49}
50
51#[cfg(feature = "discover")]
52pin_project! {
53    /// Wraps a `D`-typed stream of discovered services with `PeakEwma`.
54    #[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
55    #[derive(Debug)]
56    pub struct PeakEwmaDiscover<D, C = CompleteOnResponse> {
57        #[pin]
58        discover: D,
59        decay_ns: f64,
60        default_rtt: Duration,
61        completion: C,
62    }
63}
64
65/// Represents the relative cost of communicating with a service.
66///
67/// The underlying value estimates the amount of pending work to a service: the Peak-EWMA
68/// latency estimate multiplied by the number of pending requests.
69#[derive(Copy, Clone, Debug, PartialEq, PartialOrd)]
70pub struct Cost(f64);
71
72/// Tracks an in-flight request and updates the RTT-estimate on Drop.
73#[derive(Debug)]
74pub struct Handle {
75    sent_at: Instant,
76    decay_ns: f64,
77    rtt_estimate: Arc<Mutex<RttEstimate>>,
78}
79
80/// Holds the current RTT estimate and the last time this value was updated.
81#[derive(Debug)]
82struct RttEstimate {
83    update_at: Instant,
84    rtt_ns: f64,
85}
86
87const NANOS_PER_MILLI: f64 = 1_000_000.0;
88
89// ===== impl PeakEwma =====
90
91impl<S, C> PeakEwma<S, C> {
92    /// Wraps an `S`-typed service so that its load is tracked by the EWMA of its peak latency.
93    pub fn new(service: S, default_rtt: Duration, decay_ns: f64, completion: C) -> Self {
94        debug_assert!(decay_ns > 0.0, "decay_ns must be positive");
95        Self {
96            service,
97            decay_ns,
98            rtt_estimate: Arc::new(Mutex::new(RttEstimate::new(nanos(default_rtt)))),
99            completion,
100        }
101    }
102
103    fn handle(&self) -> Handle {
104        Handle {
105            decay_ns: self.decay_ns,
106            sent_at: Instant::now(),
107            rtt_estimate: self.rtt_estimate.clone(),
108        }
109    }
110}
111
112impl<S, C, Request> Service<Request> for PeakEwma<S, C>
113where
114    S: Service<Request>,
115    C: TrackCompletion<Handle, S::Response>,
116{
117    type Response = C::Output;
118    type Error = S::Error;
119    type Future = TrackCompletionFuture<S::Future, C, Handle>;
120
121    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
122        self.service.poll_ready(cx)
123    }
124
125    fn call(&mut self, req: Request) -> Self::Future {
126        TrackCompletionFuture::new(
127            self.completion.clone(),
128            self.handle(),
129            self.service.call(req),
130        )
131    }
132}
133
134impl<S, C> Load for PeakEwma<S, C> {
135    type Metric = Cost;
136
137    fn load(&self) -> Self::Metric {
138        let pending = Arc::strong_count(&self.rtt_estimate) as u32 - 1;
139
140        // Update the RTT estimate to account for decay since the last update.
141        // If an estimate has not been established, a default is provided
142        let estimate = self.update_estimate();
143
144        let cost = Cost(estimate * f64::from(pending + 1));
145        trace!(
146            "load estimate={:.0}ms pending={} cost={:?}",
147            estimate / NANOS_PER_MILLI,
148            pending,
149            cost,
150        );
151        cost
152    }
153}
154
155impl<S, C> PeakEwma<S, C> {
156    fn update_estimate(&self) -> f64 {
157        let mut rtt = self.rtt_estimate.lock().expect("peak ewma prior_estimate");
158        rtt.decay(self.decay_ns)
159    }
160}
161
162// ===== impl PeakEwmaDiscover =====
163
164#[cfg(feature = "discover")]
165impl<D, C> PeakEwmaDiscover<D, C> {
166    /// Wraps a `D`-typed [`Discover`] so that services have a [`PeakEwma`] load metric.
167    ///
168    /// The provided `default_rtt` is used as the default RTT estimate for newly
169    /// added services.
170    ///
171    /// They `decay` value determines over what time period a RTT estimate should
172    /// decay.
173    pub fn new<Request>(discover: D, default_rtt: Duration, decay: Duration, completion: C) -> Self
174    where
175        D: Discover,
176        D::Service: Service<Request>,
177        C: TrackCompletion<Handle, <D::Service as Service<Request>>::Response>,
178    {
179        PeakEwmaDiscover {
180            discover,
181            decay_ns: nanos(decay),
182            default_rtt,
183            completion,
184        }
185    }
186}
187
188#[cfg(feature = "discover")]
189#[cfg_attr(docsrs, doc(cfg(feature = "discover")))]
190impl<D, C> Stream for PeakEwmaDiscover<D, C>
191where
192    D: Discover,
193    C: Clone,
194{
195    type Item = Result<Change<D::Key, PeakEwma<D::Service, C>>, D::Error>;
196
197    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        let this = self.project();
199        let change = match ready!(this.discover.poll_discover(cx)).transpose()? {
200            None => return Poll::Ready(None),
201            Some(Change::Remove(k)) => Change::Remove(k),
202            Some(Change::Insert(k, svc)) => {
203                let peak_ewma = PeakEwma::new(
204                    svc,
205                    *this.default_rtt,
206                    *this.decay_ns,
207                    this.completion.clone(),
208                );
209                Change::Insert(k, peak_ewma)
210            }
211        };
212
213        Poll::Ready(Some(Ok(change)))
214    }
215}
216
217// ===== impl RttEstimate =====
218
219impl RttEstimate {
220    fn new(rtt_ns: f64) -> Self {
221        debug_assert!(0.0 < rtt_ns, "rtt must be positive");
222        Self {
223            rtt_ns,
224            update_at: Instant::now(),
225        }
226    }
227
228    /// Decays the RTT estimate with a decay period of `decay_ns`.
229    fn decay(&mut self, decay_ns: f64) -> f64 {
230        // Updates with a 0 duration so that the estimate decays towards 0.
231        let now = Instant::now();
232        self.update(now, now, decay_ns)
233    }
234
235    /// Updates the Peak-EWMA RTT estimate.
236    ///
237    /// The elapsed time from `sent_at` to `recv_at` is added
238    fn update(&mut self, sent_at: Instant, recv_at: Instant, decay_ns: f64) -> f64 {
239        debug_assert!(
240            sent_at <= recv_at,
241            "recv_at={:?} after sent_at={:?}",
242            recv_at,
243            sent_at
244        );
245        let rtt = nanos(recv_at.saturating_duration_since(sent_at));
246
247        let now = Instant::now();
248        debug_assert!(
249            self.update_at <= now,
250            "update_at={:?} in the future",
251            self.update_at
252        );
253
254        self.rtt_ns = if self.rtt_ns < rtt {
255            // For Peak-EWMA, always use the worst-case (peak) value as the estimate for
256            // subsequent requests.
257            trace!(
258                "update peak rtt={}ms prior={}ms",
259                rtt / NANOS_PER_MILLI,
260                self.rtt_ns / NANOS_PER_MILLI,
261            );
262            rtt
263        } else {
264            // When an RTT is observed that is less than the estimated RTT, we decay the
265            // prior estimate according to how much time has elapsed since the last
266            // update. The inverse of the decay is used to scale the estimate towards the
267            // observed RTT value.
268            let elapsed = nanos(now.saturating_duration_since(self.update_at));
269            let decay = (-elapsed / decay_ns).exp();
270            let recency = 1.0 - decay;
271            let next_estimate = (self.rtt_ns * decay) + (rtt * recency);
272            trace!(
273                "update rtt={:03.0}ms decay={:06.0}ns; next={:03.0}ms",
274                rtt / NANOS_PER_MILLI,
275                self.rtt_ns - next_estimate,
276                next_estimate / NANOS_PER_MILLI,
277            );
278            next_estimate
279        };
280        self.update_at = now;
281
282        self.rtt_ns
283    }
284}
285
286// ===== impl Handle =====
287
288impl Drop for Handle {
289    fn drop(&mut self) {
290        let recv_at = Instant::now();
291
292        if let Ok(mut rtt) = self.rtt_estimate.lock() {
293            rtt.update(self.sent_at, recv_at, self.decay_ns);
294        }
295    }
296}
297
298// ===== impl Cost =====
299
300// Utility that converts durations to nanos in f64.
301//
302// Due to a lossy transformation, the maximum value that can be represented is ~585 years,
303// which, I hope, is more than enough to represent request latencies.
304fn nanos(d: Duration) -> f64 {
305    const NANOS_PER_SEC: u64 = 1_000_000_000;
306    let n = f64::from(d.subsec_nanos());
307    let s = d.as_secs().saturating_mul(NANOS_PER_SEC) as f64;
308    n + s
309}
310
311#[cfg(test)]
312mod tests {
313    use futures_util::future;
314    use std::time::Duration;
315    use tokio::time;
316    use tokio_test::{assert_ready, assert_ready_ok, task};
317
318    use super::*;
319
320    struct Svc;
321    impl Service<()> for Svc {
322        type Response = ();
323        type Error = ();
324        type Future = future::Ready<Result<(), ()>>;
325
326        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), ()>> {
327            Poll::Ready(Ok(()))
328        }
329
330        fn call(&mut self, (): ()) -> Self::Future {
331            future::ok(())
332        }
333    }
334
335    /// The default RTT estimate decays, so that new nodes are considered if the
336    /// default RTT is too high.
337    #[tokio::test]
338    async fn default_decay() {
339        time::pause();
340
341        let svc = PeakEwma::new(
342            Svc,
343            Duration::from_millis(10),
344            NANOS_PER_MILLI * 1_000.0,
345            CompleteOnResponse,
346        );
347        let Cost(load) = svc.load();
348        assert_eq!(load, 10.0 * NANOS_PER_MILLI);
349
350        time::advance(Duration::from_millis(100)).await;
351        let Cost(load) = svc.load();
352        assert!(9.0 * NANOS_PER_MILLI < load && load < 10.0 * NANOS_PER_MILLI);
353
354        time::advance(Duration::from_millis(100)).await;
355        let Cost(load) = svc.load();
356        assert!(8.0 * NANOS_PER_MILLI < load && load < 9.0 * NANOS_PER_MILLI);
357    }
358
359    // The default RTT estimate decays, so that new nodes are considered if the default RTT is too
360    // high.
361    #[tokio::test]
362    async fn compound_decay() {
363        time::pause();
364
365        let mut svc = PeakEwma::new(
366            Svc,
367            Duration::from_millis(20),
368            NANOS_PER_MILLI * 1_000.0,
369            CompleteOnResponse,
370        );
371        assert_eq!(svc.load(), Cost(20.0 * NANOS_PER_MILLI));
372
373        time::advance(Duration::from_millis(100)).await;
374        let mut rsp0 = task::spawn(svc.call(()));
375        assert!(svc.load() > Cost(20.0 * NANOS_PER_MILLI));
376
377        time::advance(Duration::from_millis(100)).await;
378        let mut rsp1 = task::spawn(svc.call(()));
379        assert!(svc.load() > Cost(40.0 * NANOS_PER_MILLI));
380
381        time::advance(Duration::from_millis(100)).await;
382        let () = assert_ready_ok!(rsp0.poll());
383        assert_eq!(svc.load(), Cost(400_000_000.0));
384
385        time::advance(Duration::from_millis(100)).await;
386        let () = assert_ready_ok!(rsp1.poll());
387        assert_eq!(svc.load(), Cost(200_000_000.0));
388
389        // Check that values decay as time elapses
390        time::advance(Duration::from_secs(1)).await;
391        assert!(svc.load() < Cost(100_000_000.0));
392
393        time::advance(Duration::from_secs(10)).await;
394        assert!(svc.load() < Cost(100_000.0));
395    }
396
397    #[test]
398    fn nanos() {
399        assert_eq!(super::nanos(Duration::new(0, 0)), 0.0);
400        assert_eq!(super::nanos(Duration::new(0, 123)), 123.0);
401        assert_eq!(super::nanos(Duration::new(1, 23)), 1_000_000_023.0);
402        assert_eq!(
403            super::nanos(Duration::new(::std::u64::MAX, 999_999_999)),
404            18446744074709553000.0
405        );
406    }
407}