kube_runtime/utils/
delayed_init.rs

1use std::{fmt::Debug, sync::Mutex, task::Poll};
2
3use futures::{channel, Future, FutureExt};
4use thiserror::Error;
5use tracing::trace;
6
7/// The sending counterpart to a [`DelayedInit`]
8pub struct Initializer<T>(channel::oneshot::Sender<T>);
9impl<T> Initializer<T> {
10    /// Sends `value` to the linked [`DelayedInit`].
11    pub fn init(self, value: T) {
12        // oneshot::Sender::send fails if no recipients remain, this is not really a relevant
13        // case to signal for our use case
14        let _ = self.0.send(value);
15    }
16}
17impl<T> Debug for Initializer<T> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("delayed_init::Initializer")
20            .finish_non_exhaustive()
21    }
22}
23
24/// A value that must be initialized by an external writer
25///
26/// Can be considered equivalent to a [`channel::oneshot`] channel, except for that
27/// the value produced is retained for subsequent calls to [`Self::get`].
28#[derive(Debug)]
29pub struct DelayedInit<T> {
30    state: Mutex<ReceiverState<T>>,
31}
32#[derive(Debug)]
33enum ReceiverState<T> {
34    Waiting(channel::oneshot::Receiver<T>),
35    Ready(Result<T, InitDropped>),
36}
37impl<T> DelayedInit<T> {
38    /// Returns an empty `DelayedInit` that has no value, along with a linked [`Initializer`]
39    #[must_use]
40    pub fn new() -> (Initializer<T>, Self) {
41        let (tx, rx) = channel::oneshot::channel();
42        (Initializer(tx), DelayedInit {
43            state: Mutex::new(ReceiverState::Waiting(rx)),
44        })
45    }
46}
47impl<T: Clone + Send + Sync> DelayedInit<T> {
48    /// Wait for the value to be available and then return it
49    ///
50    /// Calling `get` again if a value has already been returned is guaranteed to return (a clone of)
51    /// the same value.
52    ///
53    /// # Errors
54    ///
55    /// Fails if the associated [`Initializer`] has been dropped before calling [`Initializer::init`].
56    pub async fn get(&self) -> Result<T, InitDropped> {
57        Get(self).await
58    }
59}
60
61// Using a manually implemented future because we don't want to hold the lock across poll calls
62// since that would mean that an unpolled writer would stall all other tasks from being able to poll it
63struct Get<'a, T>(&'a DelayedInit<T>);
64impl<T> Future for Get<'_, T>
65where
66    T: Clone,
67{
68    type Output = Result<T, InitDropped>;
69
70    #[tracing::instrument(name = "DelayedInit::get", level = "trace", skip(self, cx))]
71    fn poll(
72        self: std::pin::Pin<&mut Self>,
73        cx: &mut std::task::Context<'_>,
74    ) -> std::task::Poll<Self::Output> {
75        let mut state = self.0.state.lock().unwrap();
76        trace!("got lock lock");
77        match &mut *state {
78            ReceiverState::Waiting(rx) => {
79                trace!("channel still active, polling");
80                if let Poll::Ready(value) = rx.poll_unpin(cx).map_err(|_| InitDropped) {
81                    trace!("got value on slow path, memoizing");
82                    *state = ReceiverState::Ready(value.clone());
83                    Poll::Ready(value)
84                } else {
85                    trace!("channel is still pending");
86                    Poll::Pending
87                }
88            }
89            ReceiverState::Ready(v) => {
90                trace!("slow path but value was already initialized, another writer already initialized");
91                Poll::Ready(v.clone())
92            }
93        }
94    }
95}
96
97#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
98#[error("initializer was dropped before value was initialized")]
99pub struct InitDropped;
100
101#[cfg(test)]
102mod tests {
103    use std::{pin::pin, task::Poll};
104
105    use super::DelayedInit;
106    use futures::poll;
107    use tracing::Level;
108    use tracing_subscriber::util::SubscriberInitExt;
109
110    fn setup_tracing() -> tracing::dispatcher::DefaultGuard {
111        tracing_subscriber::fmt()
112            .with_max_level(Level::TRACE)
113            .with_test_writer()
114            .finish()
115            .set_default()
116    }
117
118    #[tokio::test]
119    async fn must_allow_single_reader() {
120        let _tracing = setup_tracing();
121        let (tx, rx) = DelayedInit::<u8>::new();
122        let mut get1 = pin!(rx.get());
123        assert_eq!(poll!(get1.as_mut()), Poll::Pending);
124        tx.init(1);
125        assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
126    }
127
128    #[tokio::test]
129    async fn must_allow_concurrent_readers_while_waiting() {
130        let _tracing = setup_tracing();
131        let (tx, rx) = DelayedInit::<u8>::new();
132        let mut get1 = pin!(rx.get());
133        let mut get2 = pin!(rx.get());
134        let mut get3 = pin!(rx.get());
135        assert_eq!(poll!(get1.as_mut()), Poll::Pending);
136        assert_eq!(poll!(get2.as_mut()), Poll::Pending);
137        assert_eq!(poll!(get3.as_mut()), Poll::Pending);
138        tx.init(1);
139        assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
140        assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
141        assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
142    }
143
144    #[tokio::test]
145    async fn must_allow_reading_after_init() {
146        let _tracing = setup_tracing();
147        let (tx, rx) = DelayedInit::<u8>::new();
148        let mut get1 = pin!(rx.get());
149        assert_eq!(poll!(get1.as_mut()), Poll::Pending);
150        tx.init(1);
151        assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
152        assert_eq!(rx.get().await, Ok(1));
153        assert_eq!(rx.get().await, Ok(1));
154    }
155
156    #[tokio::test]
157    async fn must_allow_concurrent_readers_in_any_order() {
158        let _tracing = setup_tracing();
159        let (tx, rx) = DelayedInit::<u8>::new();
160        let mut get1 = pin!(rx.get());
161        let mut get2 = pin!(rx.get());
162        let mut get3 = pin!(rx.get());
163        assert_eq!(poll!(get1.as_mut()), Poll::Pending);
164        assert_eq!(poll!(get2.as_mut()), Poll::Pending);
165        assert_eq!(poll!(get3.as_mut()), Poll::Pending);
166        tx.init(1);
167        assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
168        assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
169        assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
170    }
171}