backon/retry_with_context.rs
1use core::future::Future;
2use core::pin::Pin;
3use core::task::ready;
4use core::task::Context;
5use core::task::Poll;
6use core::time::Duration;
7
8use crate::backoff::BackoffBuilder;
9use crate::sleep::MaybeSleeper;
10use crate::Backoff;
11use crate::DefaultSleeper;
12use crate::Sleeper;
13
14/// `RetryableWithContext` adds retry support for functions that produce futures with results
15/// and context.
16///
17/// This means all types implementing `FnMut(Ctx) -> impl Future<Output = (Ctx, Result<T, E>)>`
18/// can use `retry`.
19///
20/// Users must provide context to the function and can receive it back after the retry is completed.
21///
22/// # Example
23///
24/// Without context, we might encounter errors such as the following:
25///
26/// ```shell
27/// error: captured variable cannot escape `FnMut` closure body
28/// --> src/retry.rs:404:27
29/// |
30/// 400 | let mut test = Test;
31/// | -------- variable defined here
32/// ...
33/// 404 | let result = { || async { test.hello().await } }
34/// | - ^^^^^^^^----^^^^^^^^^^^^^^^^
35/// | | | |
36/// | | | variable captured here
37/// | | returns an `async` block that contains a reference to a captured variable, which then escapes the closure body
38/// | inferred to be a `FnMut` closure
39/// |
40/// = note: `FnMut` closures only have access to their captured variables while they are executing...
41/// = note: ...therefore, they cannot allow references to captured variables to escape
42/// ```
43///
44/// However, with context support, we can implement it this way:
45///
46/// ```no_run
47/// use anyhow::anyhow;
48/// use anyhow::Result;
49/// use backon::ExponentialBuilder;
50/// use backon::RetryableWithContext;
51///
52/// struct Test;
53///
54/// impl Test {
55/// async fn hello(&mut self) -> Result<usize> {
56/// Err(anyhow!("not retryable"))
57/// }
58/// }
59///
60/// #[tokio::main(flavor = "current_thread")]
61/// async fn main() -> Result<()> {
62/// let mut test = Test;
63///
64/// // (Test, Result<usize>)
65/// let (_, result) = {
66/// |mut v: Test| async {
67/// let res = v.hello().await;
68/// (v, res)
69/// }
70/// }
71/// .retry(ExponentialBuilder::default())
72/// .context(test)
73/// .await;
74///
75/// Ok(())
76/// }
77/// ```
78pub trait RetryableWithContext<
79 B: BackoffBuilder,
80 T,
81 E,
82 Ctx,
83 Fut: Future<Output = (Ctx, Result<T, E>)>,
84 FutureFn: FnMut(Ctx) -> Fut,
85>
86{
87 /// Generate a new retry
88 fn retry(self, builder: B) -> RetryWithContext<B::Backoff, T, E, Ctx, Fut, FutureFn>;
89}
90
91impl<B, T, E, Ctx, Fut, FutureFn> RetryableWithContext<B, T, E, Ctx, Fut, FutureFn> for FutureFn
92where
93 B: BackoffBuilder,
94 Fut: Future<Output = (Ctx, Result<T, E>)>,
95 FutureFn: FnMut(Ctx) -> Fut,
96{
97 fn retry(self, builder: B) -> RetryWithContext<B::Backoff, T, E, Ctx, Fut, FutureFn> {
98 RetryWithContext::new(self, builder.build())
99 }
100}
101
102/// Retry struct generated by [`RetryableWithContext`].
103pub struct RetryWithContext<
104 B: Backoff,
105 T,
106 E,
107 Ctx,
108 Fut: Future<Output = (Ctx, Result<T, E>)>,
109 FutureFn: FnMut(Ctx) -> Fut,
110 SF: MaybeSleeper = DefaultSleeper,
111 RF = fn(&E) -> bool,
112 NF = fn(&E, Duration),
113> {
114 backoff: B,
115 retryable: RF,
116 notify: NF,
117 future_fn: FutureFn,
118 sleep_fn: SF,
119
120 state: State<T, E, Ctx, Fut, SF::Sleep>,
121}
122
123impl<B, T, E, Ctx, Fut, FutureFn> RetryWithContext<B, T, E, Ctx, Fut, FutureFn>
124where
125 B: Backoff,
126 Fut: Future<Output = (Ctx, Result<T, E>)>,
127 FutureFn: FnMut(Ctx) -> Fut,
128{
129 /// Create a new retry.
130 fn new(future_fn: FutureFn, backoff: B) -> Self {
131 RetryWithContext {
132 backoff,
133 retryable: |_: &E| true,
134 notify: |_: &E, _: Duration| {},
135 future_fn,
136 sleep_fn: DefaultSleeper::default(),
137 state: State::Idle(None),
138 }
139 }
140}
141
142impl<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
143 RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
144where
145 B: Backoff,
146 Fut: Future<Output = (Ctx, Result<T, E>)>,
147 FutureFn: FnMut(Ctx) -> Fut,
148 SF: MaybeSleeper,
149 RF: FnMut(&E) -> bool,
150 NF: FnMut(&E, Duration),
151{
152 /// Set the sleeper for retrying.
153 ///
154 /// The sleeper should implement the [`Sleeper`] trait. The simplest way is to use a closure that returns a `Future<Output=()>`.
155 ///
156 /// If not specified, we use the [`DefaultSleeper`].
157 pub fn sleep<SN: Sleeper>(
158 self,
159 sleep_fn: SN,
160 ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SN, RF, NF> {
161 assert!(
162 matches!(self.state, State::Idle(None)),
163 "sleep must be set before context"
164 );
165
166 RetryWithContext {
167 backoff: self.backoff,
168 retryable: self.retryable,
169 notify: self.notify,
170 future_fn: self.future_fn,
171 sleep_fn,
172 state: State::Idle(None),
173 }
174 }
175
176 /// Set the context for retrying.
177 ///
178 /// Context is used to capture ownership manually to prevent lifetime issues.
179 pub fn context(
180 self,
181 context: Ctx,
182 ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF> {
183 RetryWithContext {
184 backoff: self.backoff,
185 retryable: self.retryable,
186 notify: self.notify,
187 future_fn: self.future_fn,
188 sleep_fn: self.sleep_fn,
189 state: State::Idle(Some(context)),
190 }
191 }
192
193 /// Set the conditions for retrying.
194 ///
195 /// If not specified, all errors are considered retryable.
196 ///
197 /// # Examples
198 ///
199 /// ```no_run
200 /// use anyhow::Result;
201 /// use backon::ExponentialBuilder;
202 /// use backon::Retryable;
203 ///
204 /// async fn fetch() -> Result<String> {
205 /// Ok(reqwest::get("https://www.rust-lang.org")
206 /// .await?
207 /// .text()
208 /// .await?)
209 /// }
210 ///
211 /// #[tokio::main(flavor = "current_thread")]
212 /// async fn main() -> Result<()> {
213 /// let content = fetch
214 /// .retry(ExponentialBuilder::default())
215 /// .when(|e| e.to_string() == "EOF")
216 /// .await?;
217 /// println!("fetch succeeded: {}", content);
218 ///
219 /// Ok(())
220 /// }
221 /// ```
222 pub fn when<RN: FnMut(&E) -> bool>(
223 self,
224 retryable: RN,
225 ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RN, NF> {
226 RetryWithContext {
227 backoff: self.backoff,
228 retryable,
229 notify: self.notify,
230 future_fn: self.future_fn,
231 sleep_fn: self.sleep_fn,
232 state: self.state,
233 }
234 }
235
236 /// Set to notify for all retry attempts.
237 ///
238 /// When a retry happens, the input function will be invoked with the error and the sleep duration before pausing.
239 ///
240 /// If not specified, this operation does nothing.
241 ///
242 /// # Examples
243 ///
244 /// ```no_run
245 /// use core::time::Duration;
246 ///
247 /// use anyhow::Result;
248 /// use backon::ExponentialBuilder;
249 /// use backon::Retryable;
250 ///
251 /// async fn fetch() -> Result<String> {
252 /// Ok(reqwest::get("https://www.rust-lang.org")
253 /// .await?
254 /// .text()
255 /// .await?)
256 /// }
257 ///
258 /// #[tokio::main(flavor = "current_thread")]
259 /// async fn main() -> Result<()> {
260 /// let content = fetch
261 /// .retry(ExponentialBuilder::default())
262 /// .notify(|err: &anyhow::Error, dur: Duration| {
263 /// println!("retrying error {:?} with sleeping {:?}", err, dur);
264 /// })
265 /// .await?;
266 /// println!("fetch succeeded: {}", content);
267 ///
268 /// Ok(())
269 /// }
270 /// ```
271 pub fn notify<NN: FnMut(&E, Duration)>(
272 self,
273 notify: NN,
274 ) -> RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NN> {
275 RetryWithContext {
276 backoff: self.backoff,
277 retryable: self.retryable,
278 notify,
279 future_fn: self.future_fn,
280 sleep_fn: self.sleep_fn,
281 state: self.state,
282 }
283 }
284}
285
286/// State maintains internal state of retry.
287enum State<T, E, Ctx, Fut: Future<Output = (Ctx, Result<T, E>)>, SleepFut: Future<Output = ()>> {
288 Idle(Option<Ctx>),
289 Polling(Fut),
290 Sleeping((Option<Ctx>, SleepFut)),
291}
292
293impl<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF> Future
294 for RetryWithContext<B, T, E, Ctx, Fut, FutureFn, SF, RF, NF>
295where
296 B: Backoff,
297 Fut: Future<Output = (Ctx, Result<T, E>)>,
298 FutureFn: FnMut(Ctx) -> Fut,
299 SF: Sleeper,
300 RF: FnMut(&E) -> bool,
301 NF: FnMut(&E, Duration),
302{
303 type Output = (Ctx, Result<T, E>);
304
305 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
306 // Safety: This is safe because we don't move the `Retry` struct itself,
307 // only its internal state.
308 //
309 // We do the exactly same thing like `pin_project` but without depending on it directly.
310 let this = unsafe { self.get_unchecked_mut() };
311
312 loop {
313 match &mut this.state {
314 State::Idle(ctx) => {
315 let ctx = ctx.take().expect("context must be valid");
316 let fut = (this.future_fn)(ctx);
317 this.state = State::Polling(fut);
318 continue;
319 }
320 State::Polling(fut) => {
321 // Safety: This is safe because we don't move the `Retry` struct and this fut,
322 // only its internal state.
323 //
324 // We do the exactly same thing like `pin_project` but without depending on it directly.
325 let mut fut = unsafe { Pin::new_unchecked(fut) };
326
327 let (ctx, res) = ready!(fut.as_mut().poll(cx));
328 match res {
329 Ok(v) => return Poll::Ready((ctx, Ok(v))),
330 Err(err) => {
331 // If input error is not retryable, return error directly.
332 if !(this.retryable)(&err) {
333 return Poll::Ready((ctx, Err(err)));
334 }
335 match this.backoff.next() {
336 None => return Poll::Ready((ctx, Err(err))),
337 Some(dur) => {
338 (this.notify)(&err, dur);
339 this.state =
340 State::Sleeping((Some(ctx), this.sleep_fn.sleep(dur)));
341 continue;
342 }
343 }
344 }
345 }
346 }
347 State::Sleeping((ctx, sl)) => {
348 // Safety: This is safe because we don't move the `Retry` struct and this fut,
349 // only its internal state.
350 //
351 // We do the exactly same thing like `pin_project` but without depending on it directly.
352 let mut sl = unsafe { Pin::new_unchecked(sl) };
353
354 ready!(sl.as_mut().poll(cx));
355 let ctx = ctx.take().expect("context must be valid");
356 this.state = State::Idle(Some(ctx));
357 continue;
358 }
359 }
360 }
361 }
362}
363
364#[cfg(test)]
365#[cfg(any(feature = "tokio-sleep", feature = "gloo-timers-sleep",))]
366mod tests {
367 extern crate alloc;
368
369 use alloc::string::ToString;
370 use core::time::Duration;
371
372 use anyhow::anyhow;
373 use anyhow::Result;
374 use tokio::sync::Mutex;
375 #[cfg(not(target_arch = "wasm32"))]
376 use tokio::test;
377 #[cfg(target_arch = "wasm32")]
378 use wasm_bindgen_test::wasm_bindgen_test as test;
379
380 use super::*;
381 use crate::ExponentialBuilder;
382
383 struct Test;
384
385 impl Test {
386 async fn hello(&mut self) -> Result<usize> {
387 Err(anyhow!("not retryable"))
388 }
389 }
390
391 #[test]
392 async fn test_retry_with_not_retryable_error() {
393 let error_times = Mutex::new(0);
394
395 let test = Test;
396
397 let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1));
398
399 let (_, result) = {
400 |mut v: Test| async {
401 let mut x = error_times.lock().await;
402 *x += 1;
403
404 let res = v.hello().await;
405 (v, res)
406 }
407 }
408 .retry(backoff)
409 .context(test)
410 // Only retry If error message is `retryable`
411 .when(|e| e.to_string() == "retryable")
412 .await;
413
414 assert!(result.is_err());
415 assert_eq!("not retryable", result.unwrap_err().to_string());
416 // `f` always returns error "not retryable", so it should be executed
417 // only once.
418 assert_eq!(*error_times.lock().await, 1);
419 }
420}