use super::future_hash_map::FutureHashMap;
use crate::scheduler::{ScheduleRequest, Scheduler};
use futures::{FutureExt, Stream, StreamExt};
use pin_project::pin_project;
use std::{
convert::Infallible,
future::{self, Future},
hash::Hash,
pin::Pin,
task::{Context, Poll},
};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error<ReadyErr> {
#[error("readiness gate failed to become ready")]
Readiness(#[source] ReadyErr),
}
#[pin_project]
pub struct Runner<T, R, F, MkF, Ready = future::Ready<Result<(), Infallible>>> {
#[pin]
scheduler: Scheduler<T, R>,
run_msg: MkF,
slots: FutureHashMap<T, F>,
#[pin]
ready_to_execute_after: futures::future::Fuse<Ready>,
is_ready_to_execute: bool,
stopped: bool,
max_concurrent_executions: u16,
}
impl<T, R, F, MkF> Runner<T, R, F, MkF>
where
F: Future + Unpin,
MkF: FnMut(&T) -> F,
{
pub fn new(scheduler: Scheduler<T, R>, max_concurrent_executions: u16, run_msg: MkF) -> Self {
Self {
scheduler,
run_msg,
slots: FutureHashMap::default(),
ready_to_execute_after: future::ready(Ok(())).fuse(),
is_ready_to_execute: false,
stopped: false,
max_concurrent_executions,
}
}
pub fn delay_tasks_until<Ready, ReadyErr>(
self,
ready_to_execute_after: Ready,
) -> Runner<T, R, F, MkF, Ready>
where
Ready: Future<Output = Result<(), ReadyErr>>,
{
Runner {
scheduler: self.scheduler,
run_msg: self.run_msg,
slots: self.slots,
ready_to_execute_after: ready_to_execute_after.fuse(),
is_ready_to_execute: false,
stopped: false,
max_concurrent_executions: self.max_concurrent_executions,
}
}
}
#[allow(clippy::match_wildcard_for_single_variants)]
impl<T, R, F, MkF, Ready, ReadyErr> Stream for Runner<T, R, F, MkF, Ready>
where
T: Eq + Hash + Clone + Unpin,
R: Stream<Item = ScheduleRequest<T>>,
F: Future + Unpin,
MkF: FnMut(&T) -> F,
Ready: Future<Output = Result<(), ReadyErr>>,
{
type Item = Result<F::Output, Error<ReadyErr>>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if *this.stopped {
return Poll::Ready(None);
}
let slots = this.slots;
let scheduler = &mut this.scheduler;
let has_active_slots = match slots.poll_next_unpin(cx) {
Poll::Ready(Some(result)) => return Poll::Ready(Some(Ok(result))),
Poll::Ready(None) => false,
Poll::Pending => true,
};
match this.ready_to_execute_after.poll(cx) {
Poll::Ready(Ok(())) => *this.is_ready_to_execute = true,
Poll::Ready(Err(err)) => {
*this.stopped = true;
return Poll::Ready(Some(Err(Error::Readiness(err))));
}
Poll::Pending => {}
}
loop {
if (*this.max_concurrent_executions > 0
&& slots.len() >= *this.max_concurrent_executions as usize)
|| !*this.is_ready_to_execute
{
match scheduler.as_mut().hold().poll_next_unpin(cx) {
Poll::Pending | Poll::Ready(None) => break Poll::Pending,
_ => unreachable!(),
};
};
let next_msg_poll = scheduler
.as_mut()
.hold_unless(|msg| !slots.contains_key(msg))
.poll_next_unpin(cx);
match next_msg_poll {
Poll::Ready(Some(msg)) => {
let msg_fut = (this.run_msg)(&msg);
assert!(
slots.insert(msg, msg_fut).is_none(),
"Runner tried to replace a running future.. please report this as a kube-rs bug!"
);
cx.waker().wake_by_ref();
}
Poll::Ready(None) => {
break if has_active_slots {
Poll::Pending
} else {
Poll::Ready(None)
};
}
Poll::Pending => break Poll::Pending,
}
}
}
}
#[cfg(test)]
mod tests {
use super::{Error, Runner};
use crate::{
scheduler::{scheduler, ScheduleRequest},
utils::delayed_init::{self, DelayedInit},
};
use futures::{
channel::{mpsc, oneshot},
future, poll, stream, Future, SinkExt, StreamExt, TryStreamExt,
};
use std::{
cell::RefCell,
collections::HashMap,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
use tokio::{
runtime::Handle,
task::yield_now,
time::{advance, pause, sleep, timeout, Instant},
};
#[tokio::test]
async fn runner_should_never_run_two_instances_at_once() {
pause();
let rc = RefCell::new(());
let mut count = 0;
let (mut sched_tx, sched_rx) = mpsc::unbounded();
let mut runner = Box::pin(
Runner::new(scheduler(sched_rx), 0, |_| {
count += 1;
let mutex_ref = rc.borrow_mut();
Box::pin(async move {
sleep(Duration::from_secs(1)).await;
drop(mutex_ref);
})
})
.for_each(|_| async {}),
);
sched_tx
.send(ScheduleRequest {
message: (),
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(runner.as_mut()).is_pending());
sched_tx
.send(ScheduleRequest {
message: (),
run_at: Instant::now(),
})
.await
.unwrap();
future::join(
async {
tokio::time::sleep(Duration::from_secs(5)).await;
drop(sched_tx);
},
runner,
)
.await;
assert_eq!(count, 2);
}
#[tokio::test(flavor = "current_thread")]
async fn runner_should_wake_when_scheduling_messages() {
let (mut sched_tx, sched_rx) = mpsc::unbounded();
let (result_tx, result_rx) = oneshot::channel();
let mut runner = Runner::new(scheduler(sched_rx), 0, |msg: &u8| futures::future::ready(*msg));
Handle::current().spawn(async move { result_tx.send(runner.next().await).unwrap() });
yield_now().await;
sched_tx
.send(ScheduleRequest {
message: 8,
run_at: Instant::now(),
})
.await
.unwrap();
assert_eq!(
timeout(Duration::from_secs(1), result_rx)
.await
.unwrap()
.unwrap()
.transpose()
.unwrap(),
Some(8)
);
}
#[tokio::test]
async fn runner_should_wait_for_readiness() {
let is_ready = Mutex::new(false);
let (delayed_init, ready) = DelayedInit::<()>::new();
let mut runner = Box::pin(
Runner::new(
scheduler(
stream::iter([ScheduleRequest {
message: 1u8,
run_at: Instant::now(),
}])
.chain(stream::pending()),
),
0,
|msg| {
assert!(*is_ready.lock().unwrap());
future::ready(*msg)
},
)
.delay_tasks_until(ready.get()),
);
assert!(poll!(runner.next()).is_pending());
*is_ready.lock().unwrap() = true;
delayed_init.init(());
assert_eq!(runner.next().await.transpose().unwrap(), Some(1));
}
#[tokio::test]
async fn runner_should_dedupe_while_waiting_for_readiness() {
let is_ready = Mutex::new(false);
let (delayed_init, ready) = DelayedInit::<()>::new();
let mut runner = Box::pin(
Runner::new(
scheduler(
stream::iter([
ScheduleRequest {
message: 'a',
run_at: Instant::now(),
},
ScheduleRequest {
message: 'b',
run_at: Instant::now(),
},
ScheduleRequest {
message: 'a',
run_at: Instant::now(),
},
])
.chain(stream::pending()),
),
0,
|msg| {
assert!(*is_ready.lock().unwrap());
future::ready(*msg)
},
)
.delay_tasks_until(ready.get()),
);
assert!(poll!(runner.next()).is_pending());
*is_ready.lock().unwrap() = true;
delayed_init.init(());
let mut message_counts = HashMap::new();
assert!(timeout(
Duration::from_secs(1),
runner.try_for_each(|msg| {
*message_counts.entry(msg).or_default() += 1;
async { Ok(()) }
})
)
.await
.is_err());
assert_eq!(message_counts, HashMap::from([('a', 1), ('b', 1)]));
}
#[tokio::test]
async fn runner_should_report_readiness_errors() {
let (delayed_init, ready) = DelayedInit::<()>::new();
let mut runner = Box::pin(
Runner::new(
scheduler(
stream::iter([ScheduleRequest {
message: (),
run_at: Instant::now(),
}])
.chain(stream::pending()),
),
0,
|()| {
panic!("run_msg should never be invoked if readiness gate fails");
#[allow(unreachable_code)]
future::ready(())
},
)
.delay_tasks_until(ready.get()),
);
assert!(poll!(runner.next()).is_pending());
drop(delayed_init);
assert!(matches!(
runner.try_collect::<Vec<_>>().await.unwrap_err(),
Error::Readiness(delayed_init::InitDropped)
));
}
struct DurationalFuture {
start: Instant,
ready_after: Duration,
}
impl DurationalFuture {
fn new(expires_in: Duration) -> Self {
let start = Instant::now();
DurationalFuture {
start,
ready_after: expires_in,
}
}
}
impl Future for DurationalFuture {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let now = Instant::now();
if now.duration_since(self.start) > self.ready_after {
Poll::Ready(())
} else {
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
#[tokio::test]
async fn runner_should_respect_max_concurrent_executions() {
pause();
let count = Arc::new(Mutex::new(0));
let (mut sched_tx, sched_rx) = mpsc::unbounded();
let mut runner = Box::pin(
Runner::new(scheduler(sched_rx), 2, |_| {
let mut num = count.lock().unwrap();
*num += 1;
DurationalFuture::new(Duration::from_secs(2))
})
.for_each(|_| async {}),
);
sched_tx
.send(ScheduleRequest {
message: 1,
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(runner.as_mut()).is_pending());
sched_tx
.send(ScheduleRequest {
message: 2,
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(runner.as_mut()).is_pending());
sched_tx
.send(ScheduleRequest {
message: 3,
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(runner.as_mut()).is_pending());
assert_eq!(*count.lock().unwrap(), 2);
advance(Duration::from_secs(3)).await;
assert!(poll!(runner.as_mut()).is_pending());
assert_eq!(*count.lock().unwrap(), 3);
advance(Duration::from_secs(3)).await;
assert!(poll!(runner.as_mut()).is_pending());
sched_tx
.send(ScheduleRequest {
message: 3,
run_at: Instant::now(),
})
.await
.unwrap();
advance(Duration::from_secs(3)).await;
assert!(poll!(runner.as_mut()).is_pending());
assert_eq!(*count.lock().unwrap(), 4);
let (mut sched_tx, sched_rx) = mpsc::unbounded();
let mut runner = Box::pin(
Runner::new(scheduler(sched_rx), 1, |_| {
DurationalFuture::new(Duration::from_secs(2))
})
.for_each(|_| async {}),
);
sched_tx
.send(ScheduleRequest {
message: 1,
run_at: Instant::now(),
})
.await
.unwrap();
assert!(poll!(runner.as_mut()).is_pending());
drop(sched_tx);
assert_eq!(poll!(runner.as_mut()), Poll::Pending);
}
}