kube_runtime/utils/
delayed_init.rs
1use std::{fmt::Debug, sync::Mutex, task::Poll};
2
3use futures::{channel, Future, FutureExt};
4use thiserror::Error;
5use tracing::trace;
6
7pub struct Initializer<T>(channel::oneshot::Sender<T>);
9impl<T> Initializer<T> {
10 pub fn init(self, value: T) {
12 let _ = self.0.send(value);
15 }
16}
17impl<T> Debug for Initializer<T> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("delayed_init::Initializer")
20 .finish_non_exhaustive()
21 }
22}
23
24#[derive(Debug)]
29pub struct DelayedInit<T> {
30 state: Mutex<ReceiverState<T>>,
31}
32#[derive(Debug)]
33enum ReceiverState<T> {
34 Waiting(channel::oneshot::Receiver<T>),
35 Ready(Result<T, InitDropped>),
36}
37impl<T> DelayedInit<T> {
38 #[must_use]
40 pub fn new() -> (Initializer<T>, Self) {
41 let (tx, rx) = channel::oneshot::channel();
42 (Initializer(tx), DelayedInit {
43 state: Mutex::new(ReceiverState::Waiting(rx)),
44 })
45 }
46}
47impl<T: Clone + Send + Sync> DelayedInit<T> {
48 pub async fn get(&self) -> Result<T, InitDropped> {
57 Get(self).await
58 }
59}
60
61struct Get<'a, T>(&'a DelayedInit<T>);
64impl<T> Future for Get<'_, T>
65where
66 T: Clone,
67{
68 type Output = Result<T, InitDropped>;
69
70 #[tracing::instrument(name = "DelayedInit::get", level = "trace", skip(self, cx))]
71 fn poll(
72 self: std::pin::Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<Self::Output> {
75 let mut state = self.0.state.lock().unwrap();
76 trace!("got lock lock");
77 match &mut *state {
78 ReceiverState::Waiting(rx) => {
79 trace!("channel still active, polling");
80 if let Poll::Ready(value) = rx.poll_unpin(cx).map_err(|_| InitDropped) {
81 trace!("got value on slow path, memoizing");
82 *state = ReceiverState::Ready(value.clone());
83 Poll::Ready(value)
84 } else {
85 trace!("channel is still pending");
86 Poll::Pending
87 }
88 }
89 ReceiverState::Ready(v) => {
90 trace!("slow path but value was already initialized, another writer already initialized");
91 Poll::Ready(v.clone())
92 }
93 }
94 }
95}
96
97#[derive(Debug, Error, Clone, Copy, PartialEq, Eq)]
98#[error("initializer was dropped before value was initialized")]
99pub struct InitDropped;
100
101#[cfg(test)]
102mod tests {
103 use std::{pin::pin, task::Poll};
104
105 use super::DelayedInit;
106 use futures::poll;
107 use tracing::Level;
108 use tracing_subscriber::util::SubscriberInitExt;
109
110 fn setup_tracing() -> tracing::dispatcher::DefaultGuard {
111 tracing_subscriber::fmt()
112 .with_max_level(Level::TRACE)
113 .with_test_writer()
114 .finish()
115 .set_default()
116 }
117
118 #[tokio::test]
119 async fn must_allow_single_reader() {
120 let _tracing = setup_tracing();
121 let (tx, rx) = DelayedInit::<u8>::new();
122 let mut get1 = pin!(rx.get());
123 assert_eq!(poll!(get1.as_mut()), Poll::Pending);
124 tx.init(1);
125 assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
126 }
127
128 #[tokio::test]
129 async fn must_allow_concurrent_readers_while_waiting() {
130 let _tracing = setup_tracing();
131 let (tx, rx) = DelayedInit::<u8>::new();
132 let mut get1 = pin!(rx.get());
133 let mut get2 = pin!(rx.get());
134 let mut get3 = pin!(rx.get());
135 assert_eq!(poll!(get1.as_mut()), Poll::Pending);
136 assert_eq!(poll!(get2.as_mut()), Poll::Pending);
137 assert_eq!(poll!(get3.as_mut()), Poll::Pending);
138 tx.init(1);
139 assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
140 assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
141 assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
142 }
143
144 #[tokio::test]
145 async fn must_allow_reading_after_init() {
146 let _tracing = setup_tracing();
147 let (tx, rx) = DelayedInit::<u8>::new();
148 let mut get1 = pin!(rx.get());
149 assert_eq!(poll!(get1.as_mut()), Poll::Pending);
150 tx.init(1);
151 assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
152 assert_eq!(rx.get().await, Ok(1));
153 assert_eq!(rx.get().await, Ok(1));
154 }
155
156 #[tokio::test]
157 async fn must_allow_concurrent_readers_in_any_order() {
158 let _tracing = setup_tracing();
159 let (tx, rx) = DelayedInit::<u8>::new();
160 let mut get1 = pin!(rx.get());
161 let mut get2 = pin!(rx.get());
162 let mut get3 = pin!(rx.get());
163 assert_eq!(poll!(get1.as_mut()), Poll::Pending);
164 assert_eq!(poll!(get2.as_mut()), Poll::Pending);
165 assert_eq!(poll!(get3.as_mut()), Poll::Pending);
166 tx.init(1);
167 assert_eq!(poll!(get3), Poll::Ready(Ok(1)));
168 assert_eq!(poll!(get2), Poll::Ready(Ok(1)));
169 assert_eq!(poll!(get1), Poll::Ready(Ok(1)));
170 }
171}