tower/retry/
future.rs

1//! Future types
2
3use super::{Policy, Retry};
4use futures_core::ready;
5use pin_project_lite::pin_project;
6use std::future::Future;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tower_service::Service;
10
11pin_project! {
12    /// The [`Future`] returned by a [`Retry`] service.
13    #[derive(Debug)]
14    pub struct ResponseFuture<P, S, Request>
15    where
16        P: Policy<Request, S::Response, S::Error>,
17        S: Service<Request>,
18    {
19        request: Option<Request>,
20        #[pin]
21        retry: Retry<P, S>,
22        #[pin]
23        state: State<S::Future, P::Future>,
24    }
25}
26
27pin_project! {
28    #[project = StateProj]
29    #[derive(Debug)]
30    enum State<F, P> {
31        // Polling the future from [`Service::call`]
32        Called {
33            #[pin]
34            future: F
35        },
36        // Polling the future from [`Policy::retry`]
37        Checking {
38            #[pin]
39            checking: P
40        },
41        // Polling [`Service::poll_ready`] after [`Checking`] was OK.
42        Retrying,
43    }
44}
45
46impl<P, S, Request> ResponseFuture<P, S, Request>
47where
48    P: Policy<Request, S::Response, S::Error>,
49    S: Service<Request>,
50{
51    pub(crate) fn new(
52        request: Option<Request>,
53        retry: Retry<P, S>,
54        future: S::Future,
55    ) -> ResponseFuture<P, S, Request> {
56        ResponseFuture {
57            request,
58            retry,
59            state: State::Called { future },
60        }
61    }
62}
63
64impl<P, S, Request> Future for ResponseFuture<P, S, Request>
65where
66    P: Policy<Request, S::Response, S::Error> + Clone,
67    S: Service<Request> + Clone,
68{
69    type Output = Result<S::Response, S::Error>;
70
71    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        let mut this = self.project();
73
74        loop {
75            match this.state.as_mut().project() {
76                StateProj::Called { future } => {
77                    let result = ready!(future.poll(cx));
78                    if let Some(ref req) = this.request {
79                        match this.retry.policy.retry(req, result.as_ref()) {
80                            Some(checking) => {
81                                this.state.set(State::Checking { checking });
82                            }
83                            None => return Poll::Ready(result),
84                        }
85                    } else {
86                        // request wasn't cloned, so no way to retry it
87                        return Poll::Ready(result);
88                    }
89                }
90                StateProj::Checking { checking } => {
91                    this.retry
92                        .as_mut()
93                        .project()
94                        .policy
95                        .set(ready!(checking.poll(cx)));
96                    this.state.set(State::Retrying);
97                }
98                StateProj::Retrying => {
99                    // NOTE: we assume here that
100                    //
101                    //   this.retry.poll_ready()
102                    //
103                    // is equivalent to
104                    //
105                    //   this.retry.service.poll_ready()
106                    //
107                    // we need to make that assumption to avoid adding an Unpin bound to the Policy
108                    // in Ready to make it Unpin so that we can get &mut Ready as needed to call
109                    // poll_ready on it.
110                    ready!(this.retry.as_mut().project().service.poll_ready(cx))?;
111                    let req = this
112                        .request
113                        .take()
114                        .expect("retrying requires cloned request");
115                    *this.request = this.retry.policy.clone_request(&req);
116                    this.state.set(State::Called {
117                        future: this.retry.as_mut().project().service.call(req),
118                    });
119                }
120            }
121        }
122    }
123}