mz_persist_client/internal/
watch.rs1use mz_persist::location::SeqNo;
13use std::fmt::{Debug, Formatter};
14use std::ops::Deref;
15use std::sync::{Arc, RwLock};
16use tokio::sync::{Notify, broadcast};
17use tracing::debug;
18
19use crate::cache::LockingTypedState;
20use crate::internal::metrics::Metrics;
21
22#[derive(Debug)]
23pub struct StateWatchNotifier {
24 metrics: Arc<Metrics>,
25 tx: broadcast::Sender<SeqNo>,
26}
27
28impl StateWatchNotifier {
29 pub(crate) fn new(metrics: Arc<Metrics>) -> Self {
30 let (tx, _rx) = broadcast::channel(1);
31 StateWatchNotifier { metrics, tx }
32 }
33
34 pub(crate) fn notify(&self, seqno: SeqNo) {
45 match self.tx.send(seqno) {
46 Ok(_) => {
48 self.metrics.watch.notify_sent.inc();
49 }
50 Err(_) => {
52 self.metrics.watch.notify_noop.inc();
53 }
54 }
55 }
56}
57
58#[derive(Debug)]
71pub struct StateWatch<K, V, T, D> {
72 metrics: Arc<Metrics>,
73 state: Arc<LockingTypedState<K, V, T, D>>,
74 seqno_high_water: SeqNo,
75 rx: broadcast::Receiver<SeqNo>,
76}
77
78impl<K, V, T, D> StateWatch<K, V, T, D> {
79 pub(crate) fn new(state: Arc<LockingTypedState<K, V, T, D>>, metrics: Arc<Metrics>) -> Self {
80 let rx = state.notifier().tx.subscribe();
87 let seqno_high_water = state.read_lock(&metrics.locks.watch, |x| x.seqno);
88 StateWatch {
89 metrics,
90 state,
91 seqno_high_water,
92 rx,
93 }
94 }
95
96 pub async fn wait_for_seqno_ge(&mut self, requested: SeqNo) -> &mut Self {
100 self.metrics.watch.notify_wait_started.inc();
101 debug!("wait_for_seqno_ge {} {}", self.state.shard_id(), requested);
102 loop {
103 if self.seqno_high_water >= requested {
104 break;
105 }
106 match self.rx.recv().await {
107 Ok(x) => {
108 self.metrics.watch.notify_recv.inc();
109 assert!(x >= self.seqno_high_water);
110 self.seqno_high_water = x;
111 }
112 Err(broadcast::error::RecvError::Closed) => {
113 unreachable!("we're holding on to a reference to the sender")
114 }
115 Err(broadcast::error::RecvError::Lagged(_)) => {
116 self.metrics.watch.notify_lagged.inc();
117 continue;
123 }
124 }
125 }
126 self.metrics.watch.notify_wait_finished.inc();
127 debug!(
128 "wait_for_seqno_ge {} {} returning",
129 self.state.shard_id(),
130 requested
131 );
132 self
133 }
134}
135
136pub(crate) struct AwaitableState<T> {
142 state: Arc<RwLock<T>>,
143 notify: Arc<Notify>,
146}
147
148impl<T: Debug> Debug for AwaitableState<T> {
149 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150 self.state.read().fmt(f)
151 }
152}
153
154impl<T> Clone for AwaitableState<T> {
155 fn clone(&self) -> Self {
156 Self {
157 state: Arc::clone(&self.state),
158 notify: Arc::clone(&self.notify),
159 }
160 }
161}
162
163pub struct ModifyGuard<'a, T> {
166 mut_ref: &'a mut T,
167 modified: bool,
168}
169
170impl<'a, T> Deref for ModifyGuard<'a, T> {
171 type Target = T;
172
173 fn deref(&self) -> &Self::Target {
174 &*self.mut_ref
175 }
176}
177
178impl<'a, T> ModifyGuard<'a, T> {
179 pub fn get_mut(&mut self) -> &mut T {
180 self.modified = true;
181 &mut *self.mut_ref
182 }
183}
184
185impl<T> AwaitableState<T> {
186 pub fn new(value: T) -> Self {
187 Self {
188 state: Arc::new(RwLock::new(value)),
189 notify: Arc::new(Notify::new()),
190 }
191 }
192
193 #[allow(dead_code)]
194 pub fn read<A>(&self, read_fn: impl FnOnce(&T) -> A) -> A {
195 let guard = self.state.read().expect("not poisoned");
196 let state = &*guard;
197 read_fn(state)
198 }
199
200 pub fn maybe_modify<A>(&self, write_fn: impl FnOnce(&mut ModifyGuard<T>) -> A) -> A {
204 let mut guard = self.state.write().expect("not poisoned");
205 let mut state = ModifyGuard {
206 mut_ref: &mut *guard,
207 modified: false,
208 };
209 let result = write_fn(&mut state);
210 if state.modified {
213 self.notify.notify_waiters();
214 }
215 drop(guard);
216 result
217 }
218
219 pub fn modify<A>(&self, write_fn: impl FnOnce(&mut T) -> A) -> A {
220 self.maybe_modify(|guard| write_fn(guard.get_mut()))
221 }
222
223 pub async fn wait_for<A>(&self, mut wait_fn: impl FnMut(&T) -> Option<A>) -> A {
224 loop {
225 let notified = {
226 let guard = self.state.read().expect("not poisoned");
227 let state = &*guard;
228 if let Some(result) = wait_fn(state) {
229 return result;
230 }
231 let notified = self.notify.notified();
234 drop(guard);
235 notified
236 };
237
238 notified.await;
239 }
240 }
241
242 pub async fn wait_while(&self, mut wait_fn: impl FnMut(&T) -> bool) {
243 self.wait_for(|s| (!wait_fn(s)).then_some(())).await
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use std::future::Future;
250 use std::pin::Pin;
251 use std::task::Context;
252 use std::time::Duration;
253
254 use futures::FutureExt;
255 use futures_task::noop_waker;
256 use itertools::Itertools;
257 use mz_build_info::DUMMY_BUILD_INFO;
258 use mz_dyncfg::ConfigUpdates;
259 use mz_ore::assert_none;
260 use mz_ore::cast::CastFrom;
261 use mz_ore::metrics::MetricsRegistry;
262 use rand::prelude::SliceRandom;
263 use timely::progress::Antichain;
264 use tokio::task::JoinSet;
265
266 use crate::cache::StateCache;
267 use crate::cfg::PersistConfig;
268 use crate::internal::machine::{
269 NEXT_LISTEN_BATCH_RETRYER_CLAMP, NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
270 NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER,
271 };
272 use crate::internal::state::TypedState;
273 use crate::tests::new_test_client;
274 use crate::{Diagnostics, ShardId};
275
276 use super::*;
277
278 #[mz_ore::test(tokio::test)]
279 async fn state_watch() {
280 mz_ore::test::init_logging();
281 let metrics = Arc::new(Metrics::new(
282 &PersistConfig::new_for_tests(),
283 &MetricsRegistry::new(),
284 ));
285 let cache = StateCache::new_no_metrics();
286 let shard_id = ShardId::new();
287 let state = cache
288 .get::<(), (), u64, i64, _, _>(
289 shard_id,
290 || async {
291 Ok(TypedState::new(
292 DUMMY_BUILD_INFO.semver_version(),
293 shard_id,
294 "host".to_owned(),
295 0u64,
296 ))
297 },
298 &Diagnostics::for_tests(),
299 )
300 .await
301 .unwrap();
302 assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
303
304 let mut w0 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
306 let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
307
308 let w0s1 = w0.wait_for_seqno_ge(SeqNo(1)).map(|_| ()).shared();
310 assert_eq!(w0s1.clone().now_or_never(), None);
311
312 state.write_lock(&metrics.locks.applier_write, |state| {
314 state.seqno = state.seqno.next()
315 });
316 let () = w0s1.await;
317
318 let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
320
321 let mut w1 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
323 let _ = w1.wait_for_seqno_ge(SeqNo(0)).await;
324 let _ = w1.wait_for_seqno_ge(SeqNo(1)).await;
325 assert_none!(w1.wait_for_seqno_ge(SeqNo(2)).now_or_never());
326 }
327
328 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
329 #[cfg_attr(miri, ignore)] async fn state_watch_concurrency() {
331 mz_ore::test::init_logging();
332 let metrics = Arc::new(Metrics::new(
333 &PersistConfig::new_for_tests(),
334 &MetricsRegistry::new(),
335 ));
336 let cache = StateCache::new_no_metrics();
337 let shard_id = ShardId::new();
338 let state = cache
339 .get::<(), (), u64, i64, _, _>(
340 shard_id,
341 || async {
342 Ok(TypedState::new(
343 DUMMY_BUILD_INFO.semver_version(),
344 shard_id,
345 "host".to_owned(),
346 0u64,
347 ))
348 },
349 &Diagnostics::for_tests(),
350 )
351 .await
352 .unwrap();
353 assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
354
355 const NUM_WATCHES: usize = 100;
356 const NUM_WRITES: usize = 20;
357
358 let watches = (0..NUM_WATCHES)
359 .map(|idx| {
360 let state = Arc::clone(&state);
361 let metrics = Arc::clone(&metrics);
362 mz_ore::task::spawn(|| "watch", async move {
363 let mut watch = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
364 let wait_seqno = SeqNo(u64::cast_from(idx % NUM_WRITES + 1));
366 let _ = watch.wait_for_seqno_ge(wait_seqno).await;
367 let observed_seqno =
368 state.read_lock(&metrics.locks.applier_read_noncacheable, |x| x.seqno);
369 assert!(
370 wait_seqno <= observed_seqno,
371 "{} vs {}",
372 wait_seqno,
373 observed_seqno
374 );
375 })
376 })
377 .collect::<Vec<_>>();
378 let writes = (0..NUM_WRITES)
379 .map(|_| {
380 let state = Arc::clone(&state);
381 let metrics = Arc::clone(&metrics);
382 mz_ore::task::spawn(|| "write", async move {
383 state.write_lock(&metrics.locks.applier_write, |x| {
384 x.seqno = x.seqno.next();
385 });
386 })
387 })
388 .collect::<Vec<_>>();
389 for watch in watches {
390 watch.await;
391 }
392 for write in writes {
393 write.await;
394 }
395 }
396
397 #[mz_persist_proc::test(tokio::test)]
398 #[cfg_attr(miri, ignore)] async fn state_watch_listen_snapshot(dyncfgs: ConfigUpdates) {
400 mz_ore::test::init_logging();
401 let waker = noop_waker();
402 let mut cx = Context::from_waker(&waker);
403
404 let client = new_test_client(&dyncfgs).await;
405 client.cfg.set_config(
407 &NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
408 Duration::from_secs(1_000_000),
409 );
410 client
411 .cfg
412 .set_config(&NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER, 1);
413 client.cfg.set_config(
414 &NEXT_LISTEN_BATCH_RETRYER_CLAMP,
415 Duration::from_secs(1_000_000),
416 );
417
418 let (mut write, mut read) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;
419
420 let mut listen = read
423 .clone("test")
424 .await
425 .listen(Antichain::from_elem(0))
426 .await
427 .unwrap();
428 let mut snapshot = Box::pin(read.snapshot(Antichain::from_elem(0)));
429 assert!(Pin::new(&mut snapshot).poll(&mut cx).is_pending());
430 let mut listen_next_batch = Box::pin(listen.next(None));
431 assert!(Pin::new(&mut listen_next_batch).poll(&mut cx).is_pending());
432
433 write.expect_compare_and_append(&[], 0, 1).await;
438 let _ = listen_next_batch.await;
439
440 let _ = snapshot.await;
443 }
444
445 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
446 #[allow(clippy::disallowed_methods)] #[cfg_attr(miri, ignore)]
448 async fn wait_on_awaitable_state() {
449 const TASKS: usize = 1000;
450 let mut set = JoinSet::new();
453 let state = AwaitableState::new(0);
454 let mut tasks = (0..TASKS).collect_vec();
455 let mut rng = rand::rng();
456 tasks.shuffle(&mut rng);
457 for i in (0..TASKS).rev() {
458 set.spawn({
459 let state = state.clone();
460 async move {
461 state.wait_while(|v| *v != i).await;
462 state.modify(|v| *v += 1);
463 }
464 });
465 }
466 set.join_all().await;
467 assert_eq!(state.read(|i| *i), TASKS);
468 }
469}