backon/
retry_with_context.rs

1use core::future::Future;
2use core::pin::Pin;
3use core::task::ready;
4use core::task::Context;
5use core::task::Poll;
6use core::time::Duration;
7
8use crate::backoff::BackoffBuilder;
9use crate::sleep::MaybeSleeper;
10use crate::Backoff;
11use crate::DefaultSleeper;
12use crate::Sleeper;
13
14/// `RetryableWithContext` adds retry support for functions that produce futures with results
15/// and context.
16///
17/// This means all types implementing `FnMut(Ctx) -> impl Future<Output = (Ctx, Result<T, E>)>`
18/// can use `retry`.
19///
20/// Users must provide context to the function and can receive it back after the retry is completed.
21///
22/// # Example
23///
24/// Without context, we might encounter errors such as the following:
25///
26/// ```shell
27/// error: captured variable cannot escape `FnMut` closure body
28///    --> src/retry.rs:404:27
29///     |
30/// 400 |         let mut test = Test;
31///     |             -------- variable defined here
32/// ...
33/// 404 |         let result = { || async { test.hello().await } }
34///     |                         - ^^^^^^^^----^^^^^^^^^^^^^^^^
35///     |                         | |       |
36///     |                         | |       variable captured here
37///     |                         | returns an `async` block that contains a reference to a captured variable, which then escapes the closure body
38///     |                         inferred to be a `FnMut` closure
39///     |
40///     = note: `FnMut` closures only have access to their captured variables while they are executing...
41///     = note: ...therefore, they cannot allow references to captured variables to escape
42/// ```
43///
44/// However, with context support, we can implement it this way:
45///
46/// ```no_run
47/// use anyhow::anyhow;
48/// use anyhow::Result;
49/// use backon::ExponentialBuilder;
50/// use backon::RetryableWithContext;
51///
52/// struct Test;
53///
54/// impl Test {
55///     async fn hello(&mut self) -> Result<usize> {
56///         Err(anyhow!("not retryable"))
57///     }
58/// }
59///
60/// #[tokio::main(flavor = "current_thread")]
61/// async fn main() -> Result<()> {
62///     let mut test = Test;
63///
64///     // (Test, Result<usize>)
65///     let (_, result) = {
66///         |mut v: Test| async {
67///             let res = v.hello().await;
68///             (v, res)
69///         }
70///     }
71///     .retry(ExponentialBuilder::default())
72///     .context(test)
73///     .await;
74///
75///     Ok(())
76/// }
77/// ```
78pub trait RetryableWithContext<
79    B: BackoffBuilder,
80    T,
81    E,
82    Ctx,
83    Fut: Future<Output = (Ctx, Result<T, E>)>,
84    FutureFn: FnMut(Ctx) -> Fut,
85>
86{
87    /// Generate a new retry
88    fn retry(self, builder: B) -> RetryWithContext<B::Backoff, T, E, Ctx, Fut, FutureFn>;
89}
90
91impl<B, T, E, Ctx, Fut, FutureFn> RetryableWithContext<B, T, E, Ctx, Fut, FutureFn> for FutureFn
92where
93    B: BackoffBuilder,
94    Fut: Future<Output = (Ctx, Result<T, E>)>,
95    FutureFn: FnMut(Ctx) -> Fut,
96{
97    fn retry(self, builder: B) -> RetryWithContext<B::Backoff, T, E, Ctx, Fut, FutureFn> {
98        RetryWithContext::new(self, builder.build())
99    }
100}
101
102/// Retry struct generated by [`RetryableWithContext`].
103pub struct RetryWithContext<
104    B: Backoff,
105    T,
106    E,
107    Ctx,
108    Fut: Future<Output = (Ctx, Result<T, E>)>,
109    FutureFn: FnMut(Ctx) -> Fut,
110    SF: MaybeSleeper = DefaultSleeper,
111    RF = fn(&E) -> bool,
112    NF = fn(&E, Duration),
113> {
114    backoff: B,
115    retryable: RF,
116    notify: NF,
117    future_fn: FutureFn,
118    sleep_fn: SF,
119
120    state: State<T, E, Ctx, Fut, SF::Sleep>,
121}
122
123impl<B, T, E, Ctx, Fut, FutureFn> RetryWithContext<B, T, E, Ctx, Fut, FutureFn>
124where
125    B: Backoff,
126    Fut: Future<Output = (Ctx, Result<T, E>)>,
127    FutureFn: FnMut(Ctx) -> Fut,
128{
129    /// Create a new retry.
130    fn new(future_fn: FutureFn, backoff: B) -> Self {
131        RetryWithContext {
132            backoff,
133            retryable: |_: &E| true,
134            notify: |_: &E, _: Duration| {},
135            future_fn,
136            sleep_fn: DefaultSleeper::default(),
137            state: State::Idle(None),
138        }
139    }
140}
141
142impl<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
143    RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
144where
145    B: Backoff,
146    Fut: Future<Output = (Ctx, Result<T, E>)>,
147    FutureFn: FnMut(Ctx) -> Fut,
148    SF: MaybeSleeper,
149    RF: FnMut(&E) -> bool,
150    NF: FnMut(&E, Duration),
151{
152    /// Set the sleeper for retrying.
153    ///
154    /// The sleeper should implement the [`Sleeper`] trait. The simplest way is to use a closure that returns a `Future<Output=()>`.
155    ///
156    /// If not specified, we use the [`DefaultSleeper`].
157    pub fn sleep<SN: Sleeper>(
158        self,
159        sleep_fn: SN,
160    ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SN, RF, NF> {
161        assert!(
162            matches!(self.state, State::Idle(None)),
163            "sleep must be set before context"
164        );
165
166        RetryWithContext {
167            backoff: self.backoff,
168            retryable: self.retryable,
169            notify: self.notify,
170            future_fn: self.future_fn,
171            sleep_fn,
172            state: State::Idle(None),
173        }
174    }
175
176    /// Set the context for retrying.
177    ///
178    /// Context is used to capture ownership manually to prevent lifetime issues.
179    pub fn context(
180        self,
181        context: Ctx,
182    ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF> {
183        RetryWithContext {
184            backoff: self.backoff,
185            retryable: self.retryable,
186            notify: self.notify,
187            future_fn: self.future_fn,
188            sleep_fn: self.sleep_fn,
189            state: State::Idle(Some(context)),
190        }
191    }
192
193    /// Set the conditions for retrying.
194    ///
195    /// If not specified, all errors are considered retryable.
196    ///
197    /// # Examples
198    ///
199    /// ```no_run
200    /// use anyhow::Result;
201    /// use backon::ExponentialBuilder;
202    /// use backon::Retryable;
203    ///
204    /// async fn fetch() -> Result<String> {
205    ///     Ok(reqwest::get("https://www.rust-lang.org")
206    ///         .await?
207    ///         .text()
208    ///         .await?)
209    /// }
210    ///
211    /// #[tokio::main(flavor = "current_thread")]
212    /// async fn main() -> Result<()> {
213    ///     let content = fetch
214    ///         .retry(ExponentialBuilder::default())
215    ///         .when(|e| e.to_string() == "EOF")
216    ///         .await?;
217    ///     println!("fetch succeeded: {}", content);
218    ///
219    ///     Ok(())
220    /// }
221    /// ```
222    pub fn when<RN: FnMut(&E) -> bool>(
223        self,
224        retryable: RN,
225    ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RN, NF> {
226        RetryWithContext {
227            backoff: self.backoff,
228            retryable,
229            notify: self.notify,
230            future_fn: self.future_fn,
231            sleep_fn: self.sleep_fn,
232            state: self.state,
233        }
234    }
235
236    /// Set to notify for all retry attempts.
237    ///
238    /// When a retry happens, the input function will be invoked with the error and the sleep duration before pausing.
239    ///
240    /// If not specified, this operation does nothing.
241    ///
242    /// # Examples
243    ///
244    /// ```no_run
245    /// use core::time::Duration;
246    ///
247    /// use anyhow::Result;
248    /// use backon::ExponentialBuilder;
249    /// use backon::Retryable;
250    ///
251    /// async fn fetch() -> Result<String> {
252    ///     Ok(reqwest::get("https://www.rust-lang.org")
253    ///         .await?
254    ///         .text()
255    ///         .await?)
256    /// }
257    ///
258    /// #[tokio::main(flavor = "current_thread")]
259    /// async fn main() -> Result<()> {
260    ///     let content = fetch
261    ///         .retry(ExponentialBuilder::default())
262    ///         .notify(|err: &anyhow::Error, dur: Duration| {
263    ///             println!("retrying error {:?} with sleeping {:?}", err, dur);
264    ///         })
265    ///         .await?;
266    ///     println!("fetch succeeded: {}", content);
267    ///
268    ///     Ok(())
269    /// }
270    /// ```
271    pub fn notify<NN: FnMut(&E, Duration)>(
272        self,
273        notify: NN,
274    ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NN> {
275        RetryWithContext {
276            backoff: self.backoff,
277            retryable: self.retryable,
278            notify,
279            future_fn: self.future_fn,
280            sleep_fn: self.sleep_fn,
281            state: self.state,
282        }
283    }
284}
285
286/// State maintains internal state of retry.
287enum State<T, E, Ctx, Fut: Future<Output = (Ctx, Result<T, E>)>, SleepFut: Future<Output = ()>> {
288    Idle(Option<Ctx>),
289    Polling(Fut),
290    Sleeping((Option<Ctx>, SleepFut)),
291}
292
293impl<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF> Future
294    for RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
295where
296    B: Backoff,
297    Fut: Future<Output = (Ctx, Result<T, E>)>,
298    FutureFn: FnMut(Ctx) -> Fut,
299    SF: Sleeper,
300    RF: FnMut(&E) -> bool,
301    NF: FnMut(&E, Duration),
302{
303    type Output = (Ctx, Result<T, E>);
304
305    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
306        // Safety: This is safe because we don't move the `Retry` struct itself,
307        // only its internal state.
308        //
309        // We do the exactly same thing like `pin_project` but without depending on it directly.
310        let this = unsafe { self.get_unchecked_mut() };
311
312        loop {
313            match &mut this.state {
314                State::Idle(ctx) => {
315                    let ctx = ctx.take().expect("context must be valid");
316                    let fut = (this.future_fn)(ctx);
317                    this.state = State::Polling(fut);
318                    continue;
319                }
320                State::Polling(fut) => {
321                    // Safety: This is safe because we don't move the `Retry` struct and this fut,
322                    // only its internal state.
323                    //
324                    // We do the exactly same thing like `pin_project` but without depending on it directly.
325                    let mut fut = unsafe { Pin::new_unchecked(fut) };
326
327                    let (ctx, res) = ready!(fut.as_mut().poll(cx));
328                    match res {
329                        Ok(v) => return Poll::Ready((ctx, Ok(v))),
330                        Err(err) => {
331                            // If input error is not retryable, return error directly.
332                            if !(this.retryable)(&err) {
333                                return Poll::Ready((ctx, Err(err)));
334                            }
335                            match this.backoff.next() {
336                                None => return Poll::Ready((ctx, Err(err))),
337                                Some(dur) => {
338                                    (this.notify)(&err, dur);
339                                    this.state =
340                                        State::Sleeping((Some(ctx), this.sleep_fn.sleep(dur)));
341                                    continue;
342                                }
343                            }
344                        }
345                    }
346                }
347                State::Sleeping((ctx, sl)) => {
348                    // Safety: This is safe because we don't move the `Retry` struct and this fut,
349                    // only its internal state.
350                    //
351                    // We do the exactly same thing like `pin_project` but without depending on it directly.
352                    let mut sl = unsafe { Pin::new_unchecked(sl) };
353
354                    ready!(sl.as_mut().poll(cx));
355                    let ctx = ctx.take().expect("context must be valid");
356                    this.state = State::Idle(Some(ctx));
357                    continue;
358                }
359            }
360        }
361    }
362}
363
364#[cfg(test)]
365#[cfg(any(feature = "tokio-sleep", feature = "gloo-timers-sleep",))]
366mod tests {
367    extern crate alloc;
368
369    use alloc::string::ToString;
370    use core::time::Duration;
371
372    use anyhow::anyhow;
373    use anyhow::Result;
374    use tokio::sync::Mutex;
375    #[cfg(not(target_arch = "wasm32"))]
376    use tokio::test;
377    #[cfg(target_arch = "wasm32")]
378    use wasm_bindgen_test::wasm_bindgen_test as test;
379
380    use super::*;
381    use crate::ExponentialBuilder;
382
383    struct Test;
384
385    impl Test {
386        async fn hello(&mut self) -> Result<usize> {
387            Err(anyhow!("not retryable"))
388        }
389    }
390
391    #[test]
392    async fn test_retry_with_not_retryable_error() {
393        let error_times = Mutex::new(0);
394
395        let test = Test;
396
397        let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
398
399        let (_, result) = {
400            |mut v: Test| async {
401                let mut x = error_times.lock().await;
402                *x += 1;
403
404                let res = v.hello().await;
405                (v, res)
406            }
407        }
408        .retry(backoff)
409        .context(test)
410        // Only retry If error message is `retryable`
411        .when(|e| e.to_string() == "retryable")
412        .await;
413
414        assert!(result.is_err());
415        assert_eq!("not retryable", result.unwrap_err().to_string());
416        // `f` always returns error "not retryable", so it should be executed
417        // only once.
418        assert_eq!(*error_times.lock().await, 1);
419    }
420}