1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
// Copyright Materialize, Inc. and contributors. All rights reserved.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

//! Notifications for state changes.

use std::sync::Arc;

use mz_persist::location::SeqNo;
use tokio::sync::broadcast;
use tracing::debug;

use crate::cache::LockingTypedState;
use crate::internal::metrics::Metrics;

#[derive(Debug)]
pub struct StateWatchNotifier {
    metrics: Arc<Metrics>,
    tx: broadcast::Sender<SeqNo>,
}

impl StateWatchNotifier {
    pub(crate) fn new(metrics: Arc<Metrics>) -> Self {
        let (tx, _rx) = broadcast::channel(1);
        StateWatchNotifier { metrics, tx }
    }

    /// Wake up any watchers of this state.
    ///
    /// This must be called while under the same lock that modified the state to
    /// avoid any potential for out of order SeqNos in the broadcast channel.
    ///
    /// This restriction can be lifted (i.e. we could notify after releasing the
    /// write lock), but we'd have to reason about out of order SeqNos in the
    /// broadcast channel. In particular, if we see `RecvError::Lagged` then
    /// it's possible we lost X+1 and got X, so if X isn't sufficient to return,
    /// we'd need to grab the read lock and verify the real SeqNo.
    pub(crate) fn notify(&self, seqno: SeqNo) {
        match self.tx.send(seqno) {
            // Someone got woken up.
            Ok(_) => {
                self.metrics.watch.notify_sent.inc();
            }
            // No one is listening, that's also fine.
            Err(_) => {
                self.metrics.watch.notify_noop.inc();
            }
        }
    }
}

/// A reactive subscription to changes in [LockingTypedState].
///
/// Invariants:
/// - The `state.seqno` only advances (never regresses). This is guaranteed by
///   LockingTypedState.
/// - `seqno_high_water` is always <= `state.seqno`.
/// - If `seqno_high_water` is < `state.seqno`, then we'll get a notification on
///   `rx`. This is maintained by notifying new seqnos under the same lock which
///   adds them.
/// - `seqno_high_water` always holds the highest value received in the channel
///   This is maintained by `wait_for_seqno_gt` taking an exclusive reference to
///   self.
#[derive(Debug)]
pub struct StateWatch<K, V, T, D> {
    metrics: Arc<Metrics>,
    state: Arc<LockingTypedState<K, V, T, D>>,
    seqno_high_water: SeqNo,
    rx: broadcast::Receiver<SeqNo>,
}

impl<K, V, T, D> StateWatch<K, V, T, D> {
    pub(crate) fn new(state: Arc<LockingTypedState<K, V, T, D>>, metrics: Arc<Metrics>) -> Self {
        // Important! We have to subscribe to the broadcast channel _before_ we
        // grab the current seqno. Otherwise, we could race with a write to
        // state and miss a notification. Tokio guarantees that "the returned
        // Receiver will receive values sent after the call to subscribe", and
        // the read_lock linearizes the subscribe to be _before_ whatever
        // seqno_high_water we get here.
        let rx = state.notifier().tx.subscribe();
        let seqno_high_water = state.read_lock(&metrics.locks.watch, |x| x.seqno);
        StateWatch {
            metrics,
            state,
            seqno_high_water,
            rx,
        }
    }

    /// Blocks until the State has a SeqNo >= the requested one.
    ///
    /// This method is cancel-safe.
    pub async fn wait_for_seqno_ge(&mut self, requested: SeqNo) -> &mut Self {
        self.metrics.watch.notify_wait_started.inc();
        debug!("wait_for_seqno_ge {} {}", self.state.shard_id(), requested);
        loop {
            if self.seqno_high_water >= requested {
                break;
            }
            match self.rx.recv().await {
                Ok(x) => {
                    self.metrics.watch.notify_recv.inc();
                    assert!(x >= self.seqno_high_water);
                    self.seqno_high_water = x;
                }
                Err(broadcast::error::RecvError::Closed) => {
                    unreachable!("we're holding on to a reference to the sender")
                }
                Err(broadcast::error::RecvError::Lagged(_)) => {
                    self.metrics.watch.notify_lagged.inc();
                    // This is just a hint that our buffer (of size 1) filled
                    // up, which is totally fine. The broadcast channel
                    // guarantees that the most recent N (again, =1 here) are
                    // kept, so just loop around. This branch means we should be
                    // able to read a new value immediately.
                    continue;
                }
            }
        }
        self.metrics.watch.notify_wait_finished.inc();
        debug!(
            "wait_for_seqno_ge {} {} returning",
            self.state.shard_id(),
            requested
        );
        self
    }
}

#[cfg(test)]
mod tests {
    use std::future::Future;
    use std::pin::Pin;
    use std::task::Context;
    use std::time::Duration;

    use futures::FutureExt;
    use futures_task::noop_waker;
    use mz_build_info::DUMMY_BUILD_INFO;
    use mz_dyncfg::ConfigUpdates;
    use mz_ore::cast::CastFrom;
    use mz_ore::metrics::MetricsRegistry;
    use mz_ore::{assert_none, assert_ok};
    use timely::progress::Antichain;

    use crate::cache::StateCache;
    use crate::cfg::PersistConfig;
    use crate::internal::machine::{
        NEXT_LISTEN_BATCH_RETRYER_CLAMP, NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
        NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER,
    };
    use crate::internal::state::TypedState;
    use crate::tests::new_test_client;
    use crate::{Diagnostics, ShardId};

    use super::*;

    #[mz_ore::test(tokio::test)]
    async fn state_watch() {
        mz_ore::test::init_logging();
        let metrics = Arc::new(Metrics::new(
            &PersistConfig::new_for_tests(),
            &MetricsRegistry::new(),
        ));
        let cache = StateCache::new_no_metrics();
        let shard_id = ShardId::new();
        let state = cache
            .get::<(), (), u64, i64, _, _>(
                shard_id,
                || async {
                    Ok(TypedState::new(
                        DUMMY_BUILD_INFO.semver_version(),
                        shard_id,
                        "host".to_owned(),
                        0u64,
                    ))
                },
                &Diagnostics::for_tests(),
            )
            .await
            .unwrap();
        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));

        // A watch for 0 resolves immediately.
        let mut w0 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;

        // A watch for 1 does not yet resolve.
        let w0s1 = w0.wait_for_seqno_ge(SeqNo(1)).map(|_| ()).shared();
        assert_eq!(w0s1.clone().now_or_never(), None);

        // After mutating state, the watch for 1 does resolve.
        state.write_lock(&metrics.locks.applier_write, |state| {
            state.seqno = state.seqno.next()
        });
        let () = w0s1.await;

        // A watch for an old seqno immediately resolves.
        let _ = w0.wait_for_seqno_ge(SeqNo(0)).await;

        // We can create a new watch and it also behaves.
        let mut w1 = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
        let _ = w1.wait_for_seqno_ge(SeqNo(0)).await;
        let _ = w1.wait_for_seqno_ge(SeqNo(1)).await;
        assert_none!(w1.wait_for_seqno_ge(SeqNo(2)).now_or_never());
    }

    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
    #[cfg_attr(miri, ignore)] // error: unsupported operation: integer-to-pointer casts and `ptr::from_exposed_addr` are not supported with `-Zmiri-strict-provenance`
    async fn state_watch_concurrency() {
        mz_ore::test::init_logging();
        let metrics = Arc::new(Metrics::new(
            &PersistConfig::new_for_tests(),
            &MetricsRegistry::new(),
        ));
        let cache = StateCache::new_no_metrics();
        let shard_id = ShardId::new();
        let state = cache
            .get::<(), (), u64, i64, _, _>(
                shard_id,
                || async {
                    Ok(TypedState::new(
                        DUMMY_BUILD_INFO.semver_version(),
                        shard_id,
                        "host".to_owned(),
                        0u64,
                    ))
                },
                &Diagnostics::for_tests(),
            )
            .await
            .unwrap();
        assert_eq!(state.read_lock(&metrics.locks.watch, |x| x.seqno), SeqNo(0));

        const NUM_WATCHES: usize = 100;
        const NUM_WRITES: usize = 20;

        let watches = (0..NUM_WATCHES)
            .map(|idx| {
                let state = Arc::clone(&state);
                let metrics = Arc::clone(&metrics);
                mz_ore::task::spawn(|| "watch", async move {
                    let mut watch = StateWatch::new(Arc::clone(&state), Arc::clone(&metrics));
                    // We stared at 0, so N writes means N+1 seqnos.
                    let wait_seqno = SeqNo(u64::cast_from(idx % NUM_WRITES + 1));
                    let _ = watch.wait_for_seqno_ge(wait_seqno).await;
                    let observed_seqno =
                        state.read_lock(&metrics.locks.applier_read_noncacheable, |x| x.seqno);
                    assert!(
                        wait_seqno <= observed_seqno,
                        "{} vs {}",
                        wait_seqno,
                        observed_seqno
                    );
                })
            })
            .collect::<Vec<_>>();
        let writes = (0..NUM_WRITES)
            .map(|_| {
                let state = Arc::clone(&state);
                let metrics = Arc::clone(&metrics);
                mz_ore::task::spawn(|| "write", async move {
                    state.write_lock(&metrics.locks.applier_write, |x| {
                        x.seqno = x.seqno.next();
                    });
                })
            })
            .collect::<Vec<_>>();
        for watch in watches {
            assert_ok!(watch.await);
        }
        for write in writes {
            assert_ok!(write.await);
        }
    }

    #[mz_persist_proc::test(tokio::test)]
    #[cfg_attr(miri, ignore)] // unsupported operation: returning ready events from epoll_wait is not yet implemented
    async fn state_watch_listen_snapshot(dyncfgs: ConfigUpdates) {
        mz_ore::test::init_logging();
        let waker = noop_waker();
        let mut cx = Context::from_waker(&waker);

        let client = new_test_client(&dyncfgs).await;
        // Override the listen poll so that it's useless.
        client.cfg.set_config(
            &NEXT_LISTEN_BATCH_RETRYER_INITIAL_BACKOFF,
            Duration::from_secs(1_000_000),
        );
        client
            .cfg
            .set_config(&NEXT_LISTEN_BATCH_RETRYER_MULTIPLIER, 1);
        client.cfg.set_config(
            &NEXT_LISTEN_BATCH_RETRYER_CLAMP,
            Duration::from_secs(1_000_000),
        );

        let (mut write, mut read) = client.expect_open::<(), (), u64, i64>(ShardId::new()).await;

        // Grab a snapshot for 1, which doesn't resolve yet. Also grab a listen
        // for 0, which resolves but doesn't yet resolve the next batch.
        let mut listen = read
            .clone("test")
            .await
            .listen(Antichain::from_elem(0))
            .await
            .unwrap();
        let mut snapshot = Box::pin(read.snapshot(Antichain::from_elem(0)));
        assert!(Pin::new(&mut snapshot).poll(&mut cx).is_pending());
        let mut listen_next_batch = Box::pin(listen.next(None));
        assert!(Pin::new(&mut listen_next_batch).poll(&mut cx).is_pending());

        // Now update the frontier, which should allow the snapshot to resolve
        // and the listen to resolve its next batch. Because we disabled the
        // polling, the listen_next_batch future will block forever and timeout
        // the test if the watch doesn't work.
        write.expect_compare_and_append(&[], 0, 1).await;
        let _ = listen_next_batch.await;

        // For good measure, also resolve the snapshot, though we haven't broken
        // the polling on this.
        let _ = snapshot.await;
    }
}