backon/
retry.rs

1use core::future::Future;
2use core::pin::Pin;
3use core::task::ready;
4use core::task::Context;
5use core::task::Poll;
6use core::time::Duration;
7
8use crate::backoff::BackoffBuilder;
9use crate::sleep::MaybeSleeper;
10use crate::Backoff;
11use crate::DefaultSleeper;
12use crate::Sleeper;
13
14/// Retryable will add retry support for functions that produce futures with results.
15///
16/// This means all types that implement `FnMut() -> impl Future<Output = Result<T, E>>`
17/// will be able to use `retry`.
18///
19/// For example:
20///
21/// - Functions without extra args:
22///
23/// ```ignore
24/// async fn fetch() -> Result<String> {
25///     Ok(reqwest::get("https://www.rust-lang.org").await?.text().await?)
26/// }
27/// ```
28///
29/// - Closures
30///
31/// ```ignore
32/// || async {
33///     let x = reqwest::get("https://www.rust-lang.org")
34///         .await?
35///         .text()
36///         .await?;
37///
38///     Err(anyhow::anyhow!(x))
39/// }
40/// ```
41pub trait Retryable<
42    B: BackoffBuilder,
43    T,
44    E,
45    Fut: Future<Output = Result<T, E>>,
46    FutureFn: FnMut() -> Fut,
47>
48{
49    /// Generate a new retry
50    fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn>;
51}
52
53impl<B, T, E, Fut, FutureFn> Retryable<B, T, E, Fut, FutureFn> for FutureFn
54where
55    B: BackoffBuilder,
56    Fut: Future<Output = Result<T, E>>,
57    FutureFn: FnMut() -> Fut,
58{
59    fn retry(self, builder: B) -> Retry<B::Backoff, T, E, Fut, FutureFn> {
60        Retry::new(self, builder.build())
61    }
62}
63
64/// Struct generated by [`Retryable`].
65pub struct Retry<
66    B: Backoff,
67    T,
68    E,
69    Fut: Future<Output = Result<T, E>>,
70    FutureFn: FnMut() -> Fut,
71    SF: MaybeSleeper = DefaultSleeper,
72    RF = fn(&E) -> bool,
73    NF = fn(&E, Duration),
74    AF = fn(&E, Option<Duration>) -> Option<Duration>,
75> {
76    backoff: B,
77    future_fn: FutureFn,
78
79    retryable_fn: RF,
80    notify_fn: NF,
81    sleep_fn: SF,
82    adjust_fn: AF,
83
84    state: State<T, E, Fut, SF::Sleep>,
85}
86
87impl<B, T, E, Fut, FutureFn> Retry<B, T, E, Fut, FutureFn>
88where
89    B: Backoff,
90    Fut: Future<Output = Result<T, E>>,
91    FutureFn: FnMut() -> Fut,
92{
93    /// Initiate a new retry.
94    fn new(future_fn: FutureFn, backoff: B) -> Self {
95        Retry {
96            backoff,
97            future_fn,
98
99            retryable_fn: |_: &E| true,
100            notify_fn: |_: &E, _: Duration| {},
101            adjust_fn: |_: &E, dur: Option<Duration>| dur,
102            sleep_fn: DefaultSleeper::default(),
103
104            state: State::Idle,
105        }
106    }
107}
108
109impl<B, T, E, Fut, FutureFn, SF, RF, NF, AF> Retry<B, T, E, Fut, FutureFn, SF, RF, NF, AF>
110where
111    B: Backoff,
112    Fut: Future<Output = Result<T, E>>,
113    FutureFn: FnMut() -> Fut,
114    SF: MaybeSleeper,
115    RF: FnMut(&E) -> bool,
116    NF: FnMut(&E, Duration),
117    AF: FnMut(&E, Option<Duration>) -> Option<Duration>,
118{
119    /// Set the sleeper for retrying.
120    ///
121    /// The sleeper should implement the [`Sleeper`] trait. The simplest way is to use a closure that returns a `Future<Output=()>`.
122    ///
123    /// If not specified, we use the [`DefaultSleeper`].
124    ///
125    /// ```no_run
126    /// use std::future::ready;
127    ///
128    /// use anyhow::Result;
129    /// use backon::ExponentialBuilder;
130    /// use backon::Retryable;
131    ///
132    /// async fn fetch() -> Result<String> {
133    ///     Ok(reqwest::get("https://www.rust-lang.org")
134    ///         .await?
135    ///         .text()
136    ///         .await?)
137    /// }
138    ///
139    /// #[tokio::main(flavor = "current_thread")]
140    /// async fn main() -> Result<()> {
141    ///     let content = fetch
142    ///         .retry(ExponentialBuilder::default())
143    ///         .sleep(|_| ready(()))
144    ///         .await?;
145    ///     println!("fetch succeeded: {}", content);
146    ///
147    ///     Ok(())
148    /// }
149    /// ```
150    pub fn sleep<SN: Sleeper>(self, sleep_fn: SN) -> Retry<B, T, E, Fut, FutureFn, SN, RF, NF, AF> {
151        Retry {
152            backoff: self.backoff,
153            retryable_fn: self.retryable_fn,
154            notify_fn: self.notify_fn,
155            future_fn: self.future_fn,
156            sleep_fn,
157            adjust_fn: self.adjust_fn,
158            state: State::Idle,
159        }
160    }
161
162    /// Set the conditions for retrying.
163    ///
164    /// If not specified, all errors are considered retryable.
165    ///
166    /// # Examples
167    ///
168    /// ```no_run
169    /// use anyhow::Result;
170    /// use backon::ExponentialBuilder;
171    /// use backon::Retryable;
172    ///
173    /// async fn fetch() -> Result<String> {
174    ///     Ok(reqwest::get("https://www.rust-lang.org")
175    ///         .await?
176    ///         .text()
177    ///         .await?)
178    /// }
179    ///
180    /// #[tokio::main(flavor = "current_thread")]
181    /// async fn main() -> Result<()> {
182    ///     let content = fetch
183    ///         .retry(ExponentialBuilder::default())
184    ///         .when(|e| e.to_string() == "EOF")
185    ///         .await?;
186    ///     println!("fetch succeeded: {}", content);
187    ///
188    ///     Ok(())
189    /// }
190    /// ```
191    pub fn when<RN: FnMut(&E) -> bool>(
192        self,
193        retryable: RN,
194    ) -> Retry<B, T, E, Fut, FutureFn, SF, RN, NF, AF> {
195        Retry {
196            backoff: self.backoff,
197            retryable_fn: retryable,
198            notify_fn: self.notify_fn,
199            future_fn: self.future_fn,
200            sleep_fn: self.sleep_fn,
201            adjust_fn: self.adjust_fn,
202            state: self.state,
203        }
204    }
205
206    /// Set to notify for all retry attempts.
207    ///
208    /// When a retry happens, the input function will be invoked with the error and the sleep duration before pausing.
209    ///
210    /// If not specified, this operation does nothing.
211    ///
212    /// # Examples
213    ///
214    /// ```no_run
215    /// use core::time::Duration;
216    ///
217    /// use anyhow::Result;
218    /// use backon::ExponentialBuilder;
219    /// use backon::Retryable;
220    ///
221    /// async fn fetch() -> Result<String> {
222    ///     Ok(reqwest::get("https://www.rust-lang.org")
223    ///         .await?
224    ///         .text()
225    ///         .await?)
226    /// }
227    ///
228    /// #[tokio::main(flavor = "current_thread")]
229    /// async fn main() -> Result<()> {
230    ///     let content = fetch
231    ///         .retry(ExponentialBuilder::default())
232    ///         .notify(|err: &anyhow::Error, dur: Duration| {
233    ///             println!("retrying error {:?} with sleeping {:?}", err, dur);
234    ///         })
235    ///         .await?;
236    ///     println!("fetch succeeded: {}", content);
237    ///
238    ///     Ok(())
239    /// }
240    /// ```
241    pub fn notify<NN: FnMut(&E, Duration)>(
242        self,
243        notify: NN,
244    ) -> Retry<B, T, E, Fut, FutureFn, SF, RF, NN, AF> {
245        Retry {
246            backoff: self.backoff,
247            retryable_fn: self.retryable_fn,
248            notify_fn: notify,
249            sleep_fn: self.sleep_fn,
250            future_fn: self.future_fn,
251            adjust_fn: self.adjust_fn,
252            state: self.state,
253        }
254    }
255
256    /// Sets the function to adjust the backoff duration for retry attempts.
257    ///
258    /// When a retry occurs, the provided function will be called with the error and the proposed backoff duration, allowing you to modify the final duration used.
259    ///
260    /// If the function returns `None`, it indicates that no further retries should be made, and the error will be returned regardless of the backoff duration provided by the input.
261    ///
262    /// If no `adjust` function is specified, the original backoff duration from the input will be used without modification.
263    ///
264    /// `adjust` can be used to implement dynamic backoff strategies, such as adjust backoff values from the http `Retry-After` headers.
265    ///
266    /// # Examples
267    ///
268    /// ```no_run
269    /// use core::time::Duration;
270    /// use std::error::Error;
271    /// use std::fmt::Display;
272    /// use std::fmt::Formatter;
273    ///
274    /// use anyhow::Result;
275    /// use backon::ExponentialBuilder;
276    /// use backon::Retryable;
277    /// use reqwest::header::HeaderMap;
278    /// use reqwest::StatusCode;
279    ///
280    /// #[derive(Debug)]
281    /// struct HttpError {
282    ///     headers: HeaderMap,
283    /// }
284    ///
285    /// impl Display for HttpError {
286    ///     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287    ///         write!(f, "http error")
288    ///     }
289    /// }
290    ///
291    /// impl Error for HttpError {}
292    ///
293    /// async fn fetch() -> Result<String> {
294    ///     let resp = reqwest::get("https://www.rust-lang.org").await?;
295    ///     if resp.status() != StatusCode::OK {
296    ///         let source = HttpError {
297    ///             headers: resp.headers().clone(),
298    ///         };
299    ///         return Err(anyhow::Error::new(source));
300    ///     }
301    ///     Ok(resp.text().await?)
302    /// }
303    ///
304    /// #[tokio::main(flavor = "current_thread")]
305    /// async fn main() -> Result<()> {
306    ///     let content = fetch
307    ///         .retry(ExponentialBuilder::default())
308    ///         .adjust(|err, dur| {
309    ///             match err.downcast_ref::<HttpError>() {
310    ///                 Some(v) => {
311    ///                     if let Some(retry_after) = v.headers.get("Retry-After") {
312    ///                         // Parse the Retry-After header and adjust the backoff duration
313    ///                         let retry_after = retry_after.to_str().unwrap_or("0");
314    ///                         let retry_after = retry_after.parse::<u64>().unwrap_or(0);
315    ///                         Some(Duration::from_secs(retry_after))
316    ///                     } else {
317    ///                         dur
318    ///                     }
319    ///                 }
320    ///                 None => dur,
321    ///             }
322    ///         })
323    ///         .await?;
324    ///     println!("fetch succeeded: {}", content);
325    ///
326    ///     Ok(())
327    /// }
328    /// ```
329    pub fn adjust<NAF: FnMut(&E, Option<Duration>) -> Option<Duration>>(
330        self,
331        adjust: NAF,
332    ) -> Retry<B, T, E, Fut, FutureFn, SF, RF, NF, NAF> {
333        Retry {
334            backoff: self.backoff,
335            retryable_fn: self.retryable_fn,
336            notify_fn: self.notify_fn,
337            sleep_fn: self.sleep_fn,
338            future_fn: self.future_fn,
339            adjust_fn: adjust,
340            state: self.state,
341        }
342    }
343}
344
345/// State maintains internal state of retry.
346#[derive(Default)]
347enum State<T, E, Fut: Future<Output = Result<T, E>>, SleepFut: Future<Output = ()>> {
348    #[default]
349    Idle,
350    Polling(Fut),
351    Sleeping(SleepFut),
352}
353
354impl<B, T, E, Fut, FutureFn, SF, RF, NF, AF> Future
355    for Retry<B, T, E, Fut, FutureFn, SF, RF, NF, AF>
356where
357    B: Backoff,
358    Fut: Future<Output = Result<T, E>>,
359    FutureFn: FnMut() -> Fut,
360    SF: Sleeper,
361    RF: FnMut(&E) -> bool,
362    NF: FnMut(&E, Duration),
363    AF: FnMut(&E, Option<Duration>) -> Option<Duration>,
364{
365    type Output = Result<T, E>;
366
367    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
368        // Safety: This is safe because we don't move the `Retry` struct itself,
369        // only its internal state.
370        //
371        // We do the exactly same thing like `pin_project` but without depending on it directly.
372        let this = unsafe { self.get_unchecked_mut() };
373
374        loop {
375            match &mut this.state {
376                State::Idle => {
377                    let fut = (this.future_fn)();
378                    this.state = State::Polling(fut);
379                    continue;
380                }
381                State::Polling(fut) => {
382                    // Safety: This is safe because we don't move the `Retry` struct and this fut,
383                    // only its internal state.
384                    //
385                    // We do the exactly same thing like `pin_project` but without depending on it directly.
386                    let mut fut = unsafe { Pin::new_unchecked(fut) };
387
388                    match ready!(fut.as_mut().poll(cx)) {
389                        Ok(v) => return Poll::Ready(Ok(v)),
390                        Err(err) => {
391                            // If input error is not retryable, return error directly.
392                            if !(this.retryable_fn)(&err) {
393                                return Poll::Ready(Err(err));
394                            }
395                            let adjusted_backoff = (this.adjust_fn)(&err, this.backoff.next());
396                            match adjusted_backoff {
397                                None => return Poll::Ready(Err(err)),
398                                Some(dur) => {
399                                    (this.notify_fn)(&err, dur);
400                                    this.state = State::Sleeping(this.sleep_fn.sleep(dur));
401                                    continue;
402                                }
403                            }
404                        }
405                    }
406                }
407                State::Sleeping(sl) => {
408                    // Safety: This is safe because we don't move the `Retry` struct and this fut,
409                    // only its internal state.
410                    //
411                    // We do the exactly same thing like `pin_project` but without depending on it directly.
412                    let mut sl = unsafe { Pin::new_unchecked(sl) };
413
414                    ready!(sl.as_mut().poll(cx));
415                    this.state = State::Idle;
416                    continue;
417                }
418            }
419        }
420    }
421}
422
423#[cfg(test)]
424#[cfg(any(feature = "tokio-sleep", feature = "gloo-timers-sleep",))]
425mod default_sleeper_tests {
426    extern crate alloc;
427
428    use alloc::string::ToString;
429    use alloc::vec;
430    use alloc::vec::Vec;
431    use core::time::Duration;
432
433    use tokio::sync::Mutex;
434    #[cfg(not(target_arch = "wasm32"))]
435    use tokio::test;
436    #[cfg(target_arch = "wasm32")]
437    use wasm_bindgen_test::wasm_bindgen_test as test;
438
439    use super::*;
440    use crate::ExponentialBuilder;
441
442    async fn always_error() -> anyhow::Result<()> {
443        Err(anyhow::anyhow!("test_query meets error"))
444    }
445
446    #[test]
447    async fn test_retry() {
448        let result = always_error
449            .retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
450            .await;
451
452        assert!(result.is_err());
453        assert_eq!("test_query meets error", result.unwrap_err().to_string());
454    }
455
456    #[test]
457    async fn test_retry_with_not_retryable_error() {
458        let error_times = Mutex::new(0);
459
460        let f = || async {
461            let mut x = error_times.lock().await;
462            *x += 1;
463            Err::<(), anyhow::Error>(anyhow::anyhow!("not retryable"))
464        };
465
466        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
467        let result = f
468            .retry(backoff)
469            // Only retry If error message is `retryable`
470            .when(|e| e.to_string() == "retryable")
471            .await;
472
473        assert!(result.is_err());
474        assert_eq!("not retryable", result.unwrap_err().to_string());
475        // `f` always returns error "not retryable", so it should be executed
476        // only once.
477        assert_eq!(*error_times.lock().await, 1);
478    }
479
480    #[test]
481    async fn test_retry_with_retryable_error() {
482        let error_times = Mutex::new(0);
483
484        let f = || async {
485            let mut x = error_times.lock().await;
486            *x += 1;
487            Err::<(), anyhow::Error>(anyhow::anyhow!("retryable"))
488        };
489
490        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
491        let result = f
492            .retry(backoff)
493            // Only retry If error message is `retryable`
494            .when(|e| e.to_string() == "retryable")
495            .await;
496
497        assert!(result.is_err());
498        assert_eq!("retryable", result.unwrap_err().to_string());
499        // `f` always returns error "retryable", so it should be executed
500        // 4 times (retry 3 times).
501        assert_eq!(*error_times.lock().await, 4);
502    }
503
504    #[test]
505    async fn test_retry_with_adjust() {
506        let error_times = std::sync::Mutex::new(0);
507
508        let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };
509
510        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
511        let result = f
512            .retry(backoff)
513            // Only retry If error message is `retryable`
514            .when(|e| e.to_string() == "retryable")
515            .adjust(|_, dur| {
516                let mut x = error_times.lock().unwrap();
517                *x += 1;
518                dur
519            })
520            .await;
521
522        assert!(result.is_err());
523        assert_eq!("retryable", result.unwrap_err().to_string());
524        // `f` always returns error "retryable", so it should be executed
525        // 4 times (retry 3 times).
526        assert_eq!(*error_times.lock().unwrap(), 4);
527    }
528
529    #[test]
530    async fn test_fn_mut_when_and_notify() {
531        let mut calls_retryable: Vec<()> = vec![];
532        let mut calls_notify: Vec<()> = vec![];
533
534        let f = || async { Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) };
535
536        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
537        let result = f
538            .retry(backoff)
539            .when(|_| {
540                calls_retryable.push(());
541                true
542            })
543            .notify(|_, _| {
544                calls_notify.push(());
545            })
546            .await;
547
548        assert!(result.is_err());
549        assert_eq!("retryable", result.unwrap_err().to_string());
550        // `f` always returns error "retryable", so it should be executed
551        // 4 times (retry 3 times).
552        assert_eq!(calls_retryable.len(), 4);
553        assert_eq!(calls_notify.len(), 3);
554    }
555}
556
557#[cfg(test)]
558mod custom_sleeper_tests {
559    extern crate alloc;
560
561    use alloc::string::ToString;
562    use core::future::ready;
563    use core::time::Duration;
564
565    #[cfg(not(target_arch = "wasm32"))]
566    use tokio::test;
567    #[cfg(target_arch = "wasm32")]
568    use wasm_bindgen_test::wasm_bindgen_test as test;
569
570    use super::*;
571    use crate::ExponentialBuilder;
572
573    async fn always_error() -> anyhow::Result<()> {
574        Err(anyhow::anyhow!("test_query meets error"))
575    }
576
577    #[test]
578    async fn test_retry_with_sleep() {
579        let result = always_error
580            .retry(ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)))
581            .sleep(|_| ready(()))
582            .await;
583
584        assert!(result.is_err());
585        assert_eq!("test_query meets error", result.unwrap_err().to_string());
586    }
587}