mz_persist_client/internal/
watch.rs1use std::sync::Arc;
13
14use mz_persist::location::SeqNo;
15use tokio::sync::broadcast;
16use tracing::debug;
17
18use crate::cache::LockingTypedState;
19use crate::internal::metrics::Metrics;
20
21#[derive(Debug)]
22pub struct StateWatchNotifier {
23 metrics: Arc<Metrics>,
24 tx: broadcast::Sender<SeqNo>,
25}
26
27impl StateWatchNotifier {
28 pub(crate) fn new(metrics: Arc<Metrics>) -> Self {
29 let (tx, _rx) = broadcast::channel(1);
30 StateWatchNotifier { metrics, tx }
31 }
32
33 pub(crate) fn notify(&self, seqno: SeqNo) {
44 match self.tx.send(seqno) {
45 Ok(_) => {
47 self.metrics.watch.notify_sent.inc();
48 }
49 Err(_) => {
51 self.metrics.watch.notify_noop.inc();
52 }
53 }
54 }
55}
56
57#[derive(Debug)]
70pub struct StateWatch<K, V, T, D> {
71 metrics: Arc<Metrics>,
72 state: Arc<LockingTypedState<K, V, T, D>>,
73 seqno_high_water: SeqNo,
74 rx: broadcast::Receiver<SeqNo>,
75}
76
77impl<K, V, T, D> StateWatch<K, V, T, D> {
78 pub(crate) fn new(state: Arc<LockingTypedState<K, V, T, D>>, metrics: Arc<Metrics>) -> Self {
79 let rx = state.notifier().tx.subscribe();
86 let seqno_high_water = state.read_lock(&metrics.locks.watch, |x| x.seqno);
87 StateWatch {
88 metrics,
89 state,
90 seqno_high_water,
91 rx,
92 }
93 }
94
95 pub async fn wait_for_seqno_ge(&mut self, requested: SeqNo) -> &mut Self {
99 self.metrics.watch.notify_wait_started.inc();
100 debug!("wait_for_seqno_ge {} {}", self.state.shard_id(), requested);
101 loop {
102 if self.seqno_high_water >= requested {
103 break;
104 }
105 match self.rx.recv().await {
106 Ok(x) => {
107 self.metrics.watch.notify_recv.inc();
108 assert!(x >= self.seqno_high_water);
109 self.seqno_high_water = x;
110 }
111 Err(broadcast::error::RecvError::Closed) => {
112 unreachable!("we're holding on to a reference to the sender")
113 }
114 Err(broadcast::error::RecvError::Lagged(_)) => {
115 self.metrics.watch.notify_lagged.inc();
116 continue;
122 }
123 }
124 }
125 self.metrics.watch.notify_wait_finished.inc();
126 debug!(
127 "wait_for_seqno_ge {} {} returning",
128 self.state.shard_id(),
129 requested
130 );
131 self
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use std::future::Future;
138 use std::pin::Pin;
139 use std::task::Context;
140 use std::time::Duration;
141
142 use futures::FutureExt;
143 use futures_task::noop_waker;
144 use mz_build_info::DUMMY_BUILD_INFO;
145 use mz_dyncfg::ConfigUpdates;
146 use mz_ore::cast::CastFrom;
147 use mz_ore::metrics::MetricsRegistry;
148 use mz_ore::{assert_none, assert_ok};
149 use timely::progress::Antichain;
150
151 use crate::cache::StateCache;
152 use crate::cfg::PersistConfig;
153 use crate::internal::machine::{
154 NEXT_LISTEN_BATCH_RETRYER_CLAMP, NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
155 NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER,
156 };
157 use crate::internal::state::TypedState;
158 use crate::tests::new_test_client;
159 use crate::{Diagnostics, ShardId};
160
161 use super::*;
162
163 #[mz_ore::test(tokio::test)]
164 async fn state_watch() {
165 mz_ore::test::init_logging();
166 let metrics = Arc::new(Metrics::new(
167 &PersistConfig::new_for_tests(),
168 &MetricsRegistry::new(),
169 ));
170 let cache = StateCache::new_no_metrics();
171 let shard_id = ShardId::new();
172 let state = cache
173 .get::<(), (), u64, i64, _, _>(
174 shard_id,
175 || async {
176 Ok(TypedState::new(
177 DUMMY_BUILD_INFO.semver_version(),
178 shard_id,
179 "host".to_owned(),
180 0u64,
181 ))
182 },
183 &Diagnostics::for_tests(),
184 )
185 .await
186 .unwrap();
187 assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
188
189 let mut w0 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
191 let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
192
193 let w0s1 = w0.wait_for_seqno_ge(SeqNo(1)).map(|_| ()).shared();
195 assert_eq!(w0s1.clone().now_or_never(), None);
196
197 state.write_lock(&metrics.locks.applier_write, |state| {
199 state.seqno = state.seqno.next()
200 });
201 let () = w0s1.await;
202
203 let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
205
206 let mut w1 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
208 let _ = w1.wait_for_seqno_ge(SeqNo(0)).await;
209 let _ = w1.wait_for_seqno_ge(SeqNo(1)).await;
210 assert_none!(w1.wait_for_seqno_ge(SeqNo(2)).now_or_never());
211 }
212
213 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
214 #[cfg_attr(miri, ignore)] async fn state_watch_concurrency() {
216 mz_ore::test::init_logging();
217 let metrics = Arc::new(Metrics::new(
218 &PersistConfig::new_for_tests(),
219 &MetricsRegistry::new(),
220 ));
221 let cache = StateCache::new_no_metrics();
222 let shard_id = ShardId::new();
223 let state = cache
224 .get::<(), (), u64, i64, _, _>(
225 shard_id,
226 || async {
227 Ok(TypedState::new(
228 DUMMY_BUILD_INFO.semver_version(),
229 shard_id,
230 "host".to_owned(),
231 0u64,
232 ))
233 },
234 &Diagnostics::for_tests(),
235 )
236 .await
237 .unwrap();
238 assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
239
240 const NUM_WATCHES: usize = 100;
241 const NUM_WRITES: usize = 20;
242
243 let watches = (0..NUM_WATCHES)
244 .map(|idx| {
245 let state = Arc::clone(&state);
246 let metrics = Arc::clone(&metrics);
247 mz_ore::task::spawn(|| "watch", async move {
248 let mut watch = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
249 let wait_seqno = SeqNo(u64::cast_from(idx % NUM_WRITES + 1));
251 let _ = watch.wait_for_seqno_ge(wait_seqno).await;
252 let observed_seqno =
253 state.read_lock(&metrics.locks.applier_read_noncacheable, |x| x.seqno);
254 assert!(
255 wait_seqno <= observed_seqno,
256 "{} vs {}",
257 wait_seqno,
258 observed_seqno
259 );
260 })
261 })
262 .collect::<Vec<_>>();
263 let writes = (0..NUM_WRITES)
264 .map(|_| {
265 let state = Arc::clone(&state);
266 let metrics = Arc::clone(&metrics);
267 mz_ore::task::spawn(|| "write", async move {
268 state.write_lock(&metrics.locks.applier_write, |x| {
269 x.seqno = x.seqno.next();
270 });
271 })
272 })
273 .collect::<Vec<_>>();
274 for watch in watches {
275 assert_ok!(watch.await);
276 }
277 for write in writes {
278 assert_ok!(write.await);
279 }
280 }
281
282 #[mz_persist_proc::test(tokio::test)]
283 #[cfg_attr(miri, ignore)] async fn state_watch_listen_snapshot(dyncfgs: ConfigUpdates) {
285 mz_ore::test::init_logging();
286 let waker = noop_waker();
287 let mut cx = Context::from_waker(&waker);
288
289 let client = new_test_client(&dyncfgs).await;
290 client.cfg.set_config(
292 &NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
293 Duration::from_secs(1_000_000),
294 );
295 client
296 .cfg
297 .set_config(&NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER, 1);
298 client.cfg.set_config(
299 &NEXT_LISTEN_BATCH_RETRYER_CLAMP,
300 Duration::from_secs(1_000_000),
301 );
302
303 let (mut write, mut read) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;
304
305 let mut listen = read
308 .clone("test")
309 .await
310 .listen(Antichain::from_elem(0))
311 .await
312 .unwrap();
313 let mut snapshot = Box::pin(read.snapshot(Antichain::from_elem(0)));
314 assert!(Pin::new(&mut snapshot).poll(&mut cx).is_pending());
315 let mut listen_next_batch = Box::pin(listen.next(None));
316 assert!(Pin::new(&mut listen_next_batch).poll(&mut cx).is_pending());
317
318 write.expect_compare_and_append(&[], 0, 1).await;
323 let _ = listen_next_batch.await;
324
325 let _ = snapshot.await;
328 }
329}