tower/buffer/
worker.rs

1use super::{
2    error::{Closed, ServiceError},
3    message::Message,
4};
5use futures_core::ready;
6use std::sync::{Arc, Mutex};
7use std::{
8    future::Future,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tokio::sync::mpsc;
13use tower_service::Service;
14
15pin_project_lite::pin_project! {
16    /// Task that handles processing the buffer. This type should not be used
17    /// directly, instead `Buffer` requires an `Executor` that can accept this task.
18    ///
19    /// The struct is `pub` in the private module and the type is *not* re-exported
20    /// as part of the public API. This is the "sealed" pattern to include "private"
21    /// types in public traits that are not meant for consumers of the library to
22    /// implement (only call).
23    #[derive(Debug)]
24    pub struct Worker<T, Request>
25    where
26        T: Service<Request>,
27    {
28        current_message: Option<Message<Request, T::Future>>,
29        rx: mpsc::Receiver<Message<Request, T::Future>>,
30        service: T,
31        finish: bool,
32        failed: Option<ServiceError>,
33        handle: Handle,
34    }
35}
36
37/// Get the error out
38#[derive(Debug)]
39pub(crate) struct Handle {
40    inner: Arc<Mutex<Option<ServiceError>>>,
41}
42
43impl<T, Request> Worker<T, Request>
44where
45    T: Service<Request>,
46    T::Error: Into<crate::BoxError>,
47{
48    pub(crate) fn new(
49        service: T,
50        rx: mpsc::Receiver<Message<Request, T::Future>>,
51    ) -> (Handle, Worker<T, Request>) {
52        let handle = Handle {
53            inner: Arc::new(Mutex::new(None)),
54        };
55
56        let worker = Worker {
57            current_message: None,
58            finish: false,
59            failed: None,
60            rx,
61            service,
62            handle: handle.clone(),
63        };
64
65        (handle, worker)
66    }
67
68    /// Return the next queued Message that hasn't been canceled.
69    ///
70    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
71    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
72    fn poll_next_msg(
73        &mut self,
74        cx: &mut Context<'_>,
75    ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
76        if self.finish {
77            // We've already received None and are shutting down
78            return Poll::Ready(None);
79        }
80
81        tracing::trace!("worker polling for next message");
82        if let Some(msg) = self.current_message.take() {
83            // If the oneshot sender is closed, then the receiver is dropped,
84            // and nobody cares about the response. If this is the case, we
85            // should continue to the next request.
86            if !msg.tx.is_closed() {
87                tracing::trace!("resuming buffered request");
88                return Poll::Ready(Some((msg, false)));
89            }
90
91            tracing::trace!("dropping cancelled buffered request");
92        }
93
94        // Get the next request
95        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
96            if !msg.tx.is_closed() {
97                tracing::trace!("processing new request");
98                return Poll::Ready(Some((msg, true)));
99            }
100            // Otherwise, request is canceled, so pop the next one.
101            tracing::trace!("dropping cancelled request");
102        }
103
104        Poll::Ready(None)
105    }
106
107    fn failed(&mut self, error: crate::BoxError) {
108        // The underlying service failed when we called `poll_ready` on it with the given `error`. We
109        // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
110        // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
111        // requests will also fail with the same error.
112
113        // Note that we need to handle the case where some handle is concurrently trying to send us
114        // a request. We need to make sure that *either* the send of the request fails *or* it
115        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
116        // case where we send errors to all outstanding requests, and *then* the caller sends its
117        // request. We do this by *first* exposing the error, *then* closing the channel used to
118        // send more requests (so the client will see the error when the send fails), and *then*
119        // sending the error to all outstanding requests.
120        let error = ServiceError::new(error);
121
122        let mut inner = self.handle.inner.lock().unwrap();
123
124        if inner.is_some() {
125            // Future::poll was called after we've already errored out!
126            return;
127        }
128
129        *inner = Some(error.clone());
130        drop(inner);
131
132        self.rx.close();
133
134        // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
135        // which will trigger the `self.finish == true` phase. We just need to make sure that any
136        // requests that we receive before we've exhausted the receiver receive the error:
137        self.failed = Some(error);
138    }
139}
140
141impl<T, Request> Future for Worker<T, Request>
142where
143    T: Service<Request>,
144    T::Error: Into<crate::BoxError>,
145{
146    type Output = ();
147
148    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
149        if self.finish {
150            return Poll::Ready(());
151        }
152
153        loop {
154            match ready!(self.poll_next_msg(cx)) {
155                Some((msg, first)) => {
156                    let _guard = msg.span.enter();
157                    if let Some(ref failed) = self.failed {
158                        tracing::trace!("notifying caller about worker failure");
159                        let _ = msg.tx.send(Err(failed.clone()));
160                        continue;
161                    }
162
163                    // Wait for the service to be ready
164                    tracing::trace!(
165                        resumed = !first,
166                        message = "worker received request; waiting for service readiness"
167                    );
168                    match self.service.poll_ready(cx) {
169                        Poll::Ready(Ok(())) => {
170                            tracing::debug!(service.ready = true, message = "processing request");
171                            let response = self.service.call(msg.request);
172
173                            // Send the response future back to the sender.
174                            //
175                            // An error means the request had been canceled in-between
176                            // our calls, the response future will just be dropped.
177                            tracing::trace!("returning response future");
178                            let _ = msg.tx.send(Ok(response));
179                        }
180                        Poll::Pending => {
181                            tracing::trace!(service.ready = false, message = "delay");
182                            // Put out current message back in its slot.
183                            drop(_guard);
184                            self.current_message = Some(msg);
185                            return Poll::Pending;
186                        }
187                        Poll::Ready(Err(e)) => {
188                            let error = e.into();
189                            tracing::debug!({ %error }, "service failed");
190                            drop(_guard);
191                            self.failed(error);
192                            let _ = msg.tx.send(Err(self
193                                .failed
194                                .as_ref()
195                                .expect("Worker::failed did not set self.failed?")
196                                .clone()));
197                        }
198                    }
199                }
200                None => {
201                    // No more more requests _ever_.
202                    self.finish = true;
203                    return Poll::Ready(());
204                }
205            }
206        }
207    }
208}
209
210impl Handle {
211    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
212        self.inner
213            .lock()
214            .unwrap()
215            .as_ref()
216            .map(|svc_err| svc_err.clone().into())
217            .unwrap_or_else(|| Closed::new().into())
218    }
219}
220
221impl Clone for Handle {
222    fn clone(&self) -> Handle {
223        Handle {
224            inner: self.inner.clone(),
225        }
226    }
227}