backon/retry.rs
1use core::future::Future;
2use core::pin::Pin;
3use core::task::ready;
4use core::task::Context;
5use core::task::Poll;
6use core::time::Duration;
7
8use crate::backoff::BackoffBuilder;
9use crate::sleep::MaybeSleeper;
10use crate::Backoff;
11use crate::DefaultSleeper;
12use crate::Sleeper;
13
14/// Retryable will add retry support for functions that produce futures with results.
15///
16/// This means all types that implement `FnMut() -> impl Future<Output = Result<T, E>>`
17/// will be able to use `retry`.
18///
19/// For example:
20///
21/// - Functions without extra args:
22///
23/// ```ignore
24/// async fn fetch() -> Result<String> {
25/// Ok(reqwest::get("https://www.rust-lang.org").await?.text().await?)
26/// }
27/// ```
28///
29/// - Closures
30///
31/// ```ignore
32/// || async {
33/// let x = reqwest::get("https://www.rust-lang.org")
34/// .await?
35/// .text()
36/// .await?;
37///
38/// Err(anyhow::anyhow!(x))
39/// }
40/// ```
41pub trait Retryable<
42 B: BackoffBuilder,
43 T,
44 E,
45 Fut: Future<Output = Result<T, E>>,
46 FutureFn: FnMut() -> Fut,
47>
48{
49 /// Generate a new retry
50 fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn>;
51}
52
53impl<B, T, E, Fut, FutureFn> Retryable<B, T, E, Fut, FutureFn> for FutureFn
54where
55 B: BackoffBuilder,
56 Fut: Future<Output = Result<T, E>>,
57 FutureFn: FnMut() -> Fut,
58{
59 fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn> {
60 Retry::new(self, builder.build())
61 }
62}
63
64/// Struct generated by [`Retryable`].
65pub struct Retry<
66 B: Backoff,
67 T,
68 E,
69 Fut: Future<Output = Result<T, E>>,
70 FutureFn: FnMut() -> Fut,
71 SF: MaybeSleeper = DefaultSleeper,
72 RF = fn(&E) -> bool,
73 NF = fn(&E, Duration),
74 AF = fn(&E, Option<Duration>) -> Option<Duration>,
75> {
76 backoff: B,
77 future_fn: FutureFn,
78
79 retryable_fn: RF,
80 notify_fn: NF,
81 sleep_fn: SF,
82 adjust_fn: AF,
83
84 state: State<T, E, Fut, SF::Sleep>,
85}
86
87impl<B, T, E, Fut, FutureFn> Retry<B, T, E, Fut, FutureFn>
88where
89 B: Backoff,
90 Fut: Future<Output = Result<T, E>>,
91 FutureFn: FnMut() -> Fut,
92{
93 /// Initiate a new retry.
94 fn new(future_fn: FutureFn, backoff: B) -> Self {
95 Retry {
96 backoff,
97 future_fn,
98
99 retryable_fn: |_: &E| true,
100 notify_fn: |_: &E, _: Duration| {},
101 adjust_fn: |_: &E, dur: Option<Duration>| dur,
102 sleep_fn: DefaultSleeper::default(),
103
104 state: State::Idle,
105 }
106 }
107}
108
109impl<B, T, E, Fut, FutureFn, SF, RF, NF, AF> Retry<B, T, E, Fut, FutureFn, SF, RF, NF, AF>
110where
111 B: Backoff,
112 Fut: Future<Output = Result<T, E>>,
113 FutureFn: FnMut() -> Fut,
114 SF: MaybeSleeper,
115 RF: FnMut(&E) -> bool,
116 NF: FnMut(&E, Duration),
117 AF: FnMut(&E, Option<Duration>) -> Option<Duration>,
118{
119 /// Set the sleeper for retrying.
120 ///
121 /// The sleeper should implement the [`Sleeper`] trait. The simplest way is to use a closure that returns a `Future<Output=()>`.
122 ///
123 /// If not specified, we use the [`DefaultSleeper`].
124 ///
125 /// ```no_run
126 /// use std::future::ready;
127 ///
128 /// use anyhow::Result;
129 /// use backon::ExponentialBuilder;
130 /// use backon::Retryable;
131 ///
132 /// async fn fetch() -> Result<String> {
133 /// Ok(reqwest::get("https://www.rust-lang.org")
134 /// .await?
135 /// .text()
136 /// .await?)
137 /// }
138 ///
139 /// #[tokio::main(flavor = "current_thread")]
140 /// async fn main() -> Result<()> {
141 /// let content = fetch
142 /// .retry(ExponentialBuilder::default())
143 /// .sleep(|_| ready(()))
144 /// .await?;
145 /// println!("fetch succeeded: {}", content);
146 ///
147 /// Ok(())
148 /// }
149 /// ```
150 pub fn sleep<SN: Sleeper>(self, sleep_fn: SN) -> Retry<B, T, E, Fut, FutureFn, SN, RF, NF, AF> {
151 Retry {
152 backoff: self.backoff,
153 retryable_fn: self.retryable_fn,
154 notify_fn: self.notify_fn,
155 future_fn: self.future_fn,
156 sleep_fn,
157 adjust_fn: self.adjust_fn,
158 state: State::Idle,
159 }
160 }
161
162 /// Set the conditions for retrying.
163 ///
164 /// If not specified, all errors are considered retryable.
165 ///
166 /// # Examples
167 ///
168 /// ```no_run
169 /// use anyhow::Result;
170 /// use backon::ExponentialBuilder;
171 /// use backon::Retryable;
172 ///
173 /// async fn fetch() -> Result<String> {
174 /// Ok(reqwest::get("https://www.rust-lang.org")
175 /// .await?
176 /// .text()
177 /// .await?)
178 /// }
179 ///
180 /// #[tokio::main(flavor = "current_thread")]
181 /// async fn main() -> Result<()> {
182 /// let content = fetch
183 /// .retry(ExponentialBuilder::default())
184 /// .when(|e| e.to_string() == "EOF")
185 /// .await?;
186 /// println!("fetch succeeded: {}", content);
187 ///
188 /// Ok(())
189 /// }
190 /// ```
191 pub fn when<RN: FnMut(&E) -> bool>(
192 self,
193 retryable: RN,
194 ) -> Retry<B, T, E, Fut, FutureFn, SF, RN, NF, AF> {
195 Retry {
196 backoff: self.backoff,
197 retryable_fn: retryable,
198 notify_fn: self.notify_fn,
199 future_fn: self.future_fn,
200 sleep_fn: self.sleep_fn,
201 adjust_fn: self.adjust_fn,
202 state: self.state,
203 }
204 }
205
206 /// Set to notify for all retry attempts.
207 ///
208 /// When a retry happens, the input function will be invoked with the error and the sleep duration before pausing.
209 ///
210 /// If not specified, this operation does nothing.
211 ///
212 /// # Examples
213 ///
214 /// ```no_run
215 /// use core::time::Duration;
216 ///
217 /// use anyhow::Result;
218 /// use backon::ExponentialBuilder;
219 /// use backon::Retryable;
220 ///
221 /// async fn fetch() -> Result<String> {
222 /// Ok(reqwest::get("https://www.rust-lang.org")
223 /// .await?
224 /// .text()
225 /// .await?)
226 /// }
227 ///
228 /// #[tokio::main(flavor = "current_thread")]
229 /// async fn main() -> Result<()> {
230 /// let content = fetch
231 /// .retry(ExponentialBuilder::default())
232 /// .notify(|err: &anyhow::Error, dur: Duration| {
233 /// println!("retrying error {:?} with sleeping {:?}", err, dur);
234 /// })
235 /// .await?;
236 /// println!("fetch succeeded: {}", content);
237 ///
238 /// Ok(())
239 /// }
240 /// ```
241 pub fn notify<NN: FnMut(&E, Duration)>(
242 self,
243 notify: NN,
244 ) -> Retry<B, T, E, Fut, FutureFn, SF, RF, NN, AF> {
245 Retry {
246 backoff: self.backoff,
247 retryable_fn: self.retryable_fn,
248 notify_fn: notify,
249 sleep_fn: self.sleep_fn,
250 future_fn: self.future_fn,
251 adjust_fn: self.adjust_fn,
252 state: self.state,
253 }
254 }
255
256 /// Sets the function to adjust the backoff duration for retry attempts.
257 ///
258 /// When a retry occurs, the provided function will be called with the error and the proposed backoff duration, allowing you to modify the final duration used.
259 ///
260 /// If the function returns `None`, it indicates that no further retries should be made, and the error will be returned regardless of the backoff duration provided by the input.
261 ///
262 /// If no `adjust` function is specified, the original backoff duration from the input will be used without modification.
263 ///
264 /// `adjust` can be used to implement dynamic backoff strategies, such as adjust backoff values from the http `Retry-After` headers.
265 ///
266 /// # Examples
267 ///
268 /// ```no_run
269 /// use core::time::Duration;
270 /// use std::error::Error;
271 /// use std::fmt::Display;
272 /// use std::fmt::Formatter;
273 ///
274 /// use anyhow::Result;
275 /// use backon::ExponentialBuilder;
276 /// use backon::Retryable;
277 /// use reqwest::header::HeaderMap;
278 /// use reqwest::StatusCode;
279 ///
280 /// #[derive(Debug)]
281 /// struct HttpError {
282 /// headers: HeaderMap,
283 /// }
284 ///
285 /// impl Display for HttpError {
286 /// fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287 /// write!(f, "http error")
288 /// }
289 /// }
290 ///
291 /// impl Error for HttpError {}
292 ///
293 /// async fn fetch() -> Result<String> {
294 /// let resp = reqwest::get("https://www.rust-lang.org").await?;
295 /// if resp.status() != StatusCode::OK {
296 /// let source = HttpError {
297 /// headers: resp.headers().clone(),
298 /// };
299 /// return Err(anyhow::Error::new(source));
300 /// }
301 /// Ok(resp.text().await?)
302 /// }
303 ///
304 /// #[tokio::main(flavor = "current_thread")]
305 /// async fn main() -> Result<()> {
306 /// let content = fetch
307 /// .retry(ExponentialBuilder::default())
308 /// .adjust(|err, dur| {
309 /// match err.downcast_ref::<HttpError>() {
310 /// Some(v) => {
311 /// if let Some(retry_after) = v.headers.get("Retry-After") {
312 /// // Parse the Retry-After header and adjust the backoff duration
313 /// let retry_after = retry_after.to_str().unwrap_or("0");
314 /// let retry_after = retry_after.parse::<u64>().unwrap_or(0);
315 /// Some(Duration::from_secs(retry_after))
316 /// } else {
317 /// dur
318 /// }
319 /// }
320 /// None => dur,
321 /// }
322 /// })
323 /// .await?;
324 /// println!("fetch succeeded: {}", content);
325 ///
326 /// Ok(())
327 /// }
328 /// ```
329 pub fn adjust<NAF: FnMut(&E, Option<Duration>) -> Option<Duration>>(
330 self,
331 adjust: NAF,
332 ) -> Retry<B, T, E, Fut, FutureFn, SF, RF, NF, NAF> {
333 Retry {
334 backoff: self.backoff,
335 retryable_fn: self.retryable_fn,
336 notify_fn: self.notify_fn,
337 sleep_fn: self.sleep_fn,
338 future_fn: self.future_fn,
339 adjust_fn: adjust,
340 state: self.state,
341 }
342 }
343}
344
345/// State maintains internal state of retry.
346#[derive(Default)]
347enum State<T, E, Fut: Future<Output = Result<T, E>>, SleepFut: Future<Output = ()>> {
348 #[default]
349 Idle,
350 Polling(Fut),
351 Sleeping(SleepFut),
352}
353
354impl<B, T, E, Fut, FutureFn, SF, RF, NF, AF> Future
355 for Retry<B, T, E, Fut, FutureFn, SF, RF, NF, AF>
356where
357 B: Backoff,
358 Fut: Future<Output = Result<T, E>>,
359 FutureFn: FnMut() -> Fut,
360 SF: Sleeper,
361 RF: FnMut(&E) -> bool,
362 NF: FnMut(&E, Duration),
363 AF: FnMut(&E, Option<Duration>) -> Option<Duration>,
364{
365 type Output = Result<T, E>;
366
367 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368 // Safety: This is safe because we don't move the `Retry` struct itself,
369 // only its internal state.
370 //
371 // We do the exactly same thing like `pin_project` but without depending on it directly.
372 let this = unsafe { self.get_unchecked_mut() };
373
374 loop {
375 match &mut this.state {
376 State::Idle => {
377 let fut = (this.future_fn)();
378 this.state = State::Polling(fut);
379 continue;
380 }
381 State::Polling(fut) => {
382 // Safety: This is safe because we don't move the `Retry` struct and this fut,
383 // only its internal state.
384 //
385 // We do the exactly same thing like `pin_project` but without depending on it directly.
386 let mut fut = unsafe { Pin::new_unchecked(fut) };
387
388 match ready!(fut.as_mut().poll(cx)) {
389 Ok(v) => return Poll::Ready(Ok(v)),
390 Err(err) => {
391 // If input error is not retryable, return error directly.
392 if !(this.retryable_fn)(&err) {
393 return Poll::Ready(Err(err));
394 }
395 let adjusted_backoff = (this.adjust_fn)(&err, this.backoff.next());
396 match adjusted_backoff {
397 None => return Poll::Ready(Err(err)),
398 Some(dur) => {
399 (this.notify_fn)(&err, dur);
400 this.state = State::Sleeping(this.sleep_fn.sleep(dur));
401 continue;
402 }
403 }
404 }
405 }
406 }
407 State::Sleeping(sl) => {
408 // Safety: This is safe because we don't move the `Retry` struct and this fut,
409 // only its internal state.
410 //
411 // We do the exactly same thing like `pin_project` but without depending on it directly.
412 let mut sl = unsafe { Pin::new_unchecked(sl) };
413
414 ready!(sl.as_mut().poll(cx));
415 this.state = State::Idle;
416 continue;
417 }
418 }
419 }
420 }
421}
422
423#[cfg(test)]
424#[cfg(any(feature = "tokio-sleep", feature = "gloo-timers-sleep",))]
425mod default_sleeper_tests {
426 extern crate alloc;
427
428 use alloc::string::ToString;
429 use alloc::vec;
430 use alloc::vec::Vec;
431 use core::time::Duration;
432
433 use tokio::sync::Mutex;
434 #[cfg(not(target_arch = "wasm32"))]
435 use tokio::test;
436 #[cfg(target_arch = "wasm32")]
437 use wasm_bindgen_test::wasm_bindgen_test as test;
438
439 use super::*;
440 use crate::ExponentialBuilder;
441
442 async fn always_error() -> anyhow::Result<()> {
443 Err(anyhow::anyhow!("test_query meets error"))
444 }
445
446 #[test]
447 async fn test_retry() {
448 let result = always_error
449 .retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
450 .await;
451
452 assert!(result.is_err());
453 assert_eq!("test_query meets error", result.unwrap_err().to_string());
454 }
455
456 #[test]
457 async fn test_retry_with_not_retryable_error() {
458 let error_times = Mutex::new(0);
459
460 let f = || async {
461 let mut x = error_times.lock().await;
462 *x += 1;
463 Err::<(), anyhow::Error>(anyhow::anyhow!("not retryable"))
464 };
465
466 let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
467 let result = f
468 .retry(backoff)
469 // Only retry If error message is `retryable`
470 .when(|e| e.to_string() == "retryable")
471 .await;
472
473 assert!(result.is_err());
474 assert_eq!("not retryable", result.unwrap_err().to_string());
475 // `f` always returns error "not retryable", so it should be executed
476 // only once.
477 assert_eq!(*error_times.lock().await, 1);
478 }
479
480 #[test]
481 async fn test_retry_with_retryable_error() {
482 let error_times = Mutex::new(0);
483
484 let f = || async {
485 let mut x = error_times.lock().await;
486 *x += 1;
487 Err::<(), anyhow::Error>(anyhow::anyhow!("retryable"))
488 };
489
490 let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
491 let result = f
492 .retry(backoff)
493 // Only retry If error message is `retryable`
494 .when(|e| e.to_string() == "retryable")
495 .await;
496
497 assert!(result.is_err());
498 assert_eq!("retryable", result.unwrap_err().to_string());
499 // `f` always returns error "retryable", so it should be executed
500 // 4 times (retry 3 times).
501 assert_eq!(*error_times.lock().await, 4);
502 }
503
504 #[test]
505 async fn test_retry_with_adjust() {
506 let error_times = std::sync::Mutex::new(0);
507
508 let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };
509
510 let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
511 let result = f
512 .retry(backoff)
513 // Only retry If error message is `retryable`
514 .when(|e| e.to_string() == "retryable")
515 .adjust(|_, dur| {
516 let mut x = error_times.lock().unwrap();
517 *x += 1;
518 dur
519 })
520 .await;
521
522 assert!(result.is_err());
523 assert_eq!("retryable", result.unwrap_err().to_string());
524 // `f` always returns error "retryable", so it should be executed
525 // 4 times (retry 3 times).
526 assert_eq!(*error_times.lock().unwrap(), 4);
527 }
528
529 #[test]
530 async fn test_fn_mut_when_and_notify() {
531 let mut calls_retryable: Vec<()> = vec![];
532 let mut calls_notify: Vec<()> = vec![];
533
534 let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };
535
536 let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
537 let result = f
538 .retry(backoff)
539 .when(|_| {
540 calls_retryable.push(());
541 true
542 })
543 .notify(|_, _| {
544 calls_notify.push(());
545 })
546 .await;
547
548 assert!(result.is_err());
549 assert_eq!("retryable", result.unwrap_err().to_string());
550 // `f` always returns error "retryable", so it should be executed
551 // 4 times (retry 3 times).
552 assert_eq!(calls_retryable.len(), 4);
553 assert_eq!(calls_notify.len(), 3);
554 }
555}
556
557#[cfg(test)]
558mod custom_sleeper_tests {
559 extern crate alloc;
560
561 use alloc::string::ToString;
562 use core::future::ready;
563 use core::time::Duration;
564
565 #[cfg(not(target_arch = "wasm32"))]
566 use tokio::test;
567 #[cfg(target_arch = "wasm32")]
568 use wasm_bindgen_test::wasm_bindgen_test as test;
569
570 use super::*;
571 use crate::ExponentialBuilder;
572
573 async fn always_error() -> anyhow::Result<()> {
574 Err(anyhow::anyhow!("test_query meets error"))
575 }
576
577 #[test]
578 async fn test_retry_with_sleep() {
579 let result = always_error
580 .retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
581 .sleep(|_| ready(()))
582 .await;
583
584 assert!(result.is_err());
585 assert_eq!("test_query meets error", result.unwrap_err().to_string());
586 }
587}