kube_runtime/utils/
stream_backoff.rs

1use std::{future::Future, pin::Pin, task::Poll};
2
3use backoff::backoff::Backoff;
4use futures::{Stream, TryStream};
5use pin_project::pin_project;
6use tokio::time::{sleep, Instant, Sleep};
7
8/// Applies a [`Backoff`] policy to a [`Stream`]
9///
10/// After any [`Err`] is emitted, the stream is paused for [`Backoff::next_backoff`]. The
11/// [`Backoff`] is [`reset`](`Backoff::reset`) on any [`Ok`] value.
12///
13/// If [`Backoff::next_backoff`] returns [`None`] then the backing stream is given up on, and closed.
14#[pin_project]
15pub struct StreamBackoff<S, B> {
16    #[pin]
17    stream: S,
18    backoff: B,
19    #[pin]
20    state: State,
21}
22
23#[pin_project(project = StreamBackoffStateProj)]
24// It's expected to have relatively few but long-lived `StreamBackoff`s in a project, so we would rather have
25// cheaper sleeps than a smaller `StreamBackoff`.
26#[allow(clippy::large_enum_variant)]
27enum State {
28    BackingOff(#[pin] Sleep),
29    GivenUp,
30    Awake,
31}
32
33impl<S: TryStream, B: Backoff> StreamBackoff<S, B> {
34    pub fn new(stream: S, backoff: B) -> Self {
35        Self {
36            stream,
37            backoff,
38            state: State::Awake,
39        }
40    }
41}
42
43impl<S: TryStream, B: Backoff> Stream for StreamBackoff<S, B> {
44    type Item = Result<S::Ok, S::Error>;
45
46    fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
47        let mut this = self.project();
48        match this.state.as_mut().project() {
49            StreamBackoffStateProj::BackingOff(mut backoff_sleep) => match backoff_sleep.as_mut().poll(cx) {
50                Poll::Ready(()) => {
51                    tracing::debug!(deadline = ?backoff_sleep.deadline(), "Backoff complete, waking up");
52                    this.state.set(State::Awake)
53                }
54                Poll::Pending => {
55                    let deadline = backoff_sleep.deadline();
56                    tracing::trace!(
57                        ?deadline,
58                        remaining_duration = ?deadline.saturating_duration_since(Instant::now()),
59                        "Still waiting for backoff sleep to complete"
60                    );
61                    return Poll::Pending;
62                }
63            },
64            StreamBackoffStateProj::GivenUp => {
65                tracing::debug!("Backoff has given up, stream is closed");
66                return Poll::Ready(None);
67            }
68            StreamBackoffStateProj::Awake => {}
69        }
70
71        let next_item = this.stream.try_poll_next(cx);
72        match &next_item {
73            Poll::Ready(Some(Err(_))) => {
74                if let Some(backoff_duration) = this.backoff.next_backoff() {
75                    let backoff_sleep = sleep(backoff_duration);
76                    tracing::debug!(
77                        deadline = ?backoff_sleep.deadline(),
78                        duration = ?backoff_duration,
79                        "Error received, backing off"
80                    );
81                    this.state.set(State::BackingOff(backoff_sleep));
82                } else {
83                    tracing::debug!("Error received, giving up");
84                    this.state.set(State::GivenUp);
85                }
86            }
87            Poll::Ready(_) => {
88                tracing::trace!("Non-error received, resetting backoff");
89                this.backoff.reset();
90            }
91            Poll::Pending => {}
92        }
93        next_item
94    }
95}
96
97#[cfg(test)]
98pub(crate) mod tests {
99    use std::{pin::pin, task::Poll, time::Duration};
100
101    use super::StreamBackoff;
102    use backoff::backoff::Backoff;
103    use futures::{channel::mpsc, poll, stream, StreamExt};
104
105    #[tokio::test]
106    async fn stream_should_back_off() {
107        tokio::time::pause();
108        let tick = Duration::from_secs(1);
109        let rx = stream::iter([Ok(0), Ok(1), Err(2), Ok(3), Ok(4)]);
110        let mut rx = pin!(StreamBackoff::new(rx, backoff::backoff::Constant::new(tick)));
111        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0))));
112        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1))));
113        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2))));
114        assert_eq!(poll!(rx.next()), Poll::Pending);
115        tokio::time::advance(tick * 2).await;
116        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(3))));
117        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4))));
118        assert_eq!(poll!(rx.next()), Poll::Ready(None));
119    }
120
121    #[tokio::test]
122    async fn backoff_time_should_update() {
123        tokio::time::pause();
124        let (tx, rx) = mpsc::unbounded();
125        // let rx = stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]);
126        let mut rx = pin!(StreamBackoff::new(rx, LinearBackoff::new(Duration::from_secs(2))));
127        tx.unbounded_send(Ok(0)).unwrap();
128        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0))));
129        tx.unbounded_send(Ok(1)).unwrap();
130        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1))));
131        tx.unbounded_send(Err(2)).unwrap();
132        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2))));
133        assert_eq!(poll!(rx.next()), Poll::Pending);
134        tokio::time::advance(Duration::from_secs(3)).await;
135        assert_eq!(poll!(rx.next()), Poll::Pending);
136        tx.unbounded_send(Err(3)).unwrap();
137        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(3))));
138        tx.unbounded_send(Ok(4)).unwrap();
139        assert_eq!(poll!(rx.next()), Poll::Pending);
140        tokio::time::advance(Duration::from_secs(3)).await;
141        assert_eq!(poll!(rx.next()), Poll::Pending);
142        tokio::time::advance(Duration::from_secs(2)).await;
143        assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4))));
144        assert_eq!(poll!(rx.next()), Poll::Pending);
145        drop(tx);
146        assert_eq!(poll!(rx.next()), Poll::Ready(None));
147    }
148
149    #[tokio::test]
150    async fn backoff_should_close_when_requested() {
151        assert_eq!(
152            StreamBackoff::new(
153                stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]),
154                backoff::backoff::Stop {}
155            )
156            .collect::<Vec<_>>()
157            .await,
158            vec![Ok(0), Ok(1), Err(2)]
159        );
160    }
161
162    /// Dynamic backoff policy that is still deterministic and testable
163    pub struct LinearBackoff {
164        interval: Duration,
165        current_duration: Duration,
166    }
167
168    impl LinearBackoff {
169        pub fn new(interval: Duration) -> Self {
170            Self {
171                interval,
172                current_duration: Duration::ZERO,
173            }
174        }
175    }
176
177    impl Backoff for LinearBackoff {
178        fn next_backoff(&mut self) -> Option<Duration> {
179            self.current_duration += self.interval;
180            Some(self.current_duration)
181        }
182
183        fn reset(&mut self) {
184            self.current_duration = Duration::ZERO
185        }
186    }
187}