tower/buffer/
worker.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
use super::{
    error::{Closed, ServiceError},
    message::Message,
};
use futures_core::ready;
use std::sync::{Arc, Mutex, Weak};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tokio::sync::{mpsc, Semaphore};
use tower_service::Service;

pin_project_lite::pin_project! {
    /// Task that handles processing the buffer. This type should not be used
    /// directly, instead `Buffer` requires an `Executor` that can accept this task.
    ///
    /// The struct is `pub` in the private module and the type is *not* re-exported
    /// as part of the public API. This is the "sealed" pattern to include "private"
    /// types in public traits that are not meant for consumers of the library to
    /// implement (only call).
    #[derive(Debug)]
    pub struct Worker<T, Request>
    where
        T: Service<Request>,
    {
        current_message: Option<Message<Request, T::Future>>,
        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
        service: T,
        finish: bool,
        failed: Option<ServiceError>,
        handle: Handle,
        close: Option<Weak<Semaphore>>,
    }

    impl<T: Service<Request>, Request> PinnedDrop for Worker<T, Request>
    {
        fn drop(mut this: Pin<&mut Self>) {
            this.as_mut().close_semaphore();
        }
    }
}

/// Get the error out
#[derive(Debug)]
pub(crate) struct Handle {
    inner: Arc<Mutex<Option<ServiceError>>>,
}

impl<T, Request> Worker<T, Request>
where
    T: Service<Request>,
{
    /// Closes the buffer's semaphore if it is still open, waking any pending
    /// tasks.
    fn close_semaphore(&mut self) {
        if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) {
            tracing::debug!("buffer closing; waking pending tasks");
            close.close();
        } else {
            tracing::trace!("buffer already closed");
        }
    }
}

impl<T, Request> Worker<T, Request>
where
    T: Service<Request>,
    T::Error: Into<crate::BoxError>,
{
    pub(crate) fn new(
        service: T,
        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
        semaphore: &Arc<Semaphore>,
    ) -> (Handle, Worker<T, Request>) {
        let handle = Handle {
            inner: Arc::new(Mutex::new(None)),
        };

        let semaphore = Arc::downgrade(semaphore);
        let worker = Worker {
            current_message: None,
            finish: false,
            failed: None,
            rx,
            service,
            handle: handle.clone(),
            close: Some(semaphore),
        };

        (handle, worker)
    }

    /// Return the next queued Message that hasn't been canceled.
    ///
    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
    fn poll_next_msg(
        &mut self,
        cx: &mut Context<'_>,
    ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
        if self.finish {
            // We've already received None and are shutting down
            return Poll::Ready(None);
        }

        tracing::trace!("worker polling for next message");
        if let Some(msg) = self.current_message.take() {
            // If the oneshot sender is closed, then the receiver is dropped,
            // and nobody cares about the response. If this is the case, we
            // should continue to the next request.
            if !msg.tx.is_closed() {
                tracing::trace!("resuming buffered request");
                return Poll::Ready(Some((msg, false)));
            }

            tracing::trace!("dropping cancelled buffered request");
        }

        // Get the next request
        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
            if !msg.tx.is_closed() {
                tracing::trace!("processing new request");
                return Poll::Ready(Some((msg, true)));
            }
            // Otherwise, request is canceled, so pop the next one.
            tracing::trace!("dropping cancelled request");
        }

        Poll::Ready(None)
    }

    fn failed(&mut self, error: crate::BoxError) {
        // The underlying service failed when we called `poll_ready` on it with the given `error`. We
        // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
        // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
        // requests will also fail with the same error.

        // Note that we need to handle the case where some handle is concurrently trying to send us
        // a request. We need to make sure that *either* the send of the request fails *or* it
        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
        // case where we send errors to all outstanding requests, and *then* the caller sends its
        // request. We do this by *first* exposing the error, *then* closing the channel used to
        // send more requests (so the client will see the error when the send fails), and *then*
        // sending the error to all outstanding requests.
        let error = ServiceError::new(error);

        let mut inner = self.handle.inner.lock().unwrap();

        if inner.is_some() {
            // Future::poll was called after we've already errored out!
            return;
        }

        *inner = Some(error.clone());
        drop(inner);

        self.rx.close();

        // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
        // which will trigger the `self.finish == true` phase. We just need to make sure that any
        // requests that we receive before we've exhausted the receiver receive the error:
        self.failed = Some(error);
    }
}

impl<T, Request> Future for Worker<T, Request>
where
    T: Service<Request>,
    T::Error: Into<crate::BoxError>,
{
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        if self.finish {
            return Poll::Ready(());
        }

        loop {
            match ready!(self.poll_next_msg(cx)) {
                Some((msg, first)) => {
                    let _guard = msg.span.enter();
                    if let Some(ref failed) = self.failed {
                        tracing::trace!("notifying caller about worker failure");
                        let _ = msg.tx.send(Err(failed.clone()));
                        continue;
                    }

                    // Wait for the service to be ready
                    tracing::trace!(
                        resumed = !first,
                        message = "worker received request; waiting for service readiness"
                    );
                    match self.service.poll_ready(cx) {
                        Poll::Ready(Ok(())) => {
                            tracing::debug!(service.ready = true, message = "processing request");
                            let response = self.service.call(msg.request);

                            // Send the response future back to the sender.
                            //
                            // An error means the request had been canceled in-between
                            // our calls, the response future will just be dropped.
                            tracing::trace!("returning response future");
                            let _ = msg.tx.send(Ok(response));
                        }
                        Poll::Pending => {
                            tracing::trace!(service.ready = false, message = "delay");
                            // Put out current message back in its slot.
                            drop(_guard);
                            self.current_message = Some(msg);
                            return Poll::Pending;
                        }
                        Poll::Ready(Err(e)) => {
                            let error = e.into();
                            tracing::debug!({ %error }, "service failed");
                            drop(_guard);
                            self.failed(error);
                            let _ = msg.tx.send(Err(self
                                .failed
                                .as_ref()
                                .expect("Worker::failed did not set self.failed?")
                                .clone()));
                            // Wake any tasks waiting on channel capacity.
                            self.close_semaphore();
                        }
                    }
                }
                None => {
                    // No more more requests _ever_.
                    self.finish = true;
                    return Poll::Ready(());
                }
            }
        }
    }
}

impl Handle {
    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
        self.inner
            .lock()
            .unwrap()
            .as_ref()
            .map(|svc_err| svc_err.clone().into())
            .unwrap_or_else(|| Closed::new().into())
    }
}

impl Clone for Handle {
    fn clone(&self) -> Handle {
        Handle {
            inner: self.inner.clone(),
        }
    }
}