1use std::future::Future;
57use std::pin::Pin;
58use std::task::{Context, Poll};
59use std::{cmp, thread};
60
61use futures::{Stream, StreamExt, ready};
62use pin_project::pin_project;
63use tokio::io::{AsyncRead, ReadBuf};
64use tokio::time::error::Elapsed;
65use tokio::time::{self, Duration, Instant, Sleep};
66
67#[derive(Clone, Copy, Debug, Eq, PartialEq)]
69pub struct RetryState {
70 pub i: usize,
72 pub next_backoff: Option<Duration>,
77}
78
79#[derive(Debug)]
81pub enum RetryResult<T, E> {
82 Ok(T),
84 RetryableErr(E),
86 FatalErr(E),
88}
89
90impl<T, E> From<Result<T, E>> for RetryResult<T, E> {
91 fn from(res: Result<T, E>) -> RetryResult<T, E> {
92 match res {
93 Ok(t) => RetryResult::Ok(t),
94 Err(e) => RetryResult::RetryableErr(e),
95 }
96 }
97}
98
99#[pin_project]
103#[derive(Debug)]
104pub struct Retry {
105 initial_backoff: Duration,
106 factor: f64,
107 clamp_backoff: Duration,
108 max_duration: Duration,
109 max_tries: usize,
110}
111
112impl Retry {
113 pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
118 self.initial_backoff = initial_backoff;
119 self
120 }
121
122 pub fn clamp_backoff(mut self, clamp_backoff: Duration) -> Self {
126 self.clamp_backoff = clamp_backoff;
127 self
128 }
129
130 pub fn factor(mut self, factor: f64) -> Self {
135 self.factor = factor;
136 self
137 }
138
139 pub fn max_tries(mut self, max_tries: usize) -> Self {
150 if max_tries == 0 {
151 panic!("max tries must be greater than zero");
152 }
153 self.max_tries = max_tries;
154 self
155 }
156
157 pub fn max_duration(mut self, duration: Duration) -> Self {
166 self.max_duration = duration;
167 self
168 }
169
170 pub fn retry<F, R, T, E>(self, mut f: F) -> Result<T, E>
197 where
198 F: FnMut(RetryState) -> R,
199 R: Into<RetryResult<T, E>>,
200 {
201 let start = Instant::now();
202 let mut i = 0;
203 let mut next_backoff = Some(cmp::min(self.initial_backoff, self.clamp_backoff));
204 loop {
205 let elapsed = start.elapsed();
206 if elapsed > self.max_duration || i + 1 >= self.max_tries {
207 next_backoff = None;
208 } else if elapsed + next_backoff.unwrap() > self.max_duration {
209 next_backoff = Some(self.max_duration - elapsed);
210 }
211 let state = RetryState { i, next_backoff };
212 match f(state).into() {
213 RetryResult::Ok(t) => return Ok(t),
214 RetryResult::FatalErr(e) => return Err(e),
215 RetryResult::RetryableErr(e) => match &mut next_backoff {
216 None => return Err(e),
217 Some(next_backoff) => {
218 thread::sleep(*next_backoff);
219 *next_backoff =
220 cmp::min(next_backoff.mul_f64(self.factor), self.clamp_backoff);
221 }
222 },
223 }
224 i += 1;
225 }
226 }
227
228 pub async fn retry_async<F, U, R, T, E>(self, mut f: F) -> Result<T, E>
230 where
231 F: FnMut(RetryState) -> U,
232 U: Future<Output = R>,
233 R: Into<RetryResult<T, E>>,
234 {
235 let stream = self.into_retry_stream();
236 tokio::pin!(stream);
237 let mut err = None;
238 while let Some(state) = stream.next().await {
239 match f(state).await.into() {
240 RetryResult::Ok(v) => return Ok(v),
241 RetryResult::FatalErr(e) => return Err(e),
242 RetryResult::RetryableErr(e) => err = Some(e),
243 }
244 }
245 Err(err.expect("retry produces at least one element"))
246 }
247
248 pub async fn retry_async_canceling<F, U, T, E>(self, mut f: F) -> Result<T, E>
265 where
266 F: FnMut(RetryState) -> U,
267 U: Future<Output = Result<T, E>>,
268 E: From<Elapsed> + std::fmt::Debug,
269 {
270 let start = Instant::now();
271 let max_duration = self.max_duration;
272 let stream = self.into_retry_stream();
273 tokio::pin!(stream);
274 let mut err = None;
275 while let Some(state) = stream.next().await {
276 let fut = time::timeout(max_duration.saturating_sub(start.elapsed()), f(state));
277 match fut.await {
278 Ok(Ok(t)) => return Ok(t),
279 Ok(Err(e)) => err = Some(e),
280 Err(e) => return Err(err.unwrap_or_else(|| e.into())),
281 }
282 }
283 Err(err.expect("retry produces at least one element"))
284 }
285
286 pub async fn retry_async_with_state<F, S, U, R, T, E>(
288 self,
289 mut user_state: S,
290 mut f: F,
291 ) -> (S, Result<T, E>)
292 where
293 F: FnMut(RetryState, S) -> U,
294 U: Future<Output = (S, R)>,
295 R: Into<RetryResult<T, E>>,
296 {
297 let stream = self.into_retry_stream();
298 tokio::pin!(stream);
299 let mut err = None;
300 while let Some(retry_state) = stream.next().await {
301 let (s, r) = f(retry_state, user_state).await;
302 match r.into() {
303 RetryResult::Ok(v) => return (s, Ok(v)),
304 RetryResult::FatalErr(e) => return (s, Err(e)),
305 RetryResult::RetryableErr(e) => {
306 err = Some(e);
307 user_state = s;
308 }
309 }
310 }
311 (
312 user_state,
313 Err(err.expect("retry produces at least one element")),
314 )
315 }
316
317 pub async fn retry_async_with_state_canceling<F, S, U, R, T, E>(
320 self,
321 mut user_state: S,
322 mut f: F,
323 ) -> Result<T, E>
324 where
325 F: FnMut(RetryState, S) -> U,
326 U: Future<Output = (S, R)>,
327 R: Into<RetryResult<T, E>>,
328 E: From<Elapsed> + std::fmt::Debug,
329 {
330 let start = Instant::now();
331 let max_duration = self.max_duration;
332 let stream = self.into_retry_stream();
333 tokio::pin!(stream);
334 let mut err = None;
335 while let Some(retry_state) = stream.next().await {
336 let fut = time::timeout(
337 max_duration.saturating_sub(start.elapsed()),
338 f(retry_state, user_state),
339 );
340 match fut.await {
341 Ok((s, r)) => match r.into() {
342 RetryResult::Ok(t) => return Ok(t),
343 RetryResult::FatalErr(e) => return Err(e),
344 RetryResult::RetryableErr(e) => {
345 err = Some(e);
346 user_state = s;
347 }
348 },
349 Err(e) => return Err(err.unwrap_or_else(|| e.into())),
350 }
351 }
352 Err(err.expect("retry produces at least one element"))
353 }
354
355 pub fn into_retry_stream(self) -> RetryStream {
357 RetryStream {
358 retry: self,
359 start: Instant::now(),
360 i: 0,
361 next_backoff: None,
362 sleep: time::sleep(Duration::default()),
363 }
364 }
365}
366
367impl Default for Retry {
368 fn default() -> Self {
371 Retry {
372 initial_backoff: Duration::from_millis(125),
373 factor: 2.0,
374 clamp_backoff: Duration::MAX,
375 max_tries: usize::MAX,
376 max_duration: Duration::MAX,
377 }
378 }
379}
380
381#[pin_project]
383#[derive(Debug)]
384pub struct RetryStream {
385 retry: Retry,
386 start: Instant,
387 i: usize,
388 next_backoff: Option<Duration>,
389 #[pin]
390 sleep: Sleep,
391}
392
393impl RetryStream {
394 fn reset(self: Pin<&mut Self>) {
395 let this = self.project();
396 *this.start = Instant::now();
397 *this.i = 0;
398 *this.next_backoff = None;
399 }
400}
401
402impl Stream for RetryStream {
403 type Item = RetryState;
404
405 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
406 let mut this = self.project();
407 let retry = this.retry;
408
409 match this.next_backoff {
410 None if *this.i == 0 => {
411 *this.next_backoff = Some(cmp::min(retry.initial_backoff, retry.clamp_backoff));
412 }
413 None => return Poll::Ready(None),
414 Some(next_backoff) => {
415 ready!(this.sleep.as_mut().poll(cx));
416 *next_backoff = cmp::min(next_backoff.mul_f64(retry.factor), retry.clamp_backoff);
417 }
418 }
419
420 let elapsed = this.start.elapsed();
421 if elapsed > retry.max_duration || *this.i + 1 >= retry.max_tries {
422 *this.next_backoff = None;
423 } else if elapsed + this.next_backoff.unwrap() > retry.max_duration {
424 *this.next_backoff = Some(retry.max_duration - elapsed);
425 }
426
427 let state = RetryState {
428 i: *this.i,
429 next_backoff: *this.next_backoff,
430 };
431 if let Some(d) = *this.next_backoff {
432 this.sleep.reset(Instant::now() + d);
433 }
434 *this.i += 1;
435 Poll::Ready(Some(state))
436 }
437}
438
439#[pin_project]
442#[derive(Debug)]
443pub struct RetryReader<F, U, R> {
444 factory: F,
445 offset: usize,
446 error: Option<std::io::Error>,
447 #[pin]
448 retry: RetryStream,
449 #[pin]
450 state: RetryReaderState<U, R>,
451}
452
453#[pin_project(project = RetryReaderStateProj)]
454#[derive(Debug)]
455enum RetryReaderState<U, R> {
456 Waiting,
457 Creating(#[pin] U),
458 Reading(#[pin] R),
459}
460
461impl<F, U, R> RetryReader<F, U, R>
462where
463 F: FnMut(RetryState, usize) -> U,
464 U: Future<Output = Result<R, std::io::Error>>,
465 R: AsyncRead,
466{
467 pub fn new(factory: F) -> Self {
474 Self::with_retry(factory, Retry::default())
475 }
476
477 pub fn with_retry(factory: F, retry: Retry) -> Self {
480 Self {
481 factory,
482 offset: 0,
483 error: None,
484 retry: retry.into_retry_stream(),
485 state: RetryReaderState::Waiting,
486 }
487 }
488}
489
490impl<F, U, R> AsyncRead for RetryReader<F, U, R>
491where
492 F: FnMut(RetryState, usize) -> U,
493 U: Future<Output = Result<R, std::io::Error>>,
494 R: AsyncRead,
495{
496 fn poll_read(
497 mut self: Pin<&mut Self>,
498 cx: &mut Context<'_>,
499 buf: &mut ReadBuf<'_>,
500 ) -> Poll<Result<(), std::io::Error>> {
501 loop {
502 let mut this = self.as_mut().project();
503 use RetryReaderState::*;
504 match this.state.as_mut().project() {
505 RetryReaderStateProj::Waiting => match ready!(this.retry.as_mut().poll_next(cx)) {
506 None => {
507 return Poll::Ready(Err(this
508 .error
509 .take()
510 .expect("retry produces at least one element")));
511 }
512 Some(state) => {
513 this.state
514 .set(Creating((*this.factory)(state, *this.offset)));
515 }
516 },
517 RetryReaderStateProj::Creating(reader_fut) => match ready!(reader_fut.poll(cx)) {
518 Ok(reader) => {
519 this.state.set(Reading(reader));
520 }
521 Err(err) => {
522 *this.error = Some(err);
523 this.state.set(Waiting);
524 }
525 },
526 RetryReaderStateProj::Reading(reader) => {
527 let filled_end = buf.filled().len();
528 match ready!(reader.poll_read(cx, buf)) {
529 Ok(()) => {
530 if let Some(_) = this.error.take() {
531 this.retry.reset();
532 }
533 *this.offset += buf.filled().len() - filled_end;
534 return Poll::Ready(Ok(()));
535 }
536 Err(err) => {
537 *this.error = Some(err);
538 this.state.set(Waiting);
539 }
540 }
541 }
542 }
543 }
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use anyhow::{anyhow, bail};
550
551 use super::*;
552
553 #[crate::test]
554 fn test_retry_success() {
555 let mut states = vec![];
556 let res = Retry::default()
557 .initial_backoff(Duration::from_millis(1))
558 .retry(|state| {
559 states.push(state);
560 if state.i == 2 {
561 Ok(())
562 } else {
563 Err::<(), _>("injected")
564 }
565 });
566 assert_eq!(res, Ok(()));
567 assert_eq!(
568 states,
569 &[
570 RetryState {
571 i: 0,
572 next_backoff: Some(Duration::from_millis(1))
573 },
574 RetryState {
575 i: 1,
576 next_backoff: Some(Duration::from_millis(2))
577 },
578 RetryState {
579 i: 2,
580 next_backoff: Some(Duration::from_millis(4))
581 },
582 ]
583 );
584 }
585
586 #[crate::test(tokio::test)]
587 #[cfg_attr(miri, ignore)] async fn test_retry_async_success() {
589 let mut states = vec![];
590 let res = Retry::default()
591 .initial_backoff(Duration::from_millis(1))
592 .retry_async(|state| {
593 states.push(state);
594 async move {
595 if state.i == 2 {
596 Ok(())
597 } else {
598 Err::<(), _>("injected")
599 }
600 }
601 })
602 .await;
603 assert_eq!(res, Ok(()));
604 assert_eq!(
605 states,
606 &[
607 RetryState {
608 i: 0,
609 next_backoff: Some(Duration::from_millis(1))
610 },
611 RetryState {
612 i: 1,
613 next_backoff: Some(Duration::from_millis(2))
614 },
615 RetryState {
616 i: 2,
617 next_backoff: Some(Duration::from_millis(4))
618 },
619 ]
620 );
621 }
622
623 #[crate::test(tokio::test)]
624 async fn test_retry_fatal() {
625 let mut states = vec![];
626 let res = Retry::default()
627 .initial_backoff(Duration::from_millis(1))
628 .retry(|state| {
629 states.push(state);
630 if state.i == 0 {
631 RetryResult::RetryableErr::<(), _>("retry me")
632 } else {
633 RetryResult::FatalErr("injected")
634 }
635 });
636 assert_eq!(res, Err("injected"));
637 assert_eq!(
638 states,
639 &[
640 RetryState {
641 i: 0,
642 next_backoff: Some(Duration::from_millis(1))
643 },
644 RetryState {
645 i: 1,
646 next_backoff: Some(Duration::from_millis(2))
647 },
648 ]
649 );
650 }
651
652 #[crate::test(tokio::test)]
653 #[cfg_attr(miri, ignore)] async fn test_retry_async_fatal() {
655 let mut states = vec![];
656 let res = Retry::default()
657 .initial_backoff(Duration::from_millis(1))
658 .retry_async(|state| {
659 states.push(state);
660 async move {
661 if state.i == 0 {
662 RetryResult::RetryableErr::<(), _>("retry me")
663 } else {
664 RetryResult::FatalErr("injected")
665 }
666 }
667 })
668 .await;
669 assert_eq!(res, Err("injected"));
670 assert_eq!(
671 states,
672 &[
673 RetryState {
674 i: 0,
675 next_backoff: Some(Duration::from_millis(1))
676 },
677 RetryState {
678 i: 1,
679 next_backoff: Some(Duration::from_millis(2))
680 },
681 ]
682 );
683 }
684
685 #[crate::test(tokio::test)]
686 #[cfg_attr(miri, ignore)] async fn test_retry_fail_max_tries() {
688 let mut states = vec![];
689 let res = Retry::default()
690 .initial_backoff(Duration::from_millis(1))
691 .max_tries(3)
692 .retry(|state| {
693 states.push(state);
694 Err::<(), _>("injected")
695 });
696 assert_eq!(res, Err("injected"));
697 assert_eq!(
698 states,
699 &[
700 RetryState {
701 i: 0,
702 next_backoff: Some(Duration::from_millis(1))
703 },
704 RetryState {
705 i: 1,
706 next_backoff: Some(Duration::from_millis(2))
707 },
708 RetryState {
709 i: 2,
710 next_backoff: None
711 },
712 ]
713 );
714 }
715
716 #[crate::test(tokio::test)]
717 #[cfg_attr(miri, ignore)] async fn test_retry_async_fail_max_tries() {
719 let mut states = vec![];
720 let res = Retry::default()
721 .initial_backoff(Duration::from_millis(1))
722 .max_tries(3)
723 .retry_async(|state| {
724 states.push(state);
725 async { Err::<(), _>("injected") }
726 })
727 .await;
728 assert_eq!(res, Err("injected"));
729 assert_eq!(
730 states,
731 &[
732 RetryState {
733 i: 0,
734 next_backoff: Some(Duration::from_millis(1))
735 },
736 RetryState {
737 i: 1,
738 next_backoff: Some(Duration::from_millis(2))
739 },
740 RetryState {
741 i: 2,
742 next_backoff: None
743 },
744 ]
745 );
746 }
747
748 #[crate::test]
749 #[cfg_attr(miri, ignore)] fn test_retry_fail_max_duration() {
751 let mut states = vec![];
752 let res = Retry::default()
753 .initial_backoff(Duration::from_millis(10))
754 .max_duration(Duration::from_millis(20))
755 .retry(|state| {
756 states.push(state);
757 Err::<(), _>("injected")
758 });
759 assert_eq!(res, Err("injected"));
760
761 assert_eq!(
763 states[0],
764 RetryState {
765 i: 0,
766 next_backoff: Some(Duration::from_millis(10))
767 },
768 );
769
770 assert_eq!(states[1].i, 1);
774 let backoff = states[1].next_backoff.unwrap();
775 assert!(backoff > Duration::from_millis(0) && backoff < Duration::from_millis(10));
776
777 assert_eq!(
780 states[2],
781 RetryState {
782 i: 2,
783 next_backoff: None,
784 },
785 );
786 }
787
788 #[crate::test(tokio::test)]
789 #[cfg_attr(miri, ignore)] #[ignore] async fn test_retry_async_fail_max_duration() {
792 let mut states = vec![];
793 let res = Retry::default()
794 .initial_backoff(Duration::from_millis(5))
795 .max_duration(Duration::from_millis(10))
796 .retry_async(|state| {
797 states.push(state);
798 async { Err::<(), _>("injected") }
799 })
800 .await;
801 assert_eq!(res, Err("injected"));
802
803 assert_eq!(
805 states[0],
806 RetryState {
807 i: 0,
808 next_backoff: Some(Duration::from_millis(5))
809 },
810 );
811
812 assert_eq!(states[1].i, 1);
816 assert!(match states[1].next_backoff {
817 None => true,
818 Some(backoff) =>
819 backoff > Duration::from_millis(0) && backoff < Duration::from_millis(5),
820 });
821
822 assert_eq!(
825 states[2],
826 RetryState {
827 i: 2,
828 next_backoff: None,
829 },
830 );
831 }
832
833 #[crate::test]
834 #[cfg_attr(miri, ignore)] fn test_retry_fail_clamp_backoff() {
836 let mut states = vec![];
837 let res = Retry::default()
838 .initial_backoff(Duration::from_millis(1))
839 .clamp_backoff(Duration::from_millis(1))
840 .max_tries(4)
841 .retry(|state| {
842 states.push(state);
843 Err::<(), _>("injected")
844 });
845 assert_eq!(res, Err("injected"));
846 assert_eq!(
847 states,
848 &[
849 RetryState {
850 i: 0,
851 next_backoff: Some(Duration::from_millis(1))
852 },
853 RetryState {
854 i: 1,
855 next_backoff: Some(Duration::from_millis(1))
856 },
857 RetryState {
858 i: 2,
859 next_backoff: Some(Duration::from_millis(1))
860 },
861 RetryState {
862 i: 3,
863 next_backoff: None
864 },
865 ]
866 );
867 }
868
869 #[crate::test(tokio::test)]
870 #[cfg_attr(miri, ignore)] async fn test_retry_async_fail_clamp_backoff() {
872 let mut states = vec![];
873 let res = Retry::default()
874 .initial_backoff(Duration::from_millis(1))
875 .clamp_backoff(Duration::from_millis(1))
876 .max_tries(4)
877 .retry_async(|state| {
878 states.push(state);
879 async { Err::<(), _>("injected") }
880 })
881 .await;
882 assert_eq!(res, Err("injected"));
883 assert_eq!(
884 states,
885 &[
886 RetryState {
887 i: 0,
888 next_backoff: Some(Duration::from_millis(1))
889 },
890 RetryState {
891 i: 1,
892 next_backoff: Some(Duration::from_millis(1))
893 },
894 RetryState {
895 i: 2,
896 next_backoff: Some(Duration::from_millis(1))
897 },
898 RetryState {
899 i: 3,
900 next_backoff: None
901 },
902 ]
903 );
904 }
905
906 #[crate::test(tokio::test)]
909 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_uncanceled_failure() {
911 let res = Retry::default()
912 .max_duration(Duration::from_millis(100))
913 .retry_async_canceling(|_| async move { Err::<(), _>(anyhow!("injected")) })
914 .await;
915 assert_eq!(res.unwrap_err().to_string(), "injected");
916 }
917
918 #[crate::test(tokio::test)]
921 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_canceled_failure() {
923 let res = Retry::default()
924 .max_duration(Duration::from_millis(100))
925 .retry_async_canceling(|state| async move {
926 if state.i == 0 {
927 bail!("injected")
928 } else {
929 time::sleep(Duration::MAX).await;
930 Ok(())
931 }
932 })
933 .await;
934 assert_eq!(res.unwrap_err().to_string(), "injected");
935 }
936
937 #[crate::test(tokio::test)]
940 #[cfg_attr(miri, ignore)] async fn test_retry_async_canceling_canceled_first_failure() {
942 let res = Retry::default()
943 .max_duration(Duration::from_millis(100))
944 .retry_async_canceling(|_| async move {
945 time::sleep(Duration::MAX).await;
946 Ok::<_, anyhow::Error>(())
947 })
948 .await;
949 assert_eq!(res.unwrap_err().to_string(), "deadline has elapsed");
950 }
951
952 #[crate::test(tokio::test)]
953 #[cfg_attr(miri, ignore)] async fn test_retry_reader() {
955 use tokio::io::AsyncReadExt;
956
957 struct FlakyReader {
959 offset: usize,
960 should_error: bool,
961 }
962
963 impl AsyncRead for FlakyReader {
964 fn poll_read(
965 mut self: Pin<&mut Self>,
966 _: &mut Context<'_>,
967 buf: &mut ReadBuf<'_>,
968 ) -> Poll<Result<(), std::io::Error>> {
969 if self.should_error {
970 Poll::Ready(Err(std::io::ErrorKind::ConnectionReset.into()))
971 } else if self.offset < 256 {
972 buf.put_slice(&[b'A']);
973 self.should_error = true;
974 Poll::Ready(Ok(()))
975 } else {
976 Poll::Ready(Ok(()))
977 }
978 }
979 }
980
981 let reader = RetryReader::new(|_state, offset| async move {
982 Ok(FlakyReader {
983 offset,
984 should_error: false,
985 })
986 });
987 tokio::pin!(reader);
988
989 let mut data = Vec::new();
990 reader.read_to_end(&mut data).await.unwrap();
991 assert_eq!(data, vec![b'A'; 256]);
992 }
993
994 #[crate::test(tokio::test)]
995 #[cfg_attr(miri, ignore)] async fn test_retry_async_state() {
997 struct S {
998 i: i64,
999 }
1000 impl S {
1001 #[allow(clippy::unused_async)]
1002 async fn try_inc(&mut self) -> Result<i64, ()> {
1003 self.i += 1;
1004 if self.i > 10 { Ok(self.i) } else { Err(()) }
1005 }
1006 }
1007
1008 let s = S { i: 0 };
1009 let (_new_s, res) = Retry::default()
1010 .max_tries(10)
1011 .clamp_backoff(Duration::from_nanos(0))
1012 .retry_async_with_state(s, |_, mut s| async {
1013 let res = s.try_inc().await;
1014 (s, res)
1015 })
1016 .await;
1017 assert_eq!(res, Err(()));
1018
1019 let s = S { i: 0 };
1020 let (_new_s, res) = Retry::default()
1021 .max_tries(11)
1022 .clamp_backoff(Duration::from_nanos(0))
1023 .retry_async_with_state(s, |_, mut s| async {
1024 let res = s.try_inc().await;
1025 (s, res)
1026 })
1027 .await;
1028 assert_eq!(res, Ok(11));
1029 }
1030}