mz_persist_client/internal/
watch.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Notifications for state changes.
11
12use std::sync::Arc;
13
14use mz_persist::location::SeqNo;
15use tokio::sync::broadcast;
16use tracing::debug;
17
18use crate::cache::LockingTypedState;
19use crate::internal::metrics::Metrics;
20
21#[derive(Debug)]
22pub struct StateWatchNotifier {
23    metrics: Arc<Metrics>,
24    tx: broadcast::Sender<SeqNo>,
25}
26
27impl StateWatchNotifier {
28    pub(crate) fn new(metrics: Arc<Metrics>) -> Self {
29        let (tx, _rx) = broadcast::channel(1);
30        StateWatchNotifier { metrics, tx }
31    }
32
33    /// Wake up any watchers of this state.
34    ///
35    /// This must be called while under the same lock that modified the state to
36    /// avoid any potential for out of order SeqNos in the broadcast channel.
37    ///
38    /// This restriction can be lifted (i.e. we could notify after releasing the
39    /// write lock), but we'd have to reason about out of order SeqNos in the
40    /// broadcast channel. In particular, if we see `RecvError::Lagged` then
41    /// it's possible we lost X+1 and got X, so if X isn't sufficient to return,
42    /// we'd need to grab the read lock and verify the real SeqNo.
43    pub(crate) fn notify(&self, seqno: SeqNo) {
44        match self.tx.send(seqno) {
45            // Someone got woken up.
46            Ok(_) => {
47                self.metrics.watch.notify_sent.inc();
48            }
49            // No one is listening, that's also fine.
50            Err(_) => {
51                self.metrics.watch.notify_noop.inc();
52            }
53        }
54    }
55}
56
57/// A reactive subscription to changes in [LockingTypedState].
58///
59/// Invariants:
60/// - The `state.seqno` only advances (never regresses). This is guaranteed by
61///   LockingTypedState.
62/// - `seqno_high_water` is always <= `state.seqno`.
63/// - If `seqno_high_water` is < `state.seqno`, then we'll get a notification on
64///   `rx`. This is maintained by notifying new seqnos under the same lock which
65///   adds them.
66/// - `seqno_high_water` always holds the highest value received in the channel
67///   This is maintained by `wait_for_seqno_gt` taking an exclusive reference to
68///   self.
69#[derive(Debug)]
70pub struct StateWatch<K, V, T, D> {
71    metrics: Arc<Metrics>,
72    state: Arc<LockingTypedState<K, V, T, D>>,
73    seqno_high_water: SeqNo,
74    rx: broadcast::Receiver<SeqNo>,
75}
76
77impl<K, V, T, D> StateWatch<K, V, T, D> {
78    pub(crate) fn new(state: Arc<LockingTypedState<K, V, T, D>>, metrics: Arc<Metrics>) -> Self {
79        // Important! We have to subscribe to the broadcast channel _before_ we
80        // grab the current seqno. Otherwise, we could race with a write to
81        // state and miss a notification. Tokio guarantees that "the returned
82        // Receiver will receive values sent after the call to subscribe", and
83        // the read_lock linearizes the subscribe to be _before_ whatever
84        // seqno_high_water we get here.
85        let rx = state.notifier().tx.subscribe();
86        let seqno_high_water = state.read_lock(&metrics.locks.watch, |x| x.seqno);
87        StateWatch {
88            metrics,
89            state,
90            seqno_high_water,
91            rx,
92        }
93    }
94
95    /// Blocks until the State has a SeqNo >= the requested one.
96    ///
97    /// This method is cancel-safe.
98    pub async fn wait_for_seqno_ge(&mut self, requested: SeqNo) -> &mut Self {
99        self.metrics.watch.notify_wait_started.inc();
100        debug!("wait_for_seqno_ge {} {}", self.state.shard_id(), requested);
101        loop {
102            if self.seqno_high_water >= requested {
103                break;
104            }
105            match self.rx.recv().await {
106                Ok(x) => {
107                    self.metrics.watch.notify_recv.inc();
108                    assert!(x >= self.seqno_high_water);
109                    self.seqno_high_water = x;
110                }
111                Err(broadcast::error::RecvError::Closed) => {
112                    unreachable!("we're holding on to a reference to the sender")
113                }
114                Err(broadcast::error::RecvError::Lagged(_)) => {
115                    self.metrics.watch.notify_lagged.inc();
116                    // This is just a hint that our buffer (of size 1) filled
117                    // up, which is totally fine. The broadcast channel
118                    // guarantees that the most recent N (again, =1 here) are
119                    // kept, so just loop around. This branch means we should be
120                    // able to read a new value immediately.
121                    continue;
122                }
123            }
124        }
125        self.metrics.watch.notify_wait_finished.inc();
126        debug!(
127            "wait_for_seqno_ge {} {} returning",
128            self.state.shard_id(),
129            requested
130        );
131        self
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::future::Future;
138    use std::pin::Pin;
139    use std::task::Context;
140    use std::time::Duration;
141
142    use futures::FutureExt;
143    use futures_task::noop_waker;
144    use mz_build_info::DUMMY_BUILD_INFO;
145    use mz_dyncfg::ConfigUpdates;
146    use mz_ore::cast::CastFrom;
147    use mz_ore::metrics::MetricsRegistry;
148    use mz_ore::{assert_none, assert_ok};
149    use timely::progress::Antichain;
150
151    use crate::cache::StateCache;
152    use crate::cfg::PersistConfig;
153    use crate::internal::machine::{
154        NEXT_LISTEN_BATCH_RETRYER_CLAMP, NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
155        NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER,
156    };
157    use crate::internal::state::TypedState;
158    use crate::tests::new_test_client;
159    use crate::{Diagnostics, ShardId};
160
161    use super::*;
162
163    #[mz_ore::test(tokio::test)]
164    async fn state_watch() {
165        mz_ore::test::init_logging();
166        let metrics = Arc::new(Metrics::new(
167            &PersistConfig::new_for_tests(),
168            &MetricsRegistry::new(),
169        ));
170        let cache = StateCache::new_no_metrics();
171        let shard_id = ShardId::new();
172        let state = cache
173            .get::<(), (), u64, i64, _, _>(
174                shard_id,
175                || async {
176                    Ok(TypedState::new(
177                        DUMMY_BUILD_INFO.semver_version(),
178                        shard_id,
179                        "host".to_owned(),
180                        0u64,
181                    ))
182                },
183                &Diagnostics::for_tests(),
184            )
185            .await
186            .unwrap();
187        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
188
189        // A watch for 0 resolves immediately.
190        let mut w0 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
191        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
192
193        // A watch for 1 does not yet resolve.
194        let w0s1 = w0.wait_for_seqno_ge(SeqNo(1)).map(|_| ()).shared();
195        assert_eq!(w0s1.clone().now_or_never(), None);
196
197        // After mutating state, the watch for 1 does resolve.
198        state.write_lock(&metrics.locks.applier_write, |state| {
199            state.seqno = state.seqno.next()
200        });
201        let () = w0s1.await;
202
203        // A watch for an old seqno immediately resolves.
204        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;
205
206        // We can create a new watch and it also behaves.
207        let mut w1 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
208        let _ = w1.wait_for_seqno_ge(SeqNo(0)).await;
209        let _ = w1.wait_for_seqno_ge(SeqNo(1)).await;
210        assert_none!(w1.wait_for_seqno_ge(SeqNo(2)).now_or_never());
211    }
212
213    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
214    #[cfg_attr(miri, ignore)] // error: unsupported operation: integer-to-pointer casts and `ptr::from_exposed_addr` are not supported with `-Zmiri-strict-provenance`
215    async fn state_watch_concurrency() {
216        mz_ore::test::init_logging();
217        let metrics = Arc::new(Metrics::new(
218            &PersistConfig::new_for_tests(),
219            &MetricsRegistry::new(),
220        ));
221        let cache = StateCache::new_no_metrics();
222        let shard_id = ShardId::new();
223        let state = cache
224            .get::<(), (), u64, i64, _, _>(
225                shard_id,
226                || async {
227                    Ok(TypedState::new(
228                        DUMMY_BUILD_INFO.semver_version(),
229                        shard_id,
230                        "host".to_owned(),
231                        0u64,
232                    ))
233                },
234                &Diagnostics::for_tests(),
235            )
236            .await
237            .unwrap();
238        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));
239
240        const NUM_WATCHES: usize = 100;
241        const NUM_WRITES: usize = 20;
242
243        let watches = (0..NUM_WATCHES)
244            .map(|idx| {
245                let state = Arc::clone(&state);
246                let metrics = Arc::clone(&metrics);
247                mz_ore::task::spawn(|| "watch", async move {
248                    let mut watch = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
249                    // We stared at 0, so N writes means N+1 seqnos.
250                    let wait_seqno = SeqNo(u64::cast_from(idx % NUM_WRITES + 1));
251                    let _ = watch.wait_for_seqno_ge(wait_seqno).await;
252                    let observed_seqno =
253                        state.read_lock(&metrics.locks.applier_read_noncacheable, |x| x.seqno);
254                    assert!(
255                        wait_seqno <= observed_seqno,
256                        "{} vs {}",
257                        wait_seqno,
258                        observed_seqno
259                    );
260                })
261            })
262            .collect::<Vec<_>>();
263        let writes = (0..NUM_WRITES)
264            .map(|_| {
265                let state = Arc::clone(&state);
266                let metrics = Arc::clone(&metrics);
267                mz_ore::task::spawn(|| "write", async move {
268                    state.write_lock(&metrics.locks.applier_write, |x| {
269                        x.seqno = x.seqno.next();
270                    });
271                })
272            })
273            .collect::<Vec<_>>();
274        for watch in watches {
275            assert_ok!(watch.await);
276        }
277        for write in writes {
278            assert_ok!(write.await);
279        }
280    }
281
282    #[mz_persist_proc::test(tokio::test)]
283    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
284    async fn state_watch_listen_snapshot(dyncfgs: ConfigUpdates) {
285        mz_ore::test::init_logging();
286        let waker = noop_waker();
287        let mut cx = Context::from_waker(&waker);
288
289        let client = new_test_client(&dyncfgs).await;
290        // Override the listen poll so that it's useless.
291        client.cfg.set_config(
292            &NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
293            Duration::from_secs(1_000_000),
294        );
295        client
296            .cfg
297            .set_config(&NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER, 1);
298        client.cfg.set_config(
299            &NEXT_LISTEN_BATCH_RETRYER_CLAMP,
300            Duration::from_secs(1_000_000),
301        );
302
303        let (mut write, mut read) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;
304
305        // Grab a snapshot for 1, which doesn't resolve yet. Also grab a listen
306        // for 0, which resolves but doesn't yet resolve the next batch.
307        let mut listen = read
308            .clone("test")
309            .await
310            .listen(Antichain::from_elem(0))
311            .await
312            .unwrap();
313        let mut snapshot = Box::pin(read.snapshot(Antichain::from_elem(0)));
314        assert!(Pin::new(&mut snapshot).poll(&mut cx).is_pending());
315        let mut listen_next_batch = Box::pin(listen.next(None));
316        assert!(Pin::new(&mut listen_next_batch).poll(&mut cx).is_pending());
317
318        // Now update the frontier, which should allow the snapshot to resolve
319        // and the listen to resolve its next batch. Because we disabled the
320        // polling, the listen_next_batch future will block forever and timeout
321        // the test if the watch doesn't work.
322        write.expect_compare_and_append(&[], 0, 1).await;
323        let _ = listen_next_batch.await;
324
325        // For good measure, also resolve the snapshot, though we haven't broken
326        // the polling on this.
327        let _ = snapshot.await;
328    }
329}