1use std::future::Future;
57use std::pin::Pin;
58use std::task::{Context, Poll};
59use std::{cmp, thread};
60
61use futures::{Stream, StreamExt, ready};
62use pin_project::pin_project;
63use tokio::io::{AsyncRead, ReadBuf};
64use tokio::time::error::Elapsed;
65use tokio::time::{self, Duration, Instant, Sleep};
66
67#[derive(Clone, Copy, Debug, Eq, PartialEq)]
69pub struct RetryState {
70 pub i: usize,
72 pub next_backoff: Option<Duration>,
77}
78
79#[derive(Debug)]
81pub enum RetryResult<T, E> {
82 Ok(T),
84 RetryableErr(E),
86 FatalErr(E),
88}
89
90impl<T, E> From<Result<T, E>> for RetryResult<T, E> {
91 fn from(res: Result<T, E>) -> RetryResult<T, E> {
92 match res {
93 Ok(t) => RetryResult::Ok(t),
94 Err(e) => RetryResult::RetryableErr(e),
95 }
96 }
97}
98
99#[pin_project]
103#[derive(Debug)]
104pub struct Retry {
105 initial_backoff: Duration,
106 factor: f64,
107 clamp_backoff: Duration,
108 max_duration: Duration,
109 max_tries: usize,
110}
111
112impl Retry {
113 pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
118 self.initial_backoff = initial_backoff;
119 self
120 }
121
122 pub fn clamp_backoff(mut self, clamp_backoff: Duration) -> Self {
126 self.clamp_backoff = clamp_backoff;
127 self
128 }
129
130 pub fn factor(mut self, factor: f64) -> Self {
135 self.factor = factor;
136 self
137 }
138
139 pub fn max_tries(mut self, max_tries: usize) -> Self {
150 if max_tries == 0 {
151 panic!("max tries must be greater than zero");
152 }
153 self.max_tries = max_tries;
154 self
155 }
156
157 pub fn max_duration(mut self, duration: Duration) -> Self {
166 self.max_duration = duration;
167 self
168 }
169
170 pub fn retry<F, R, T, E>(self, mut f: F) -> Result<T, E>
197 where
198 F: FnMut(RetryState) -> R,
199 R: Into<RetryResult<T, E>>,
200 {
201 let start = Instant::now();
202 let mut i = 0;
203 let mut next_backoff = Some(cmp::min(self.initial_backoff, self.clamp_backoff));
204 loop {
205 let elapsed = start.elapsed();
206 if elapsed > self.max_duration || i + 1 >= self.max_tries {
207 next_backoff = None;
208 } else if elapsed + next_backoff.unwrap() > self.max_duration {
209 next_backoff = Some(self.max_duration - elapsed);
210 }
211 let state = RetryState { i, next_backoff };
212 match f(state).into() {
213 RetryResult::Ok(t) => return Ok(t),
214 RetryResult::FatalErr(e) => return Err(e),
215 RetryResult::RetryableErr(e) => match &mut next_backoff {
216 None => return Err(e),
217 Some(next_backoff) => {
218 thread::sleep(*next_backoff);
219 *next_backoff =
220 cmp::min(next_backoff.mul_f64(self.factor), self.clamp_backoff);
221 }
222 },
223 }
224 i += 1;
225 }
226 }
227
228 pub async fn retry_async<F, U, R, T, E>(self, mut f: F) -> Result<T, E>
230 where
231 F: FnMut(RetryState) -> U,
232 U: Future<Output = R>,
233 R: Into<RetryResult<T, E>>,
234 {
235 let stream = self.into_retry_stream();
236 tokio::pin!(stream);
237 let mut err = None;
238 while let Some(state) = stream.next().await {
239 match f(state).await.into() {
240 RetryResult::Ok(v) => return Ok(v),
241 RetryResult::FatalErr(e) => return Err(e),
242 RetryResult::RetryableErr(e) => err = Some(e),
243 }
244 }
245 Err(err.expect("retry produces at least one element"))
246 }
247
248 pub async fn retry_async_canceling<F, U, T, E>(self, mut f: F) -> Result<T, E>
265 where
266 F: FnMut(RetryState) -> U,
267 U: Future<Output = Result<T, E>>,
268 E: From<Elapsed> + std::fmt::Debug,
269 {
270 let start = Instant::now();
271 let max_duration = self.max_duration;
272 let stream = self.into_retry_stream();
273 tokio::pin!(stream);
274 let mut err = None;
275 while let Some(state) = stream.next().await {
276 let fut = time::timeout(max_duration.saturating_sub(start.elapsed()), f(state));
277 match fut.await {
278 Ok(Ok(t)) => return Ok(t),
279 Ok(Err(e)) => err = Some(e),
280 Err(e) => return Err(err.unwrap_or_else(|| e.into())),
281 }
282 }
283 Err(err.expect("retry produces at least one element"))
284 }
285
286 pub async fn retry_async_with_state<F, S, U, R, T, E>(
288 self,
289 mut user_state: S,
290 mut f: F,
291 ) -> (S, Result<T, E>)
292 where
293 F: FnMut(RetryState, S) -> U,
294 U: Future<Output = (S, R)>,
295 R: Into<RetryResult<T, E>>,
296 {
297 let stream = self.into_retry_stream();
298 tokio::pin!(stream);
299 let mut err = None;
300 while let Some(retry_state) = stream.next().await {
301 let (s, r) = f(retry_state, user_state).await;
302 match r.into() {
303 RetryResult::Ok(v) => return (s, Ok(v)),
304 RetryResult::FatalErr(e) => return (s, Err(e)),
305 RetryResult::RetryableErr(e) => {
306 err = Some(e);
307 user_state = s;
308 }
309 }
310 }
311 (
312 user_state,
313 Err(err.expect("retry produces at least one element")),
314 )
315 }
316
317 pub fn into_retry_stream(self) -> RetryStream {
319 RetryStream {
320 retry: self,
321 start: Instant::now(),
322 i: 0,
323 next_backoff: None,
324 sleep: time::sleep(Duration::default()),
325 }
326 }
327}
328
329impl Default for Retry {
330 fn default() -> Self {
333 Retry {
334 initial_backoff: Duration::from_millis(125),
335 factor: 2.0,
336 clamp_backoff: Duration::MAX,
337 max_tries: usize::MAX,
338 max_duration: Duration::MAX,
339 }
340 }
341}
342
343#[pin_project]
345#[derive(Debug)]
346pub struct RetryStream {
347 retry: Retry,
348 start: Instant,
349 i: usize,
350 next_backoff: Option<Duration>,
351 #[pin]
352 sleep: Sleep,
353}
354
355impl RetryStream {
356 fn reset(self: Pin<&mut Self>) {
357 let this = self.project();
358 *this.start = Instant::now();
359 *this.i = 0;
360 *this.next_backoff = None;
361 }
362}
363
364impl Stream for RetryStream {
365 type Item = RetryState;
366
367 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
368 let mut this = self.project();
369 let retry = this.retry;
370
371 match this.next_backoff {
372 None if *this.i == 0 => {
373 *this.next_backoff = Some(cmp::min(retry.initial_backoff, retry.clamp_backoff));
374 }
375 None => return Poll::Ready(None),
376 Some(next_backoff) => {
377 ready!(this.sleep.as_mut().poll(cx));
378 *next_backoff = cmp::min(next_backoff.mul_f64(retry.factor), retry.clamp_backoff);
379 }
380 }
381
382 let elapsed = this.start.elapsed();
383 if elapsed > retry.max_duration || *this.i + 1 >= retry.max_tries {
384 *this.next_backoff = None;
385 } else if elapsed + this.next_backoff.unwrap() > retry.max_duration {
386 *this.next_backoff = Some(retry.max_duration - elapsed);
387 }
388
389 let state = RetryState {
390 i: *this.i,
391 next_backoff: *this.next_backoff,
392 };
393 if let Some(d) = *this.next_backoff {
394 this.sleep.reset(Instant::now() + d);
395 }
396 *this.i += 1;
397 Poll::Ready(Some(state))
398 }
399}
400
401#[pin_project]
404#[derive(Debug)]
405pub struct RetryReader<F, U, R> {
406 factory: F,
407 offset: usize,
408 error: Option<std::io::Error>,
409 #[pin]
410 retry: RetryStream,
411 #[pin]
412 state: RetryReaderState<U, R>,
413}
414
415#[pin_project(project = RetryReaderStateProj)]
416#[derive(Debug)]
417enum RetryReaderState<U, R> {
418 Waiting,
419 Creating(#[pin] U),
420 Reading(#[pin] R),
421}
422
423impl<F, U, R> RetryReader<F, U, R>
424where
425 F: FnMut(RetryState, usize) -> U,
426 U: Future<Output = Result<R, std::io::Error>>,
427 R: AsyncRead,
428{
429 pub fn new(factory: F) -> Self {
436 Self::with_retry(factory, Retry::default())
437 }
438
439 pub fn with_retry(factory: F, retry: Retry) -> Self {
442 Self {
443 factory,
444 offset: 0,
445 error: None,
446 retry: retry.into_retry_stream(),
447 state: RetryReaderState::Waiting,
448 }
449 }
450}
451
452impl<F, U, R> AsyncRead for RetryReader<F, U, R>
453where
454 F: FnMut(RetryState, usize) -> U,
455 U: Future<Output = Result<R, std::io::Error>>,
456 R: AsyncRead,
457{
458 fn poll_read(
459 mut self: Pin<&mut Self>,
460 cx: &mut Context<'_>,
461 buf: &mut ReadBuf<'_>,
462 ) -> Poll<Result<(), std::io::Error>> {
463 loop {
464 let mut this = self.as_mut().project();
465 use RetryReaderState::*;
466 match this.state.as_mut().project() {
467 RetryReaderStateProj::Waiting => match ready!(this.retry.as_mut().poll_next(cx)) {
468 None => {
469 return Poll::Ready(Err(this
470 .error
471 .take()
472 .expect("retry produces at least one element")));
473 }
474 Some(state) => {
475 this.state
476 .set(Creating((*this.factory)(state, *this.offset)));
477 }
478 },
479 RetryReaderStateProj::Creating(reader_fut) => match ready!(reader_fut.poll(cx)) {
480 Ok(reader) => {
481 this.state.set(Reading(reader));
482 }
483 Err(err) => {
484 *this.error = Some(err);
485 this.state.set(Waiting);
486 }
487 },
488 RetryReaderStateProj::Reading(reader) => {
489 let filled_end = buf.filled().len();
490 match ready!(reader.poll_read(cx, buf)) {
491 Ok(()) => {
492 if let Some(_) = this.error.take() {
493 this.retry.reset();
494 }
495 *this.offset += buf.filled().len() - filled_end;
496 return Poll::Ready(Ok(()));
497 }
498 Err(err) => {
499 *this.error = Some(err);
500 this.state.set(Waiting);
501 }
502 }
503 }
504 }
505 }
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use anyhow::{anyhow, bail};
512
513 use super::*;
514
515 #[crate::test]
516 fn test_retry_success() {
517 let mut states = vec![];
518 let res = Retry::default()
519 .initial_backoff(Duration::from_millis(1))
520 .retry(|state| {
521 states.push(state);
522 if state.i == 2 {
523 Ok(())
524 } else {
525 Err::<(), _>("injected")
526 }
527 });
528 assert_eq!(res, Ok(()));
529 assert_eq!(
530 states,
531 &[
532 RetryState {
533 i: 0,
534 next_backoff: Some(Duration::from_millis(1))
535 },
536 RetryState {
537 i: 1,
538 next_backoff: Some(Duration::from_millis(2))
539 },
540 RetryState {
541 i: 2,
542 next_backoff: Some(Duration::from_millis(4))
543 },
544 ]
545 );
546 }
547
548 #[crate::test(tokio::test)]
549 #[cfg_attr(miri, ignore)] async fn test_retry_async_success() {
551 let mut states = vec![];
552 let res = Retry::default()
553 .initial_backoff(Duration::from_millis(1))
554 .retry_async(|state| {
555 states.push(state);
556 async move {
557 if state.i == 2 {
558 Ok(())
559 } else {
560 Err::<(), _>("injected")
561 }
562 }
563 })
564 .await;
565 assert_eq!(res, Ok(()));
566 assert_eq!(
567 states,
568 &[
569 RetryState {
570 i: 0,
571 next_backoff: Some(Duration::from_millis(1))
572 },
573 RetryState {
574 i: 1,
575 next_backoff: Some(Duration::from_millis(2))
576 },
577 RetryState {
578 i: 2,
579 next_backoff: Some(Duration::from_millis(4))
580 },
581 ]
582 );
583 }
584
585 #[crate::test(tokio::test)]
586 async fn test_retry_fatal() {
587 let mut states = vec![];
588 let res = Retry::default()
589 .initial_backoff(Duration::from_millis(1))
590 .retry(|state| {
591 states.push(state);
592 if state.i == 0 {
593 RetryResult::RetryableErr::<(), _>("retry me")
594 } else {
595 RetryResult::FatalErr("injected")
596 }
597 });
598 assert_eq!(res, Err("injected"));
599 assert_eq!(
600 states,
601 &[
602 RetryState {
603 i: 0,
604 next_backoff: Some(Duration::from_millis(1))
605 },
606 RetryState {
607 i: 1,
608 next_backoff: Some(Duration::from_millis(2))
609 },
610 ]
611 );
612 }
613
614 #[crate::test(tokio::test)]
615 #[cfg_attr(miri, ignore)] async fn test_retry_async_fatal() {
617 let mut states = vec![];
618 let res = Retry::default()
619 .initial_backoff(Duration::from_millis(1))
620 .retry_async(|state| {
621 states.push(state);
622 async move {
623 if state.i == 0 {
624 RetryResult::RetryableErr::<(), _>("retry me")
625 } else {
626 RetryResult::FatalErr("injected")
627 }
628 }
629 })
630 .await;
631 assert_eq!(res, Err("injected"));
632 assert_eq!(
633 states,
634 &[
635 RetryState {
636 i: 0,
637 next_backoff: Some(Duration::from_millis(1))
638 },
639 RetryState {
640 i: 1,
641 next_backoff: Some(Duration::from_millis(2))
642 },
643 ]
644 );
645 }
646
647 #[crate::test(tokio::test)]
648 #[cfg_attr(miri, ignore)] async fn test_retry_fail_max_tries() {
650 let mut states = vec![];
651 let res = Retry::default()
652 .initial_backoff(Duration::from_millis(1))
653 .max_tries(3)
654 .retry(|state| {
655 states.push(state);
656 Err::<(), _>("injected")
657 });
658 assert_eq!(res, Err("injected"));
659 assert_eq!(
660 states,
661 &[
662 RetryState {
663 i: 0,
664 next_backoff: Some(Duration::from_millis(1))
665 },
666 RetryState {
667 i: 1,
668 next_backoff: Some(Duration::from_millis(2))
669 },
670 RetryState {
671 i: 2,
672 next_backoff: None
673 },
674 ]
675 );
676 }
677
678 #[crate::test(tokio::test)]
679 #[cfg_attr(miri, ignore)] async fn test_retry_async_fail_max_tries() {
681 let mut states = vec![];
682 let res = Retry::default()
683 .initial_backoff(Duration::from_millis(1))
684 .max_tries(3)
685 .retry_async(|state| {
686 states.push(state);
687 async { Err::<(), _>("injected") }
688 })
689 .await;
690 assert_eq!(res, Err("injected"));
691 assert_eq!(
692 states,
693 &[
694 RetryState {
695 i: 0,
696 next_backoff: Some(Duration::from_millis(1))
697 },
698 RetryState {
699 i: 1,
700 next_backoff: Some(Duration::from_millis(2))
701 },
702 RetryState {
703 i: 2,
704 next_backoff: None
705 },
706 ]
707 );
708 }
709
710 #[crate::test]
711 #[cfg_attr(miri, ignore)] fn test_retry_fail_max_duration() {
713 let mut states = vec![];
714 let res = Retry::default()
715 .initial_backoff(Duration::from_millis(10))
716 .max_duration(Duration::from_millis(20))
717 .retry(|state| {
718 states.push(state);
719 Err::<(), _>("injected")
720 });
721 assert_eq!(res, Err("injected"));
722
723 assert_eq!(
725 states[0],
726 RetryState {
727 i: 0,
728 next_backoff: Some(Duration::from_millis(10))
729 },
730 );
731
732 assert_eq!(states[1].i, 1);
736 let backoff = states[1].next_backoff.unwrap();
737 assert!(backoff > Duration::from_millis(0) && backoff < Duration::from_millis(10));
738
739 assert_eq!(
742 states[2],
743 RetryState {
744 i: 2,
745 next_backoff: None,
746 },
747 );
748 }
749
750 #[crate::test(tokio::test)]
751 #[cfg_attr(miri, ignore)] #[ignore] async fn test_retry_async_fail_max_duration() {
754 let mut states = vec![];
755 let res = Retry::default()
756 .initial_backoff(Duration::from_millis(5))
757 .max_duration(Duration::from_millis(10))
758 .retry_async(|state| {
759 states.push(state);
760 async { Err::<(), _>("injected") }
761 })
762 .await;
763 assert_eq!(res, Err("injected"));
764
765 assert_eq!(
767 states[0],
768 RetryState {
769 i: 0,
770 next_backoff: Some(Duration::from_millis(5))
771 },
772 );
773
774 assert_eq!(states[1].i, 1);
778 assert!(match states[1].next_backoff {
779 None => true,
780 Some(backoff) =>
781 backoff > Duration::from_millis(0) && backoff < Duration::from_millis(5),
782 });
783
784 assert_eq!(
787 states[2],
788 RetryState {
789 i: 2,
790 next_backoff: None,
791 },
792 );
793 }
794
795 #[crate::test]
796 #[cfg_attr(miri, ignore)] fn test_retry_fail_clamp_backoff() {
798 let mut states = vec![];
799 let res = Retry::default()
800 .initial_backoff(Duration::from_millis(1))
801 .clamp_backoff(Duration::from_millis(1))
802 .max_tries(4)
803 .retry(|state| {
804 states.push(state);
805 Err::<(), _>("injected")
806 });
807 assert_eq!(res, Err("injected"));
808 assert_eq!(
809 states,
810 &[
811 RetryState {
812 i: 0,
813 next_backoff: Some(Duration::from_millis(1))
814 },
815 RetryState {
816 i: 1,
817 next_backoff: Some(Duration::from_millis(1))
818 },
819 RetryState {
820 i: 2,
821 next_backoff: Some(Duration::from_millis(1))
822 },
823 RetryState {
824 i: 3,
825 next_backoff: None
826 },
827 ]
828 );
829 }
830
831 #[crate::test(tokio::test)]
832 #[cfg_attr(miri, ignore)] async fn test_retry_async_fail_clamp_backoff() {
834 let mut states = vec![];
835 let res = Retry::default()
836 .initial_backoff(Duration::from_millis(1))
837 .clamp_backoff(Duration::from_millis(1))
838 .max_tries(4)
839 .retry_async(|state| {
840 states.push(state);
841 async { Err::<(), _>("injected") }
842 })
843 .await;
844 assert_eq!(res, Err("injected"));
845 assert_eq!(
846 states,
847 &[
848 RetryState {
849 i: 0,
850 next_backoff: Some(Duration::from_millis(1))
851 },
852 RetryState {
853 i: 1,
854 next_backoff: Some(Duration::from_millis(1))
855 },
856 RetryState {
857 i: 2,
858 next_backoff: Some(Duration::from_millis(1))
859 },
860 RetryState {
861 i: 3,
862 next_backoff: None
863 },
864 ]
865 );
866 }
867
868 #[crate::test(tokio::test)]
871 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_uncanceled_failure() {
873 let res = Retry::default()
874 .max_duration(Duration::from_millis(100))
875 .retry_async_canceling(|_| async move { Err::<(), _>(anyhow!("injected")) })
876 .await;
877 assert_eq!(res.unwrap_err().to_string(), "injected");
878 }
879
880 #[crate::test(tokio::test)]
883 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_canceled_failure() {
885 let res = Retry::default()
886 .max_duration(Duration::from_millis(100))
887 .retry_async_canceling(|state| async move {
888 if state.i == 0 {
889 bail!("injected")
890 } else {
891 time::sleep(Duration::MAX).await;
892 Ok(())
893 }
894 })
895 .await;
896 assert_eq!(res.unwrap_err().to_string(), "injected");
897 }
898
899 #[crate::test(tokio::test)]
902 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_canceled_first_failure() {
904 let res = Retry::default()
905 .max_duration(Duration::from_millis(100))
906 .retry_async_canceling(|_| async move {
907 time::sleep(Duration::MAX).await;
908 Ok::<_, anyhow::Error>(())
909 })
910 .await;
911 assert_eq!(res.unwrap_err().to_string(), "deadline has elapsed");
912 }
913
914 #[crate::test(tokio::test)]
915 #[cfg_attr(miri, ignore)] async fn test_retry_reader() {
917 use tokio::io::AsyncReadExt;
918
919 struct FlakyReader {
921 offset: usize,
922 should_error: bool,
923 }
924
925 impl AsyncRead for FlakyReader {
926 fn poll_read(
927 mut self: Pin<&mut Self>,
928 _: &mut Context<'_>,
929 buf: &mut ReadBuf<'_>,
930 ) -> Poll<Result<(), std::io::Error>> {
931 if self.should_error {
932 Poll::Ready(Err(std::io::ErrorKind::ConnectionReset.into()))
933 } else if self.offset < 256 {
934 buf.put_slice(&[b'A']);
935 self.should_error = true;
936 Poll::Ready(Ok(()))
937 } else {
938 Poll::Ready(Ok(()))
939 }
940 }
941 }
942
943 let reader = RetryReader::new(|_state, offset| async move {
944 Ok(FlakyReader {
945 offset,
946 should_error: false,
947 })
948 });
949 tokio::pin!(reader);
950
951 let mut data = Vec::new();
952 reader.read_to_end(&mut data).await.unwrap();
953 assert_eq!(data, vec![b'A'; 256]);
954 }
955
956 #[crate::test(tokio::test)]
957 #[cfg_attr(miri, ignore)] async fn test_retry_async_state() {
959 struct S {
960 i: i64,
961 }
962 impl S {
963 #[allow(clippy::unused_async)]
964 async fn try_inc(&mut self) -> Result<i64, ()> {
965 self.i += 1;
966 if self.i > 10 { Ok(self.i) } else { Err(()) }
967 }
968 }
969
970 let s = S { i: 0 };
971 let (_new_s, res) = Retry::default()
972 .max_tries(10)
973 .clamp_backoff(Duration::from_nanos(0))
974 .retry_async_with_state(s, |_, mut s| async {
975 let res = s.try_inc().await;
976 (s, res)
977 })
978 .await;
979 assert_eq!(res, Err(()));
980
981 let s = S { i: 0 };
982 let (_new_s, res) = Retry::default()
983 .max_tries(11)
984 .clamp_backoff(Duration::from_nanos(0))
985 .retry_async_with_state(s, |_, mut s| async {
986 let res = s.try_inc().await;
987 (s, res)
988 })
989 .await;
990 assert_eq!(res, Ok(11));
991 }
992}