tower/buffer/worker.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
use super::{
error::{Closed, ServiceError},
message::Message,
};
use futures_core::ready;
use std::sync::{Arc, Mutex, Weak};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::sync::{mpsc, Semaphore};
use tower_service::Service;
pin_project_lite::pin_project! {
/// Task that handles processing the buffer. This type should not be used
/// directly, instead `Buffer` requires an `Executor` that can accept this task.
///
/// The struct is `pub` in the private module and the type is *not* re-exported
/// as part of the public API. This is the "sealed" pattern to include "private"
/// types in public traits that are not meant for consumers of the library to
/// implement (only call).
#[derive(Debug)]
pub struct Worker<T, Request>
where
T: Service<Request>,
{
current_message: Option<Message<Request, T::Future>>,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
service: T,
finish: bool,
failed: Option<ServiceError>,
handle: Handle,
close: Option<Weak<Semaphore>>,
}
impl<T: Service<Request>, Request> PinnedDrop for Worker<T, Request>
{
fn drop(mut this: Pin<&mut Self>) {
this.as_mut().close_semaphore();
}
}
}
/// Get the error out
#[derive(Debug)]
pub(crate) struct Handle {
inner: Arc<Mutex<Option<ServiceError>>>,
}
impl<T, Request> Worker<T, Request>
where
T: Service<Request>,
{
/// Closes the buffer's semaphore if it is still open, waking any pending
/// tasks.
fn close_semaphore(&mut self) {
if let Some(close) = self.close.take().as_ref().and_then(Weak::upgrade) {
tracing::debug!("buffer closing; waking pending tasks");
close.close();
} else {
tracing::trace!("buffer already closed");
}
}
}
impl<T, Request> Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
pub(crate) fn new(
service: T,
rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
semaphore: &Arc<Semaphore>,
) -> (Handle, Worker<T, Request>) {
let handle = Handle {
inner: Arc::new(Mutex::new(None)),
};
let semaphore = Arc::downgrade(semaphore);
let worker = Worker {
current_message: None,
finish: false,
failed: None,
rx,
service,
handle: handle.clone(),
close: Some(semaphore),
};
(handle, worker)
}
/// Return the next queued Message that hasn't been canceled.
///
/// If a `Message` is returned, the `bool` is true if this is the first time we received this
/// message, and false otherwise (i.e., we tried to forward it to the backing service before).
fn poll_next_msg(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
if self.finish {
// We've already received None and are shutting down
return Poll::Ready(None);
}
tracing::trace!("worker polling for next message");
if let Some(msg) = self.current_message.take() {
// If the oneshot sender is closed, then the receiver is dropped,
// and nobody cares about the response. If this is the case, we
// should continue to the next request.
if !msg.tx.is_closed() {
tracing::trace!("resuming buffered request");
return Poll::Ready(Some((msg, false)));
}
tracing::trace!("dropping cancelled buffered request");
}
// Get the next request
while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
if !msg.tx.is_closed() {
tracing::trace!("processing new request");
return Poll::Ready(Some((msg, true)));
}
// Otherwise, request is canceled, so pop the next one.
tracing::trace!("dropping cancelled request");
}
Poll::Ready(None)
}
fn failed(&mut self, error: crate::BoxError) {
// The underlying service failed when we called `poll_ready` on it with the given `error`. We
// need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
// an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
// requests will also fail with the same error.
// Note that we need to handle the case where some handle is concurrently trying to send us
// a request. We need to make sure that *either* the send of the request fails *or* it
// receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
// case where we send errors to all outstanding requests, and *then* the caller sends its
// request. We do this by *first* exposing the error, *then* closing the channel used to
// send more requests (so the client will see the error when the send fails), and *then*
// sending the error to all outstanding requests.
let error = ServiceError::new(error);
let mut inner = self.handle.inner.lock().unwrap();
if inner.is_some() {
// Future::poll was called after we've already errored out!
return;
}
*inner = Some(error.clone());
drop(inner);
self.rx.close();
// By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
// which will trigger the `self.finish == true` phase. We just need to make sure that any
// requests that we receive before we've exhausted the receiver receive the error:
self.failed = Some(error);
}
}
impl<T, Request> Future for Worker<T, Request>
where
T: Service<Request>,
T::Error: Into<crate::BoxError>,
{
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.finish {
return Poll::Ready(());
}
loop {
match ready!(self.poll_next_msg(cx)) {
Some((msg, first)) => {
let _guard = msg.span.enter();
if let Some(ref failed) = self.failed {
tracing::trace!("notifying caller about worker failure");
let _ = msg.tx.send(Err(failed.clone()));
continue;
}
// Wait for the service to be ready
tracing::trace!(
resumed = !first,
message = "worker received request; waiting for service readiness"
);
match self.service.poll_ready(cx) {
Poll::Ready(Ok(())) => {
tracing::debug!(service.ready = true, message = "processing request");
let response = self.service.call(msg.request);
// Send the response future back to the sender.
//
// An error means the request had been canceled in-between
// our calls, the response future will just be dropped.
tracing::trace!("returning response future");
let _ = msg.tx.send(Ok(response));
}
Poll::Pending => {
tracing::trace!(service.ready = false, message = "delay");
// Put out current message back in its slot.
drop(_guard);
self.current_message = Some(msg);
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
let error = e.into();
tracing::debug!({ %error }, "service failed");
drop(_guard);
self.failed(error);
let _ = msg.tx.send(Err(self
.failed
.as_ref()
.expect("Worker::failed did not set self.failed?")
.clone()));
// Wake any tasks waiting on channel capacity.
self.close_semaphore();
}
}
}
None => {
// No more more requests _ever_.
self.finish = true;
return Poll::Ready(());
}
}
}
}
}
impl Handle {
pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
self.inner
.lock()
.unwrap()
.as_ref()
.map(|svc_err| svc_err.clone().into())
.unwrap_or_else(|| Closed::new().into())
}
}
impl Clone for Handle {
fn clone(&self) -> Handle {
Handle {
inner: self.inner.clone(),
}
}
}