kube_runtime/utils/
stream_backoff.rs
1use std::{future::Future, pin::Pin, task::Poll};
2
3use backoff::backoff::Backoff;
4use futures::{Stream, TryStream};
5use pin_project::pin_project;
6use tokio::time::{sleep, Instant, Sleep};
7
8#[pin_project]
15pub struct StreamBackoff<S, B> {
16 #[pin]
17 stream: S,
18 backoff: B,
19 #[pin]
20 state: State,
21}
22
23#[pin_project(project = StreamBackoffStateProj)]
24#[allow(clippy::large_enum_variant)]
27enum State {
28 BackingOff(#[pin] Sleep),
29 GivenUp,
30 Awake,
31}
32
33impl<S: TryStream, B: Backoff> StreamBackoff<S, B> {
34 pub fn new(stream: S, backoff: B) -> Self {
35 Self {
36 stream,
37 backoff,
38 state: State::Awake,
39 }
40 }
41}
42
43impl<S: TryStream, B: Backoff> Stream for StreamBackoff<S, B> {
44 type Item = Result<S::Ok, S::Error>;
45
46 fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
47 let mut this = self.project();
48 match this.state.as_mut().project() {
49 StreamBackoffStateProj::BackingOff(mut backoff_sleep) => match backoff_sleep.as_mut().poll(cx) {
50 Poll::Ready(()) => {
51 tracing::debug!(deadline = ?backoff_sleep.deadline(), "Backoff complete, waking up");
52 this.state.set(State::Awake)
53 }
54 Poll::Pending => {
55 let deadline = backoff_sleep.deadline();
56 tracing::trace!(
57 ?deadline,
58 remaining_duration = ?deadline.saturating_duration_since(Instant::now()),
59 "Still waiting for backoff sleep to complete"
60 );
61 return Poll::Pending;
62 }
63 },
64 StreamBackoffStateProj::GivenUp => {
65 tracing::debug!("Backoff has given up, stream is closed");
66 return Poll::Ready(None);
67 }
68 StreamBackoffStateProj::Awake => {}
69 }
70
71 let next_item = this.stream.try_poll_next(cx);
72 match &next_item {
73 Poll::Ready(Some(Err(_))) => {
74 if let Some(backoff_duration) = this.backoff.next_backoff() {
75 let backoff_sleep = sleep(backoff_duration);
76 tracing::debug!(
77 deadline = ?backoff_sleep.deadline(),
78 duration = ?backoff_duration,
79 "Error received, backing off"
80 );
81 this.state.set(State::BackingOff(backoff_sleep));
82 } else {
83 tracing::debug!("Error received, giving up");
84 this.state.set(State::GivenUp);
85 }
86 }
87 Poll::Ready(_) => {
88 tracing::trace!("Non-error received, resetting backoff");
89 this.backoff.reset();
90 }
91 Poll::Pending => {}
92 }
93 next_item
94 }
95}
96
97#[cfg(test)]
98pub(crate) mod tests {
99 use std::{pin::pin, task::Poll, time::Duration};
100
101 use super::StreamBackoff;
102 use backoff::backoff::Backoff;
103 use futures::{channel::mpsc, poll, stream, StreamExt};
104
105 #[tokio::test]
106 async fn stream_should_back_off() {
107 tokio::time::pause();
108 let tick = Duration::from_secs(1);
109 let rx = stream::iter([Ok(0), Ok(1), Err(2), Ok(3), Ok(4)]);
110 let mut rx = pin!(StreamBackoff::new(rx, backoff::backoff::Constant::new(tick)));
111 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0))));
112 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1))));
113 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2))));
114 assert_eq!(poll!(rx.next()), Poll::Pending);
115 tokio::time::advance(tick * 2).await;
116 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(3))));
117 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4))));
118 assert_eq!(poll!(rx.next()), Poll::Ready(None));
119 }
120
121 #[tokio::test]
122 async fn backoff_time_should_update() {
123 tokio::time::pause();
124 let (tx, rx) = mpsc::unbounded();
125 let mut rx = pin!(StreamBackoff::new(rx, LinearBackoff::new(Duration::from_secs(2))));
127 tx.unbounded_send(Ok(0)).unwrap();
128 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(0))));
129 tx.unbounded_send(Ok(1)).unwrap();
130 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(1))));
131 tx.unbounded_send(Err(2)).unwrap();
132 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(2))));
133 assert_eq!(poll!(rx.next()), Poll::Pending);
134 tokio::time::advance(Duration::from_secs(3)).await;
135 assert_eq!(poll!(rx.next()), Poll::Pending);
136 tx.unbounded_send(Err(3)).unwrap();
137 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Err(3))));
138 tx.unbounded_send(Ok(4)).unwrap();
139 assert_eq!(poll!(rx.next()), Poll::Pending);
140 tokio::time::advance(Duration::from_secs(3)).await;
141 assert_eq!(poll!(rx.next()), Poll::Pending);
142 tokio::time::advance(Duration::from_secs(2)).await;
143 assert_eq!(poll!(rx.next()), Poll::Ready(Some(Ok(4))));
144 assert_eq!(poll!(rx.next()), Poll::Pending);
145 drop(tx);
146 assert_eq!(poll!(rx.next()), Poll::Ready(None));
147 }
148
149 #[tokio::test]
150 async fn backoff_should_close_when_requested() {
151 assert_eq!(
152 StreamBackoff::new(
153 stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]),
154 backoff::backoff::Stop {}
155 )
156 .collect::<Vec<_>>()
157 .await,
158 vec![Ok(0), Ok(1), Err(2)]
159 );
160 }
161
162 pub struct LinearBackoff {
164 interval: Duration,
165 current_duration: Duration,
166 }
167
168 impl LinearBackoff {
169 pub fn new(interval: Duration) -> Self {
170 Self {
171 interval,
172 current_duration: Duration::ZERO,
173 }
174 }
175 }
176
177 impl Backoff for LinearBackoff {
178 fn next_backoff(&mut self) -> Option<Duration> {
179 self.current_duration += self.interval;
180 Some(self.current_duration)
181 }
182
183 fn reset(&mut self) {
184 self.current_duration = Duration::ZERO
185 }
186 }
187}