kube_runtime/utils/
delayed_init.rsuse std::{fmt::Debug, sync::Mutex, task::Poll};
use derivative::Derivative;
use futures::{channel, Future, FutureExt};
use thiserror::Error;
use tracing::trace;
pub struct Initializer<T>(channel::oneshot::Sender<T>);
impl<T> Initializer<T> {
pub fn init(self, value: T) {
let _ = self.0.send(value);
}
}
impl<T> Debug for Initializer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("delayed_init::Initializer")
.finish_non_exhaustive()
}
}
#[derive(Derivative)]
#[derivative(Debug)]
pub struct DelayedInit<T> {
state: Mutex<ReceiverState<T>>,
}
#[derive(Debug)]
enum ReceiverState<T> {
Waiting(channel::oneshot::Receiver<T>),
Ready(Result<T, InitDropped>),
}
impl<T> DelayedInit<T> {
#[must_use]
pub fn new() -> (Initializer<T>, Self) {
let (tx, rx) = channel::oneshot::channel();
(Initializer(tx), DelayedInit {
state: Mutex::new(ReceiverState::Waiting(rx)),
})
}
}
impl<T: Clone + Send + Sync> DelayedInit<T> {
pub async fn get(&self) -> Result<T, InitDropped> {
Get(self).await
}
}
struct Get<'a, T>(&'a DelayedInit<T>);
impl<'a, T> Future for Get<'a, T>
where
T: Clone,
{
type Output = Result<T, InitDropped>;
#[tracing::instrument(name = "DelayedInit::get", level = "trace", skip(self, cx))]
fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let mut state = self.0.state.lock().unwrap();
trace!("got lock lock");
match &mut *state {
ReceiverState::Waiting(rx) => {
trace!("channel still active, polling");
if let Poll::Ready(value) = rx.poll_unpin(cx).map_err(|_| InitDropped) {
trace!("got value on slow path, memoizing");
*state = ReceiverState::Ready(value.clone());
Poll::Ready(value)
} else {
trace!("channel is still pending");
Poll::Pending
}
}
ReceiverState::Ready(v) => {
trace!("slow path but value was already initialized, another writer already initialized");
Poll::Ready(v.clone())
}
}
}
}
#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
#[error("initializer was dropped before value was initialized")]
pub struct InitDropped;
#[cfg(test)]
mod tests {
use std::{pin::pin, task::Poll};
use super::DelayedInit;
use futures::poll;
use tracing::Level;
use tracing_subscriber::util::SubscriberInitExt;
fn setup_tracing() -> tracing::dispatcher::DefaultGuard {
tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.with_test_writer()
.finish()
.set_default()
}
#[tokio::test]
async fn must_allow_single_reader() {
let _tracing = setup_tracing();
let (tx, rx) = DelayedInit::<u8>::new();
let mut get1 = pin!(rx.get());
assert_eq!(poll!(get1.as_mut()), Poll::Pending);
tx.init(1);
assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
}
#[tokio::test]
async fn must_allow_concurrent_readers_while_waiting() {
let _tracing = setup_tracing();
let (tx, rx) = DelayedInit::<u8>::new();
let mut get1 = pin!(rx.get());
let mut get2 = pin!(rx.get());
let mut get3 = pin!(rx.get());
assert_eq!(poll!(get1.as_mut()), Poll::Pending);
assert_eq!(poll!(get2.as_mut()), Poll::Pending);
assert_eq!(poll!(get3.as_mut()), Poll::Pending);
tx.init(1);
assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
}
#[tokio::test]
async fn must_allow_reading_after_init() {
let _tracing = setup_tracing();
let (tx, rx) = DelayedInit::<u8>::new();
let mut get1 = pin!(rx.get());
assert_eq!(poll!(get1.as_mut()), Poll::Pending);
tx.init(1);
assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
assert_eq!(rx.get().await, Ok(1));
assert_eq!(rx.get().await, Ok(1));
}
#[tokio::test]
async fn must_allow_concurrent_readers_in_any_order() {
let _tracing = setup_tracing();
let (tx, rx) = DelayedInit::<u8>::new();
let mut get1 = pin!(rx.get());
let mut get2 = pin!(rx.get());
let mut get3 = pin!(rx.get());
assert_eq!(poll!(get1.as_mut()), Poll::Pending);
assert_eq!(poll!(get2.as_mut()), Poll::Pending);
assert_eq!(poll!(get3.as_mut()), Poll::Pending);
tx.init(1);
assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
}
}