mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet, ConfigUpdates};
24use mz_ore::cast::CastFrom;
25use mz_ore::task::JoinHandleExt;
26use mz_persist_client::cfg::{RetryParameters, USE_GLOBAL_TXN_CACHE_SOURCE};
27use mz_persist_client::operators::shard_source::{
28    ErrorHandler, FilterResult, SnapshotMode, shard_source,
29};
30use mz_persist_client::{Diagnostics, PersistClient, ShardId};
31use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
32use mz_persist_types::txn::TxnsCodec;
33use mz_persist_types::{Codec, Codec64, StepForward};
34use mz_timely_util::builder_async::{
35    AsyncInputHandle, Event as AsyncEvent, InputConnection,
36    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
37};
38use timely::container::CapacityContainerBuilder;
39use timely::dataflow::channels::pact::Pipeline;
40use timely::dataflow::operators::capture::Event;
41use timely::dataflow::operators::{Broadcast, Capture, Leave, Map, Probe};
42use timely::dataflow::{ProbeHandle, Scope, Stream};
43use timely::order::TotalOrder;
44use timely::progress::{Antichain, Timestamp};
45use timely::worker::Worker;
46use timely::{Data, PartialOrder, WorkerConfig};
47use tracing::debug;
48
49use crate::TxnsCodecDefault;
50use crate::txn_cache::TxnsCache;
51use crate::txn_read::{DataListenNext, DataRemapEntry, TxnsRead};
52
53/// An operator for translating physical data shard frontiers into logical ones.
54///
55/// A data shard in the txns set logically advances its upper each time a txn is
56/// committed, but the upper is not physically advanced unless that data shard
57/// was involved in the txn. This means that a shard_source (or any read)
58/// pointed at a data shard would appear to stall at the time of the most recent
59/// write. We fix this for shard_source by flowing its output through a new
60/// `txns_progress` dataflow operator, which ensures that the
61/// frontier/capability is advanced as the txns shard progresses, as long as the
62/// shard_source is up to date with the latest committed write to that data
63/// shard.
64///
65/// Example:
66///
67/// - A data shard has most recently been written to at 3.
68/// - The txns shard's upper is at 6.
69/// - We render a dataflow containing a shard_source with an as_of of 5.
70/// - A txn NOT involving the data shard is committed at 7.
71/// - A txn involving the data shard is committed at 9.
72///
73/// How it works:
74///
75/// - The shard_source operator is rendered. Its single output is hooked up as a
76///   _disconnected_ input to txns_progress. The txns_progress single output is
77///   a stream of the same type, which is used by downstream operators. This
78///   txns_progress operator is targeted at one data_shard; rendering a
79///   shard_source for a second data shard requires a second txns_progress
80///   operator.
81/// - The shard_source operator emits data through 3 and advances the frontier.
82/// - The txns_progress operator passes through these writes and frontier
83///   advancements unchanged. (Recall that it's always correct to read a data
84///   shard "normally", it just might stall.) Because the txns_progress operator
85///   knows there are no writes in `[3,5]`, it then downgrades its own
86///   capability past 5 (to 6). Because the input is disconnected, this means
87///   the overall frontier of the output is downgraded to 6.
88/// - The txns_progress operator learns about the write at 7 (the upper is now
89///   8). Because it knows that the data shard was not involved in this, it's
90///   free to downgrade its capability to 8.
91/// - The txns_progress operator learns about the write at 9 (the upper is now
92///   10). It knows that the data shard _WAS_ involved in this, so it forwards
93///   on data from its input until the input has progressed to 10, at which
94///   point it can itself downgrade to 10.
95pub fn txns_progress<K, V, T, D, P, C, F, G>(
96    passthrough: Stream<G, P>,
97    name: &str,
98    ctx: &TxnsContext,
99    worker_dyncfgs: &ConfigSet,
100    client_fn: impl Fn() -> F,
101    txns_id: ShardId,
102    data_id: ShardId,
103    as_of: T,
104    until: Antichain<T>,
105    data_key_schema: Arc<K::Schema>,
106    data_val_schema: Arc<V::Schema>,
107) -> (Stream<G, P>, Vec<PressOnDropButton>)
108where
109    K: Debug + Codec + Send + Sync,
110    V: Debug + Codec + Send + Sync,
111    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
112    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
113    P: Debug + Data,
114    C: TxnsCodec + 'static,
115    F: Future<Output = PersistClient> + Send + 'static,
116    G: Scope<Timestamp = T>,
117{
118    let unique_id = (name, passthrough.scope().addr()).hashed();
119    let (remap, source_button) = if USE_GLOBAL_TXN_CACHE_SOURCE.get(worker_dyncfgs) {
120        txns_progress_source_global::<K, V, T, D, P, C, G>(
121            passthrough.scope(),
122            name,
123            ctx.clone(),
124            client_fn(),
125            txns_id,
126            data_id,
127            as_of,
128            data_key_schema,
129            data_val_schema,
130            unique_id,
131        )
132    } else {
133        txns_progress_source_local::<K, V, T, D, P, C, G>(
134            passthrough.scope(),
135            name,
136            client_fn(),
137            txns_id,
138            data_id,
139            as_of,
140            data_key_schema,
141            data_val_schema,
142            unique_id,
143        )
144    };
145    // Each of the `txns_frontiers` workers wants the full copy of the remap
146    // information.
147    let remap = remap.broadcast();
148    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
149        remap,
150        passthrough,
151        name,
152        data_id,
153        until,
154        unique_id,
155    );
156    (passthrough, vec![source_button, frontiers_button])
157}
158
159/// An alternative implementation of [`txns_progress_source_global`] that opens
160/// a new [`TxnsCache`] local to the operator.
161fn txns_progress_source_local<K, V, T, D, P, C, G>(
162    scope: G,
163    name: &str,
164    client: impl Future<Output = PersistClient> + 'static,
165    txns_id: ShardId,
166    data_id: ShardId,
167    as_of: T,
168    data_key_schema: Arc<K::Schema>,
169    data_val_schema: Arc<V::Schema>,
170    unique_id: u64,
171) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
172where
173    K: Debug + Codec + Send + Sync,
174    V: Debug + Codec + Send + Sync,
175    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
176    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
177    P: Debug + Data,
178    C: TxnsCodec + 'static,
179    G: Scope<Timestamp = T>,
180{
181    let worker_idx = scope.index();
182    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
183    let name = format!("txns_progress_source({})", name);
184    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
185    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
186    let (remap_output, remap_stream) = builder.new_output();
187
188    let shutdown_button = builder.build(move |capabilities| async move {
189        if worker_idx != chosen_worker {
190            return;
191        }
192
193        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
194        let client = client.await;
195        let mut txns_cache = TxnsCache::<T, C>::open(&client, txns_id, Some(data_id)).await;
196
197        let _ = txns_cache.update_gt(&as_of).await;
198        let mut subscribe = txns_cache.data_subscribe(data_id, as_of.clone());
199        let data_write = client
200            .open_writer::<K, V, T, D>(
201                data_id,
202                Arc::clone(&data_key_schema),
203                Arc::clone(&data_val_schema),
204                Diagnostics::from_purpose("data read physical upper"),
205            )
206            .await
207            .expect("schema shouldn't change");
208        if let Some(snapshot) = subscribe.snapshot.take() {
209            snapshot.unblock_read(data_write).await;
210        }
211
212        debug!("{} emitting {:?}", name, subscribe.remap);
213        remap_output.give(&cap, subscribe.remap.clone());
214
215        loop {
216            let _ = txns_cache.update_ge(&subscribe.remap.logical_upper).await;
217            cap.downgrade(&subscribe.remap.logical_upper);
218            let data_listen_next =
219                txns_cache.data_listen_next(&subscribe.data_id, &subscribe.remap.logical_upper);
220            debug!(
221                "{} data_listen_next at {:?}: {:?}",
222                name, subscribe.remap.logical_upper, data_listen_next,
223            );
224            match data_listen_next {
225                // We've caught up to the txns upper and we have to wait for it
226                // to advance before asking again.
227                //
228                // Note that we're asking again with the same input, but once
229                // the cache is past remap.logical_upper (as it will be after
230                // this update_gt call), we're guaranteed to get an answer.
231                DataListenNext::WaitForTxnsProgress => {
232                    let _ = txns_cache.update_gt(&subscribe.remap.logical_upper).await;
233                }
234                // The data shard got a write!
235                DataListenNext::ReadDataTo(new_upper) => {
236                    // A write means both the physical and logical upper advance.
237                    subscribe.remap = DataRemapEntry {
238                        physical_upper: new_upper.clone(),
239                        logical_upper: new_upper,
240                    };
241                    debug!("{} emitting {:?}", name, subscribe.remap);
242                    remap_output.give(&cap, subscribe.remap.clone());
243                }
244                // We know there are no writes in `[logical_upper,
245                // new_progress)`, so advance our output frontier.
246                DataListenNext::EmitLogicalProgress(new_progress) => {
247                    assert!(subscribe.remap.physical_upper < new_progress);
248                    assert!(subscribe.remap.logical_upper < new_progress);
249
250                    subscribe.remap.logical_upper = new_progress;
251                    // As mentioned in the docs on `DataRemapEntry`, we only
252                    // emit updates when the physical upper changes (which
253                    // happens to makes the protocol a tiny bit more
254                    // remap-like).
255                    debug!("{} not emitting {:?}", name, subscribe.remap);
256                }
257            }
258        }
259    });
260    (remap_stream, shutdown_button.press_on_drop())
261}
262
263/// TODO: I'd much prefer the communication protocol between the two operators
264/// to be exactly remap as defined in the [reclocking design doc]. However, we
265/// can't quite recover exactly the information necessary to construct that at
266/// the moment. Seems worth doing, but in the meantime, intentionally make this
267/// look fairly different (`Stream` of `DataRemapEntry` instead of
268/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
269/// optimization, we only re-emit this when the _physical_ upper has changed,
270/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
271/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
272/// one with the largest physical_upper).
273///
274/// [reclocking design doc]:
275///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
276fn txns_progress_source_global<K, V, T, D, P, C, G>(
277    scope: G,
278    name: &str,
279    ctx: TxnsContext,
280    client: impl Future<Output = PersistClient> + 'static,
281    txns_id: ShardId,
282    data_id: ShardId,
283    as_of: T,
284    data_key_schema: Arc<K::Schema>,
285    data_val_schema: Arc<V::Schema>,
286    unique_id: u64,
287) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
288where
289    K: Debug + Codec + Send + Sync,
290    V: Debug + Codec + Send + Sync,
291    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
292    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
293    P: Debug + Data,
294    C: TxnsCodec + 'static,
295    G: Scope<Timestamp = T>,
296{
297    let worker_idx = scope.index();
298    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
299    let name = format!("txns_progress_source({})", name);
300    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
301    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
302    let (remap_output, remap_stream) = builder.new_output();
303
304    let shutdown_button = builder.build(move |capabilities| async move {
305        if worker_idx != chosen_worker {
306            return;
307        }
308
309        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
310        let client = client.await;
311        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
312
313        let _ = txns_read.update_gt(as_of.clone()).await;
314        let data_write = client
315            .open_writer::<K, V, T, D>(
316                data_id,
317                Arc::clone(&data_key_schema),
318                Arc::clone(&data_val_schema),
319                Diagnostics::from_purpose("data read physical upper"),
320            )
321            .await
322            .expect("schema shouldn't change");
323        let mut rx = txns_read
324            .data_subscribe(data_id, as_of.clone(), Box::new(data_write))
325            .await;
326        debug!("{} starting as_of={:?}", name, as_of);
327
328        let mut physical_upper = T::minimum();
329
330        while let Some(remap) = rx.recv().await {
331            assert!(physical_upper <= remap.physical_upper);
332            assert!(physical_upper < remap.logical_upper);
333
334            let logical_upper = remap.logical_upper.clone();
335            // As mentioned in the docs on this function, we only
336            // emit updates when the physical upper changes (which
337            // happens to makes the protocol a tiny bit more
338            // remap-like).
339            if remap.physical_upper != physical_upper {
340                physical_upper = remap.physical_upper.clone();
341                debug!("{} emitting {:?}", name, remap);
342                remap_output.give(&cap, remap);
343            } else {
344                debug!("{} not emitting {:?}", name, remap);
345            }
346            cap.downgrade(&logical_upper);
347        }
348    });
349    (remap_stream, shutdown_button.press_on_drop())
350}
351
352fn txns_progress_frontiers<K, V, T, D, P, C, G>(
353    remap: Stream<G, DataRemapEntry<T>>,
354    passthrough: Stream<G, P>,
355    name: &str,
356    data_id: ShardId,
357    until: Antichain<T>,
358    unique_id: u64,
359) -> (Stream<G, P>, PressOnDropButton)
360where
361    K: Debug + Codec,
362    V: Debug + Codec,
363    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
364    D: Data + Semigroup + Codec64 + Send + Sync,
365    P: Debug + Data,
366    C: TxnsCodec,
367    G: Scope<Timestamp = T>,
368{
369    let name = format!("txns_progress_frontiers({})", name);
370    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
371    let name = format!(
372        "{} [{}] {}/{} {:.9}",
373        name,
374        unique_id,
375        passthrough.scope().index(),
376        passthrough.scope().peers(),
377        data_id.to_string(),
378    );
379    let (passthrough_output, passthrough_stream) =
380        builder.new_output::<CapacityContainerBuilder<_>>();
381    let mut remap_input = builder.new_disconnected_input(&remap, Pipeline);
382    let mut passthrough_input = builder.new_disconnected_input(&passthrough, Pipeline);
383
384    let shutdown_button = builder.build(move |capabilities| async move {
385        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
386
387        // None is used to indicate that both uppers are the empty antichain.
388        let mut remap = Some(DataRemapEntry {
389            physical_upper: T::minimum(),
390            logical_upper: T::minimum(),
391        });
392        // NB: The following loop uses `cap.time()`` to track how far we've
393        // progressed in copying along the passthrough input.
394        loop {
395            debug!("{} remap {:?}", name, remap);
396            if let Some(r) = remap.as_ref() {
397                assert!(r.physical_upper <= r.logical_upper);
398                // If we've passed through data to at least `physical_upper`,
399                // then it means we can artificially advance the upper of the
400                // output to `logical_upper`. This also indicates that we need
401                // to wait for the next DataRemapEntry. It can either (A) have
402                // the same physical upper or (B) have a larger physical upper.
403                //
404                // - If (A), then we would again satisfy this `physical_upper`
405                //   check, again advance the logical upper again, ...
406                // - If (B), then we'd fall down to the code below, which copies
407                //   the passthrough data until the frontier passes
408                //   `physical_upper`, then loops back up here.
409                if r.physical_upper.less_equal(cap.time()) {
410                    if cap.time() < &r.logical_upper {
411                        cap.downgrade(&r.logical_upper);
412                    }
413                    remap = txns_progress_frontiers_read_remap_input(
414                        &name,
415                        &mut remap_input,
416                        r.clone(),
417                    )
418                    .await;
419                    continue;
420                }
421            }
422
423            // This only returns None when there are no more data left. Turn it
424            // into an empty frontier progress so we can re-use the shutdown
425            // code below.
426            let event = passthrough_input
427                .next()
428                .await
429                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
430            match event {
431                // NB: Ignore the data_cap because this input is disconnected.
432                AsyncEvent::Data(_data_cap, mut data) => {
433                    // NB: Nothing to do here for `until` because both the
434                    // `shard_source` (before this operator) and
435                    // `mfp_and_decode` (after this operator) do the necessary
436                    // filtering.
437                    debug!("{} emitting data {:?}", name, data);
438                    passthrough_output.give_container(&cap, &mut data);
439                }
440                AsyncEvent::Progress(new_progress) => {
441                    // If `until.less_equal(new_progress)`, it means that all
442                    // subsequent batches will contain only times greater or
443                    // equal to `until`, which means they can be dropped in
444                    // their entirety.
445                    //
446                    // Ideally this check would live in `txns_progress_source`,
447                    // but that turns out to be much more invasive (requires
448                    // replacing lots of `T`s with `Antichain<T>`s). Given that
449                    // we've been thinking about reworking the operators, do the
450                    // easy but more wasteful thing for now.
451                    if PartialOrder::less_equal(&until, &new_progress) {
452                        debug!(
453                            "{} progress {:?} has passed until {:?}",
454                            name,
455                            new_progress.elements(),
456                            until.elements()
457                        );
458                        return;
459                    }
460                    // We reached the empty frontier! Shut down.
461                    let Some(new_progress) = new_progress.into_option() else {
462                        return;
463                    };
464
465                    // Recall that any reads of the data shard are always
466                    // correct, so given that we've passed through any data
467                    // from the input, that means we're free to pass through
468                    // frontier updates too.
469                    if cap.time() < &new_progress {
470                        debug!("{} downgrading cap to {:?}", name, new_progress);
471                        cap.downgrade(&new_progress);
472                    }
473                }
474            }
475        }
476    });
477    (passthrough_stream, shutdown_button.press_on_drop())
478}
479
480async fn txns_progress_frontiers_read_remap_input<T, C>(
481    name: &str,
482    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
483    mut remap: DataRemapEntry<T>,
484) -> Option<DataRemapEntry<T>>
485where
486    T: Timestamp + TotalOrder,
487    C: InputConnection<T>,
488{
489    while let Some(event) = input.next().await {
490        let xs = match event {
491            AsyncEvent::Progress(logical_upper) => {
492                if let Some(logical_upper) = logical_upper.into_option() {
493                    if remap.logical_upper < logical_upper {
494                        remap.logical_upper = logical_upper;
495                        return Some(remap);
496                    }
497                }
498                continue;
499            }
500            AsyncEvent::Data(_cap, xs) => xs,
501        };
502        for x in xs {
503            debug!("{} got remap {:?}", name, x);
504            // Don't assume anything about the ordering.
505            if remap.logical_upper < x.logical_upper {
506                assert!(
507                    remap.physical_upper <= x.physical_upper,
508                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
509                    remap.physical_upper,
510                    x.physical_upper,
511                );
512                // TODO: If the physical upper has advanced, that's a very
513                // strong hint that the data shard is about to be written to.
514                // Because the data shard's upper advances sparsely (on write,
515                // but not on passage of time) which invalidates the "every 1s"
516                // assumption of the default tuning, we've had to de-tune the
517                // listen sleeps on the paired persist_source. Maybe we use "one
518                // state" to wake it up in case pubsub doesn't and remove the
519                // listen polling entirely? (NB: This would have to happen in
520                // each worker so that it's guaranteed to happen in each
521                // process.)
522                remap = x;
523            }
524        }
525        return Some(remap);
526    }
527    // remap_input is closed, which indicates the data shard is finished.
528    None
529}
530
531/// The process global [`TxnsRead`] that any operator can communicate with.
532#[derive(Default, Debug, Clone)]
533pub struct TxnsContext {
534    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
535}
536
537impl TxnsContext {
538    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
539    where
540        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
541        C: TxnsCodec + 'static,
542    {
543        let read = self
544            .read
545            .get_or_init(|| {
546                let client = client.clone();
547                async move {
548                    let read: Box<dyn Any + Send + Sync> =
549                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
550                    read
551                }
552            })
553            .await
554            .downcast_ref::<TxnsRead<T>>()
555            .expect("timestamp types should match");
556        // We initially only have one txns shard in the system.
557        assert_eq!(&txns_id, read.txns_id());
558        read.clone()
559    }
560}
561
562// Existing configs use the prefix "persist_txns_" for historical reasons. New
563// configs should use the prefix "txn_wal_".
564
565pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
566    "persist_txns_data_shard_retryer_initial_backoff",
567    Duration::from_millis(1024),
568    "The initial backoff when polling for new batches from a txns data shard persist_source.",
569);
570
571pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
572    "persist_txns_data_shard_retryer_multiplier",
573    2,
574    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
575);
576
577pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
578    "persist_txns_data_shard_retryer_clamp",
579    Duration::from_secs(16),
580    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
581);
582
583/// Retry configuration for txn-wal data shard override of
584/// `next_listen_batch`.
585pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
586    RetryParameters {
587        fixed_sleep: Duration::ZERO,
588        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
589        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
590        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
591    }
592}
593
594/// A helper for subscribing to a data shard using the timely operators.
595///
596/// This could instead be a wrapper around a [Subscribe], but it's only used in
597/// tests and maelstrom, so do it by wrapping the timely operators to get
598/// additional coverage. For the same reason, hardcode the K, V, T, D types.
599///
600/// [Subscribe]: mz_persist_client::read::Subscribe
601pub struct DataSubscribe {
602    pub(crate) as_of: u64,
603    pub(crate) worker: Worker<timely::communication::allocator::Thread>,
604    data: ProbeHandle<u64>,
605    txns: ProbeHandle<u64>,
606    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
607    output: Vec<(String, u64, i64)>,
608
609    _tokens: Vec<PressOnDropButton>,
610}
611
612impl std::fmt::Debug for DataSubscribe {
613    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
614        let DataSubscribe {
615            as_of,
616            worker: _,
617            data,
618            txns,
619            capture: _,
620            output,
621            _tokens: _,
622        } = self;
623        f.debug_struct("DataSubscribe")
624            .field("as_of", as_of)
625            .field("data", data)
626            .field("txns", txns)
627            .field("output", output)
628            .finish_non_exhaustive()
629    }
630}
631
632impl DataSubscribe {
633    /// Creates a new [DataSubscribe].
634    pub fn new(
635        name: &str,
636        client: PersistClient,
637        txns_id: ShardId,
638        data_id: ShardId,
639        as_of: u64,
640        until: Antichain<u64>,
641        use_global_txn_cache: bool,
642    ) -> Self {
643        let mut worker = Worker::new(
644            WorkerConfig::default(),
645            timely::communication::allocator::Thread::default(),
646            Some(std::time::Instant::now()),
647        );
648        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
649            let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
650                let client = client.clone();
651                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
652                    scope,
653                    name,
654                    move || std::future::ready(client.clone()),
655                    data_id,
656                    Some(Antichain::from_elem(as_of)),
657                    SnapshotMode::Include,
658                    until.clone(),
659                    false.then_some(|_, _: &_, _| unreachable!()),
660                    Arc::new(StringSchema),
661                    Arc::new(UnitSchema),
662                    FilterResult::keep_all,
663                    false.then_some(|| unreachable!()),
664                    async {},
665                    ErrorHandler::Halt("data_subscribe"),
666                );
667                (data_stream.leave(), token)
668            });
669            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
670            let data_stream = data_stream.flat_map(|part| {
671                let part = part.parse();
672                part.part.map(|((k, v), t, d)| {
673                    let (k, ()) = (k.unwrap(), v.unwrap());
674                    (k, t, d)
675                })
676            });
677            let data_stream = data_stream.probe_with(&data);
678            // We purposely do not use the `ConfigSet` in `client` so that
679            // different tests can set different values.
680            let config_set = ConfigSet::default().add(&USE_GLOBAL_TXN_CACHE_SOURCE);
681            let mut updates = ConfigUpdates::default();
682            updates.add(&USE_GLOBAL_TXN_CACHE_SOURCE, use_global_txn_cache);
683            updates.apply(&config_set);
684            let (data_stream, mut txns_progress_token) =
685                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
686                    data_stream,
687                    name,
688                    &TxnsContext::default(),
689                    &config_set,
690                    || std::future::ready(client.clone()),
691                    txns_id,
692                    data_id,
693                    as_of,
694                    until,
695                    Arc::new(StringSchema),
696                    Arc::new(UnitSchema),
697                );
698            let data_stream = data_stream.probe_with(&txns);
699            let mut tokens = shard_source_token;
700            tokens.append(&mut txns_progress_token);
701            (data, txns, data_stream.capture(), tokens)
702        });
703        Self {
704            as_of,
705            worker,
706            data,
707            txns,
708            capture,
709            output: Vec::new(),
710            _tokens: tokens,
711        }
712    }
713
714    /// Returns the exclusive progress of the dataflow.
715    pub fn progress(&self) -> u64 {
716        self.txns
717            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
718    }
719
720    /// Steps the dataflow, capturing output.
721    pub fn step(&mut self) {
722        self.worker.step();
723        self.capture_output()
724    }
725
726    pub(crate) fn capture_output(&mut self) {
727        loop {
728            let event = match self.capture.try_recv() {
729                Ok(x) => x,
730                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
731            };
732            match event {
733                Event::Progress(_) => {}
734                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
735            }
736        }
737    }
738
739    /// Steps the dataflow past the given time, capturing output.
740    #[cfg(test)]
741    pub async fn step_past(&mut self, ts: u64) {
742        while self.txns.less_equal(&ts) {
743            tracing::trace!(
744                "progress at {:?}",
745                self.txns.with_frontier(|x| x.to_owned()).elements()
746            );
747            self.step();
748            tokio::task::yield_now().await;
749        }
750    }
751
752    /// Returns captured output.
753    pub fn output(&self) -> &Vec<(String, u64, i64)> {
754        &self.output
755    }
756}
757
758/// A handle to a [DataSubscribe] running in a task.
759#[derive(Debug)]
760pub struct DataSubscribeTask {
761    /// Carries step requests. A `None` timestamp requests one step, a
762    /// `Some(ts)` requests stepping until we progress beyond `ts`.
763    tx: std::sync::mpsc::Sender<(
764        Option<u64>,
765        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
766    )>,
767    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
768    output: Vec<(String, u64, i64)>,
769    progress: u64,
770}
771
772impl DataSubscribeTask {
773    /// Creates a new [DataSubscribeTask].
774    pub async fn new(
775        client: PersistClient,
776        txns_id: ShardId,
777        data_id: ShardId,
778        as_of: u64,
779    ) -> Self {
780        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
781        let (tx, rx) = std::sync::mpsc::channel();
782        let task = mz_ore::task::spawn_blocking(
783            || "data_subscribe task",
784            move || Self::task(client, cache, data_id, as_of, rx),
785        );
786        DataSubscribeTask {
787            tx,
788            task,
789            output: Vec::new(),
790            progress: 0,
791        }
792    }
793
794    #[cfg(test)]
795    async fn step(&mut self) {
796        self.send(None).await;
797    }
798
799    /// Steps the dataflow past the given time, capturing output.
800    pub async fn step_past(&mut self, ts: u64) -> u64 {
801        self.send(Some(ts)).await;
802        self.progress
803    }
804
805    /// Returns captured output.
806    pub fn output(&self) -> &Vec<(String, u64, i64)> {
807        &self.output
808    }
809
810    async fn send(&mut self, ts: Option<u64>) {
811        let (tx, rx) = tokio::sync::oneshot::channel();
812        self.tx.send((ts, tx)).expect("task should be running");
813        let (mut new_output, new_progress) = rx.await.expect("task should be running");
814        self.output.append(&mut new_output);
815        assert!(self.progress <= new_progress);
816        self.progress = new_progress;
817    }
818
819    /// Signals for the task to exit, and then waits for this to happen.
820    ///
821    /// _All_ output from the lifetime of the task (not just what was previously
822    /// captured) is returned.
823    pub async fn finish(self) -> Vec<(String, u64, i64)> {
824        // Closing the channel signals the task to exit.
825        drop(self.tx);
826        self.task.wait_and_assert_finished().await
827    }
828
829    fn task(
830        client: PersistClient,
831        cache: TxnsCache<u64>,
832        data_id: ShardId,
833        as_of: u64,
834        rx: std::sync::mpsc::Receiver<(
835            Option<u64>,
836            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
837        )>,
838    ) -> Vec<(String, u64, i64)> {
839        let mut subscribe = DataSubscribe::new(
840            "DataSubscribeTask",
841            client.clone(),
842            cache.txns_id(),
843            data_id,
844            as_of,
845            Antichain::new(),
846            true,
847        );
848        let mut output = Vec::new();
849        loop {
850            let (ts, tx) = match rx.try_recv() {
851                Ok(x) => x,
852                Err(TryRecvError::Empty) => {
853                    // No requests, continue stepping so nothing deadlocks.
854                    subscribe.step();
855                    continue;
856                }
857                Err(TryRecvError::Disconnected) => {
858                    // All done! Return our output.
859                    return output;
860                }
861            };
862            // Always step at least once.
863            subscribe.step();
864            // If we got a ts, make sure to step past it.
865            if let Some(ts) = ts {
866                while subscribe.progress() <= ts {
867                    subscribe.step();
868                }
869            }
870            let new_output = std::mem::take(&mut subscribe.output);
871            output.extend(new_output.iter().cloned());
872            let _ = tx.send((new_output, subscribe.progress()));
873        }
874    }
875}
876
877#[cfg(test)]
878mod tests {
879    use itertools::{Either, Itertools};
880    use mz_persist_types::Opaque;
881
882    use crate::tests::writer;
883    use crate::txns::TxnsHandle;
884
885    use super::*;
886
887    impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
888    where
889        K: Debug + Codec,
890        V: Debug + Codec,
891        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
892        D: Debug + Semigroup + Ord + Codec64 + Send + Sync,
893        O: Opaque + Debug + Codec64,
894        C: TxnsCodec,
895    {
896        async fn subscribe_task(
897            &self,
898            client: &PersistClient,
899            data_id: ShardId,
900            as_of: u64,
901        ) -> DataSubscribeTask {
902            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
903        }
904    }
905
906    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
907    #[cfg_attr(miri, ignore)] // too slow
908    async fn data_subscribe() {
909        async fn step(subs: &mut Vec<DataSubscribeTask>) {
910            for sub in subs.iter_mut() {
911                sub.step().await;
912            }
913        }
914
915        let client = PersistClient::new_for_tests().await;
916        let mut txns = TxnsHandle::expect_open(client.clone()).await;
917        let log = txns.new_log();
918        let d0 = ShardId::new();
919
920        // Start a subscription before the shard gets registered.
921        let mut subs = Vec::new();
922        subs.push(txns.subscribe_task(&client, d0, 5).await);
923        step(&mut subs).await;
924
925        // Now register the shard. Also start a new subscription and step the
926        // previous one (plus repeat this for every later step).
927        txns.register(1, [writer(&client, d0).await]).await.unwrap();
928        subs.push(txns.subscribe_task(&client, d0, 5).await);
929        step(&mut subs).await;
930
931        // Now write something unrelated.
932        let d1 = txns.expect_register(2).await;
933        txns.expect_commit_at(3, d1, &["nope"], &log).await;
934        subs.push(txns.subscribe_task(&client, d0, 5).await);
935        step(&mut subs).await;
936
937        // Now write to our shard before.
938        txns.expect_commit_at(4, d0, &["4"], &log).await;
939        subs.push(txns.subscribe_task(&client, d0, 5).await);
940        step(&mut subs).await;
941
942        // Now write to our shard at the as_of.
943        txns.expect_commit_at(5, d0, &["5"], &log).await;
944        subs.push(txns.subscribe_task(&client, d0, 5).await);
945        step(&mut subs).await;
946
947        // Now write to our shard past the as_of.
948        txns.expect_commit_at(6, d0, &["6"], &log).await;
949        subs.push(txns.subscribe_task(&client, d0, 5).await);
950        step(&mut subs).await;
951
952        // Now write something unrelated again.
953        txns.expect_commit_at(7, d1, &["nope"], &log).await;
954        subs.push(txns.subscribe_task(&client, d0, 5).await);
955        step(&mut subs).await;
956
957        // Verify that the dataflows can progress to the expected point and that
958        // we read the right thing no matter when the dataflow started.
959        for mut sub in subs {
960            let progress = sub.step_past(7).await;
961            assert_eq!(progress, 8);
962            log.assert_eq(d0, 5, 8, sub.finish().await);
963        }
964    }
965
966    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
967    #[cfg_attr(miri, ignore)] // too slow
968    async fn subscribe_shard_finalize() {
969        let client = PersistClient::new_for_tests().await;
970        let mut txns = TxnsHandle::expect_open(client.clone()).await;
971        let log = txns.new_log();
972        let d0 = txns.expect_register(1).await;
973
974        // Start the operator as_of the register ts.
975        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
976        sub.step_past(1).await;
977
978        // Write to it via txns.
979        txns.expect_commit_at(2, d0, &["foo"], &log).await;
980        sub.step_past(2).await;
981
982        // Unregister it.
983        txns.forget(3, [d0]).await.unwrap();
984        sub.step_past(3).await;
985
986        // TODO: Hard mode, see if we can get the rest of this test to work even
987        // _without_ the txns shard advancing.
988        txns.begin().commit_at(&mut txns, 7).await.unwrap();
989
990        // The operator should continue to emit data written directly even
991        // though it's no longer in the txns set.
992        let mut d0_write = writer(&client, d0).await;
993        let key = "bar".to_owned();
994        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
995            .await
996            .unwrap();
997        log.record((d0, key, 5, 1));
998        sub.step_past(4).await;
999
1000        // Now finalize the shard to writes.
1001        let () = d0_write
1002            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new(), true)
1003            .await
1004            .unwrap()
1005            .unwrap();
1006        while sub.txns.less_than(&u64::MAX) {
1007            sub.step();
1008            tokio::task::yield_now().await;
1009        }
1010
1011        // Make sure we read the correct things.
1012        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
1013
1014        // Also make sure that we can read the right things if we start up after
1015        // the forget but before the direct write and ditto after the direct
1016        // write.
1017        log.assert_subscribe(d0, 4, u64::MAX).await;
1018        log.assert_subscribe(d0, 6, u64::MAX).await;
1019    }
1020
1021    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1022    #[cfg_attr(miri, ignore)] // too slow
1023    async fn subscribe_shard_register_forget() {
1024        let client = PersistClient::new_for_tests().await;
1025        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1026        let d0 = ShardId::new();
1027
1028        // Start a subscription on the data shard.
1029        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
1030        assert_eq!(sub.progress(), 0);
1031
1032        // Register the shard at 10.
1033        txns.register(10, [writer(&client, d0).await])
1034            .await
1035            .unwrap();
1036        sub.step_past(10).await;
1037        assert!(
1038            sub.progress() > 10,
1039            "operator should advance past 10 when shard is registered"
1040        );
1041
1042        // Forget the shard at 20.
1043        txns.forget(20, [d0]).await.unwrap();
1044        sub.step_past(20).await;
1045        assert!(
1046            sub.progress() > 20,
1047            "operator should advance past 20 when shard is forgotten"
1048        );
1049    }
1050
1051    #[mz_ore::test(tokio::test)]
1052    #[cfg_attr(miri, ignore)] // too slow
1053    async fn as_of_until() {
1054        let client = PersistClient::new_for_tests().await;
1055        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1056        let log = txns.new_log();
1057
1058        let d0 = txns.expect_register(1).await;
1059        txns.expect_commit_at(2, d0, &["2"], &log).await;
1060        txns.expect_commit_at(3, d0, &["3"], &log).await;
1061        txns.expect_commit_at(4, d0, &["4"], &log).await;
1062        txns.expect_commit_at(5, d0, &["5"], &log).await;
1063        txns.expect_commit_at(6, d0, &["6"], &log).await;
1064        txns.expect_commit_at(7, d0, &["7"], &log).await;
1065
1066        let until = 5;
1067        let mut sub = DataSubscribe::new(
1068            "as_of_until",
1069            client,
1070            txns.txns_id(),
1071            d0,
1072            3,
1073            Antichain::from_elem(until),
1074            true,
1075        );
1076        // Manually step the dataflow, instead of going through the
1077        // `DataSubscribe` helper because we're interested in all captured
1078        // events.
1079        while sub.txns.less_equal(&5) {
1080            sub.worker.step();
1081            tokio::task::yield_now().await;
1082            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1083        }
1084        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
1085            sub.capture.into_iter().partition_map(|event| match event {
1086                Event::Progress(progress) => Either::Left(progress),
1087                Event::Messages(ts, data) => Either::Right((ts, data)),
1088            });
1089        let expected = vec![
1090            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
1091            (3, vec![("4".to_owned(), 4, 1)]),
1092        ];
1093        assert_eq!(actual_events, expected);
1094
1095        // The number and contents of progress messages is not guaranteed and
1096        // depends on the downgrade behavior. The only thing we can assert is
1097        // the max progress timestamp, if there is one, is less than the until.
1098        if let Some(max_progress_ts) = actual_progresses
1099            .into_iter()
1100            .flatten()
1101            .map(|(ts, _diff)| ts)
1102            .max()
1103        {
1104            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
1105        }
1106    }
1107}