Skip to main content

mz_persist_client/internal/
watch.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Notifications for state changes.
11
12use mz_persist::location::SeqNo;
13use std::fmt::{Debug, Formatter};
14use std::ops::Deref;
15use std::sync::{Arc, RwLock};
16use tokio::sync::{Notify, broadcast};
17use tracing::debug;
18
19use crate::cache::LockingTypedState;
20use crate::internal::metrics::Metrics;
21
22#[derive(Debug)]
23pub struct StateWatchNotifier {
24    metrics: Arc<Metrics>,
25    tx: broadcast::Sender<SeqNo>,
26}
27
28impl StateWatchNotifier {
29    pub(crate) fn new(metrics: Arc<Metrics>) -> Self {
30        let (tx, _rx) = broadcast::channel(1);
31        StateWatchNotifier { metrics, tx }
32    }
33
34    /// Wake up any watchers of this state.
35    ///
36    /// This must be called while under the same lock that modified the state to
37    /// avoid any potential for out of order SeqNos in the broadcast channel.
38    ///
39    /// This restriction can be lifted (i.e. we could notify after releasing the
40    /// write lock), but we'd have to reason about out of order SeqNos in the
41    /// broadcast channel. In particular, if we see `RecvError::Lagged` then
42    /// it's possible we lost X+1 and got X, so if X isn't sufficient to return,
43    /// we'd need to grab the read lock and verify the real SeqNo.
44    pub(crate) fn notify(&self, seqno: SeqNo) {
45        match self.tx.send(seqno) {
46            // Someone got woken up.
47            Ok(_) => {
48                self.metrics.watch.notify_sent.inc();
49            }
50            // No one is listening, that's also fine.
51            Err(_) => {
52                self.metrics.watch.notify_noop.inc();
53            }
54        }
55    }
56}
57
58/// A reactive subscription to changes in [LockingTypedState].
59///
60/// Invariants:
61/// - The `state.seqno` only advances (never regresses). This is guaranteed by
62///   LockingTypedState.
63/// - `seqno_high_water` is always <= `state.seqno`.
64/// - If `seqno_high_water` is < `state.seqno`, then we'll get a notification on
65///   `rx`. This is maintained by notifying new seqnos under the same lock which
66///   adds them.
67/// - `seqno_high_water` always holds the highest value received in the channel
68///   This is maintained by `wait_for_seqno_gt` taking an exclusive reference to
69///   self.
70#[derive(Debug)]
71pub struct StateWatch<K, V, T, D> {
72    metrics: Arc<Metrics>,
73    state: Arc<LockingTypedState<K, V, T, D>>,
74    seqno_high_water: SeqNo,
75    rx: broadcast::Receiver<SeqNo>,
76}
77
78impl<K, V, T, D> StateWatch<K, V, T, D> {
79    pub(crate) fn new(state: Arc<LockingTypedState<K, V, T, D>>, metrics: Arc<Metrics>) -> Self {
80        // Important! We have to subscribe to the broadcast channel _before_ we
81        // grab the current seqno. Otherwise, we could race with a write to
82        // state and miss a notification. Tokio guarantees that "the returned
83        // Receiver will receive values sent after the call to subscribe", and
84        // the read_lock linearizes the subscribe to be _before_ whatever
85        // seqno_high_water we get here.
86        let rx = state.notifier().tx.subscribe();
87        let seqno_high_water = state.read_lock(&metrics.locks.watch, |x| x.seqno);
88        StateWatch {
89            metrics,
90            state,
91            seqno_high_water,
92            rx,
93        }
94    }
95
96    /// Blocks until the State has a SeqNo >= the requested one.
97    ///
98    /// This method is cancel-safe.
99    pub async fn wait_for_seqno_ge(&mut self, requested: SeqNo) -> &mut Self {
100        self.metrics.watch.notify_wait_started.inc();
101        debug!("wait_for_seqno_ge {} {}", self.state.shard_id(), requested);
102        loop {
103            if self.seqno_high_water >= requested {
104                break;
105            }
106            match self.rx.recv().await {
107                Ok(x) => {
108                    self.metrics.watch.notify_recv.inc();
109                    assert!(x >= self.seqno_high_water);
110                    self.seqno_high_water = x;
111                }
112                Err(broadcast::error::RecvError::Closed) => {
113                    unreachable!("we're holding on to a reference to the sender")
114                }
115                Err(broadcast::error::RecvError::Lagged(_)) => {
116                    self.metrics.watch.notify_lagged.inc();
117                    // This is just a hint that our buffer (of size 1) filled
118                    // up, which is totally fine. The broadcast channel
119                    // guarantees that the most recent N (again, =1 here) are
120                    // kept, so just loop around. This branch means we should be
121                    // able to read a new value immediately.
122                    continue;
123                }
124            }
125        }
126        self.metrics.watch.notify_wait_finished.inc();
127        debug!(
128            "wait_for_seqno_ge {} {} returning",
129            self.state.shard_id(),
130            requested
131        );
132        self
133    }
134}
135
136/// A concurrent state - one which allows reading, writing, and waiting for changes made by
137/// another concurrent writer.
138///
139/// This is morally similar to a mutex with a condvar, but allowing asynchronous waits and with
140/// access methods that make it a little trickier to accidentally hold a lock across a yield point.
141pub(crate) struct AwaitableState<T> {
142    state: Arc<RwLock<T>>,
143    /// NB: we can't wrap the [Notify] in the lock since the signature of [Notify::notified]
144    /// doesn't allow it, but this is only accessed while holding the lock.
145    notify: Arc<Notify>,
146}
147
148impl<T: Debug> Debug for AwaitableState<T> {
149    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
150        self.state.read().fmt(f)
151    }
152}
153
154impl<T> Clone for AwaitableState<T> {
155    fn clone(&self) -> Self {
156        Self {
157            state: Arc::clone(&self.state),
158            notify: Arc::clone(&self.notify),
159        }
160    }
161}
162
163/// A wrapper around a mutable ref that tracks whether it's ever accessed mutably. See
164/// [AwaitableState::maybe_modify] for usage.
165pub struct ModifyGuard<'a, T> {
166    mut_ref: &'a mut T,
167    modified: bool,
168}
169
170impl<'a, T> Deref for ModifyGuard<'a, T> {
171    type Target = T;
172
173    fn deref(&self) -> &Self::Target {
174        &*self.mut_ref
175    }
176}
177
178impl<'a, T> ModifyGuard<'a, T> {
179    pub fn get_mut(&mut self) -> &mut T {
180        self.modified = true;
181        &mut *self.mut_ref
182    }
183}
184
185impl<T> AwaitableState<T> {
186    pub fn new(value: T) -> Self {
187        Self {
188            state: Arc::new(RwLock::new(value)),
189            notify: Arc::new(Notify::new()),
190        }
191    }
192
193    #[allow(dead_code)]
194    pub fn read<A>(&self, read_fn: impl FnOnce(&T) -> A) -> A {
195        let guard = self.state.read().expect("not poisoned");
196        let state = &*guard;
197        read_fn(state)
198    }
199
200    /// Conditionally modify the state. This method passes a guard to the provided function,
201    /// which only allows mutable access to the data via [ModifyGuard::get_mut]. If that method
202    /// is not called, waiters will not be woken up.
203    pub fn maybe_modify<A>(&self, write_fn: impl FnOnce(&mut ModifyGuard<T>) -> A) -> A {
204        let mut guard = self.state.write().expect("not poisoned");
205        let mut state = ModifyGuard {
206            mut_ref: &mut *guard,
207            modified: false,
208        };
209        let result = write_fn(&mut state);
210        // Notify everyone while holding the guard. This guarantees that all waiters will observe
211        // the just-updated state, assuming the state was accessed mutably.
212        if state.modified {
213            self.notify.notify_waiters();
214        }
215        drop(guard);
216        result
217    }
218
219    pub fn modify<A>(&self, write_fn: impl FnOnce(&mut T) -> A) -> A {
220        self.maybe_modify(|guard| write_fn(guard.get_mut()))
221    }
222
223    pub async fn wait_for<A>(&self, mut wait_fn: impl FnMut(&T) -> Option<A>) -> A {
224        loop {
225            let notified = {
226                let guard = self.state.read().expect("not poisoned");
227                let state = &*guard;
228                if let Some(result) = wait_fn(state) {
229                    return result;
230                }
231                // Grab the notified future while holding the guard. This ensures that we will see any
232                // future modifications to this state, even if they happen before the first poll.
233                let notified = self.notify.notified();
234                drop(guard);
235                notified
236            };
237
238            notified.await;
239        }
240    }
241
242    pub async fn wait_while(&self, mut wait_fn: impl FnMut(&T) -> bool) {
243        self.wait_for(|s| (!wait_fn(s)).then_some(())).await
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use std::future::Future;
250    use std::pin::Pin;
251    use std::task::Context;
252    use std::time::Duration;
253
254    use futures::FutureExt;
255    use futures_task::noop_waker;
256    use itertools::Itertools;
257    use mz_build_info::DUMMY_BUILD_INFO;
258    use mz_dyncfg::ConfigUpdates;
259    use mz_ore::assert_none;
260    use mz_ore::cast::CastFrom;
261    use mz_ore::metrics::MetricsRegistry;
262    use rand::prelude::SliceRandom;
263    use timely::progress::Antichain;
264    use tokio::task::JoinSet;
265
266    use crate::cache::StateCache;
267    use crate::cfg::PersistConfig;
268    use crate::internal::machine::{
269        NEXT_LISTEN_BATCH_RETRYER_CLAMP, NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
270        NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER,
271    };
272    use crate::internal::state::TypedState;
273    use crate::tests::new_test_client;
274    use crate::{Diagnostics, ShardId};
275
276    use super::*;
277
278    #[mz_ore::test(tokio::test)]
279    async fn state_watch() {
280        mz_ore::test::init_logging();
281        let metrics = Arc::new(Metrics::new(
282            &PersistConfig::new_for_tests(),
283            &MetricsRegistry::new(),
284        ));
285        let cache = StateCache::new_no_metrics();
286        let shard_id = ShardId::new();
287        let state = cache
288            .get::<(), (), u64, i64, _, _>(
289                shard_id,
290                || async {
291                    Ok(TypedState::new(
292                        DUMMY_BUILD_INFO.semver_version(),
293                        shard_id,
294                        "host".to_owned(),
295                        0u64,
296                    ))
297                },
298                &Diagnostics::for_tests(),
299            )
300            .await
301            .unwrap();
302        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
303
304        // A watch for 0 resolves immediately.
305        let mut w0 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
306        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
307
308        // A watch for 1 does not yet resolve.
309        let w0s1 = w0.wait_for_seqno_ge(SeqNo(1)).map(|_| ()).shared();
310        assert_eq!(w0s1.clone().now_or_never(), None);
311
312        // After mutating state, the watch for 1 does resolve.
313        state.write_lock(&metrics.locks.applier_write, |state| {
314            state.seqno = state.seqno.next()
315        });
316        let () = w0s1.await;
317
318        // A watch for an old seqno immediately resolves.
319        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
320
321        // We can create a new watch and it also behaves.
322        let mut w1 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
323        let _ = w1.wait_for_seqno_ge(SeqNo(0)).await;
324        let _ = w1.wait_for_seqno_ge(SeqNo(1)).await;
325        assert_none!(w1.wait_for_seqno_ge(SeqNo(2)).now_or_never());
326    }
327
328    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
329    #[cfg_attr(miri, ignore)] // error: unsupported operation: integer-to-pointer casts and `ptr::from_exposed_addr` are not supported with `-Zmiri-strict-provenance`
330    async fn state_watch_concurrency() {
331        mz_ore::test::init_logging();
332        let metrics = Arc::new(Metrics::new(
333            &PersistConfig::new_for_tests(),
334            &MetricsRegistry::new(),
335        ));
336        let cache = StateCache::new_no_metrics();
337        let shard_id = ShardId::new();
338        let state = cache
339            .get::<(), (), u64, i64, _, _>(
340                shard_id,
341                || async {
342                    Ok(TypedState::new(
343                        DUMMY_BUILD_INFO.semver_version(),
344                        shard_id,
345                        "host".to_owned(),
346                        0u64,
347                    ))
348                },
349                &Diagnostics::for_tests(),
350            )
351            .await
352            .unwrap();
353        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
354
355        const NUM_WATCHES: usize = 100;
356        const NUM_WRITES: usize = 20;
357
358        let watches = (0..NUM_WATCHES)
359            .map(|idx| {
360                let state = Arc::clone(&state);
361                let metrics = Arc::clone(&metrics);
362                mz_ore::task::spawn(|| "watch", async move {
363                    let mut watch = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
364                    // We stared at 0, so N writes means N+1 seqnos.
365                    let wait_seqno = SeqNo(u64::cast_from(idx % NUM_WRITES + 1));
366                    let _ = watch.wait_for_seqno_ge(wait_seqno).await;
367                    let observed_seqno =
368                        state.read_lock(&metrics.locks.applier_read_noncacheable, |x| x.seqno);
369                    assert!(
370                        wait_seqno <= observed_seqno,
371                        "{} vs {}",
372                        wait_seqno,
373                        observed_seqno
374                    );
375                })
376            })
377            .collect::<Vec<_>>();
378        let writes = (0..NUM_WRITES)
379            .map(|_| {
380                let state = Arc::clone(&state);
381                let metrics = Arc::clone(&metrics);
382                mz_ore::task::spawn(|| "write", async move {
383                    state.write_lock(&metrics.locks.applier_write, |x| {
384                        x.seqno = x.seqno.next();
385                    });
386                })
387            })
388            .collect::<Vec<_>>();
389        for watch in watches {
390            watch.await;
391        }
392        for write in writes {
393            write.await;
394        }
395    }
396
397    #[mz_persist_proc::test(tokio::test)]
398    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
399    async fn state_watch_listen_snapshot(dyncfgs: ConfigUpdates) {
400        mz_ore::test::init_logging();
401        let waker = noop_waker();
402        let mut cx = Context::from_waker(&waker);
403
404        let client = new_test_client(&dyncfgs).await;
405        // Override the listen poll so that it's useless.
406        client.cfg.set_config(
407            &NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
408            Duration::from_secs(1_000_000),
409        );
410        client
411            .cfg
412            .set_config(&NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER, 1);
413        client.cfg.set_config(
414            &NEXT_LISTEN_BATCH_RETRYER_CLAMP,
415            Duration::from_secs(1_000_000),
416        );
417
418        let (mut write, mut read) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;
419
420        // Grab a snapshot for 1, which doesn't resolve yet. Also grab a listen
421        // for 0, which resolves but doesn't yet resolve the next batch.
422        let mut listen = read
423            .clone("test")
424            .await
425            .listen(Antichain::from_elem(0))
426            .await
427            .unwrap();
428        let mut snapshot = Box::pin(read.snapshot(Antichain::from_elem(0)));
429        assert!(Pin::new(&mut snapshot).poll(&mut cx).is_pending());
430        let mut listen_next_batch = Box::pin(listen.next(None));
431        assert!(Pin::new(&mut listen_next_batch).poll(&mut cx).is_pending());
432
433        // Now update the frontier, which should allow the snapshot to resolve
434        // and the listen to resolve its next batch. Because we disabled the
435        // polling, the listen_next_batch future will block forever and timeout
436        // the test if the watch doesn't work.
437        write.expect_compare_and_append(&[], 0, 1).await;
438        let _ = listen_next_batch.await;
439
440        // For good measure, also resolve the snapshot, though we haven't broken
441        // the polling on this.
442        let _ = snapshot.await;
443    }
444
445    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
446    #[allow(clippy::disallowed_methods)] // For JoinSet.
447    #[cfg_attr(miri, ignore)]
448    async fn wait_on_awaitable_state() {
449        const TASKS: usize = 1000;
450        // Launch a bunch of tasks, have them all wait for a specific number, then increment it
451        // by one. Lost notifications would cause this test to time out.
452        let mut set = JoinSet::new();
453        let state = AwaitableState::new(0);
454        let mut tasks = (0..TASKS).collect_vec();
455        let mut rng = rand::rng();
456        tasks.shuffle(&mut rng);
457        for i in (0..TASKS).rev() {
458            set.spawn({
459                let state = state.clone();
460                async move {
461                    state.wait_while(|v| *v != i).await;
462                    state.modify(|v| *v += 1);
463                }
464            });
465        }
466        set.join_all().await;
467        assert_eq!(state.read(|i| *i), TASKS);
468    }
469}