mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet, ConfigUpdates};
24use mz_ore::cast::CastFrom;
25use mz_ore::task::JoinHandleExt;
26use mz_persist_client::cfg::{RetryParameters, USE_GLOBAL_TXN_CACHE_SOURCE};
27use mz_persist_client::operators::shard_source::{FilterResult, SnapshotMode, shard_source};
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
30use mz_persist_types::txn::TxnsCodec;
31use mz_persist_types::{Codec, Codec64, StepForward};
32use mz_timely_util::builder_async::{
33    AsyncInputHandle, Event as AsyncEvent, InputConnection,
34    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use timely::container::CapacityContainerBuilder;
37use timely::dataflow::channels::pact::Pipeline;
38use timely::dataflow::operators::capture::Event;
39use timely::dataflow::operators::{Broadcast, Capture, Leave, Map, Probe};
40use timely::dataflow::{ProbeHandle, Scope, Stream};
41use timely::order::TotalOrder;
42use timely::progress::{Antichain, Timestamp};
43use timely::worker::Worker;
44use timely::{Data, PartialOrder, WorkerConfig};
45use tracing::debug;
46
47use crate::TxnsCodecDefault;
48use crate::txn_cache::TxnsCache;
49use crate::txn_read::{DataListenNext, DataRemapEntry, TxnsRead};
50
51/// An operator for translating physical data shard frontiers into logical ones.
52///
53/// A data shard in the txns set logically advances its upper each time a txn is
54/// committed, but the upper is not physically advanced unless that data shard
55/// was involved in the txn. This means that a shard_source (or any read)
56/// pointed at a data shard would appear to stall at the time of the most recent
57/// write. We fix this for shard_source by flowing its output through a new
58/// `txns_progress` dataflow operator, which ensures that the
59/// frontier/capability is advanced as the txns shard progresses, as long as the
60/// shard_source is up to date with the latest committed write to that data
61/// shard.
62///
63/// Example:
64///
65/// - A data shard has most recently been written to at 3.
66/// - The txns shard's upper is at 6.
67/// - We render a dataflow containing a shard_source with an as_of of 5.
68/// - A txn NOT involving the data shard is committed at 7.
69/// - A txn involving the data shard is committed at 9.
70///
71/// How it works:
72///
73/// - The shard_source operator is rendered. Its single output is hooked up as a
74///   _disconnected_ input to txns_progress. The txns_progress single output is
75///   a stream of the same type, which is used by downstream operators. This
76///   txns_progress operator is targeted at one data_shard; rendering a
77///   shard_source for a second data shard requires a second txns_progress
78///   operator.
79/// - The shard_source operator emits data through 3 and advances the frontier.
80/// - The txns_progress operator passes through these writes and frontier
81///   advancements unchanged. (Recall that it's always correct to read a data
82///   shard "normally", it just might stall.) Because the txns_progress operator
83///   knows there are no writes in `[3,5]`, it then downgrades its own
84///   capability past 5 (to 6). Because the input is disconnected, this means
85///   the overall frontier of the output is downgraded to 6.
86/// - The txns_progress operator learns about the write at 7 (the upper is now
87///   8). Because it knows that the data shard was not involved in this, it's
88///   free to downgrade its capability to 8.
89/// - The txns_progress operator learns about the write at 9 (the upper is now
90///   10). It knows that the data shard _WAS_ involved in this, so it forwards
91///   on data from its input until the input has progressed to 10, at which
92///   point it can itself downgrade to 10.
93pub fn txns_progress<K, V, T, D, P, C, F, G>(
94    passthrough: Stream<G, P>,
95    name: &str,
96    ctx: &TxnsContext,
97    worker_dyncfgs: &ConfigSet,
98    client_fn: impl Fn() -> F,
99    txns_id: ShardId,
100    data_id: ShardId,
101    as_of: T,
102    until: Antichain<T>,
103    data_key_schema: Arc<K::Schema>,
104    data_val_schema: Arc<V::Schema>,
105) -> (Stream<G, P>, Vec<PressOnDropButton>)
106where
107    K: Debug + Codec + Send + Sync,
108    V: Debug + Codec + Send + Sync,
109    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
110    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
111    P: Debug + Data,
112    C: TxnsCodec + 'static,
113    F: Future<Output = PersistClient> + Send + 'static,
114    G: Scope<Timestamp = T>,
115{
116    let unique_id = (name, passthrough.scope().addr()).hashed();
117    let (remap, source_button) = if USE_GLOBAL_TXN_CACHE_SOURCE.get(worker_dyncfgs) {
118        txns_progress_source_global::<K, V, T, D, P, C, G>(
119            passthrough.scope(),
120            name,
121            ctx.clone(),
122            client_fn(),
123            txns_id,
124            data_id,
125            as_of,
126            data_key_schema,
127            data_val_schema,
128            unique_id,
129        )
130    } else {
131        txns_progress_source_local::<K, V, T, D, P, C, G>(
132            passthrough.scope(),
133            name,
134            client_fn(),
135            txns_id,
136            data_id,
137            as_of,
138            data_key_schema,
139            data_val_schema,
140            unique_id,
141        )
142    };
143    // Each of the `txns_frontiers` workers wants the full copy of the remap
144    // information.
145    let remap = remap.broadcast();
146    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
147        remap,
148        passthrough,
149        name,
150        data_id,
151        until,
152        unique_id,
153    );
154    (passthrough, vec![source_button, frontiers_button])
155}
156
157/// An alternative implementation of [`txns_progress_source_global`] that opens
158/// a new [`TxnsCache`] local to the operator.
159fn txns_progress_source_local<K, V, T, D, P, C, G>(
160    scope: G,
161    name: &str,
162    client: impl Future<Output = PersistClient> + 'static,
163    txns_id: ShardId,
164    data_id: ShardId,
165    as_of: T,
166    data_key_schema: Arc<K::Schema>,
167    data_val_schema: Arc<V::Schema>,
168    unique_id: u64,
169) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
170where
171    K: Debug + Codec + Send + Sync,
172    V: Debug + Codec + Send + Sync,
173    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
174    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
175    P: Debug + Data,
176    C: TxnsCodec + 'static,
177    G: Scope<Timestamp = T>,
178{
179    let worker_idx = scope.index();
180    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
181    let name = format!("txns_progress_source({})", name);
182    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
183    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
184    let (remap_output, remap_stream) = builder.new_output();
185
186    let shutdown_button = builder.build(move |capabilities| async move {
187        if worker_idx != chosen_worker {
188            return;
189        }
190
191        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
192        let client = client.await;
193        let mut txns_cache = TxnsCache::<T, C>::open(&client, txns_id, Some(data_id)).await;
194
195        let _ = txns_cache.update_gt(&as_of).await;
196        let mut subscribe = txns_cache.data_subscribe(data_id, as_of.clone());
197        let data_write = client
198            .open_writer::<K, V, T, D>(
199                data_id,
200                Arc::clone(&data_key_schema),
201                Arc::clone(&data_val_schema),
202                Diagnostics::from_purpose("data read physical upper"),
203            )
204            .await
205            .expect("schema shouldn't change");
206        if let Some(snapshot) = subscribe.snapshot.take() {
207            snapshot.unblock_read(data_write).await;
208        }
209
210        debug!("{} emitting {:?}", name, subscribe.remap);
211        remap_output.give(&cap, subscribe.remap.clone());
212
213        loop {
214            let _ = txns_cache.update_ge(&subscribe.remap.logical_upper).await;
215            cap.downgrade(&subscribe.remap.logical_upper);
216            let data_listen_next =
217                txns_cache.data_listen_next(&subscribe.data_id, &subscribe.remap.logical_upper);
218            debug!(
219                "{} data_listen_next at {:?}: {:?}",
220                name, subscribe.remap.logical_upper, data_listen_next,
221            );
222            match data_listen_next {
223                // We've caught up to the txns upper and we have to wait for it
224                // to advance before asking again.
225                //
226                // Note that we're asking again with the same input, but once
227                // the cache is past remap.logical_upper (as it will be after
228                // this update_gt call), we're guaranteed to get an answer.
229                DataListenNext::WaitForTxnsProgress => {
230                    let _ = txns_cache.update_gt(&subscribe.remap.logical_upper).await;
231                }
232                // The data shard got a write!
233                DataListenNext::ReadDataTo(new_upper) => {
234                    // A write means both the physical and logical upper advance.
235                    subscribe.remap = DataRemapEntry {
236                        physical_upper: new_upper.clone(),
237                        logical_upper: new_upper,
238                    };
239                    debug!("{} emitting {:?}", name, subscribe.remap);
240                    remap_output.give(&cap, subscribe.remap.clone());
241                }
242                // We know there are no writes in `[logical_upper,
243                // new_progress)`, so advance our output frontier.
244                DataListenNext::EmitLogicalProgress(new_progress) => {
245                    assert!(subscribe.remap.physical_upper < new_progress);
246                    assert!(subscribe.remap.logical_upper < new_progress);
247
248                    subscribe.remap.logical_upper = new_progress;
249                    // As mentioned in the docs on `DataRemapEntry`, we only
250                    // emit updates when the physical upper changes (which
251                    // happens to makes the protocol a tiny bit more
252                    // remap-like).
253                    debug!("{} not emitting {:?}", name, subscribe.remap);
254                }
255            }
256        }
257    });
258    (remap_stream, shutdown_button.press_on_drop())
259}
260
261/// TODO: I'd much prefer the communication protocol between the two operators
262/// to be exactly remap as defined in the [reclocking design doc]. However, we
263/// can't quite recover exactly the information necessary to construct that at
264/// the moment. Seems worth doing, but in the meantime, intentionally make this
265/// look fairly different (`Stream` of `DataRemapEntry` instead of
266/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
267/// optimization, we only re-emit this when the _physical_ upper has changed,
268/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
269/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
270/// one with the largest physical_upper).
271///
272/// [reclocking design doc]:
273///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
274fn txns_progress_source_global<K, V, T, D, P, C, G>(
275    scope: G,
276    name: &str,
277    ctx: TxnsContext,
278    client: impl Future<Output = PersistClient> + 'static,
279    txns_id: ShardId,
280    data_id: ShardId,
281    as_of: T,
282    data_key_schema: Arc<K::Schema>,
283    data_val_schema: Arc<V::Schema>,
284    unique_id: u64,
285) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
286where
287    K: Debug + Codec + Send + Sync,
288    V: Debug + Codec + Send + Sync,
289    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
290    D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
291    P: Debug + Data,
292    C: TxnsCodec + 'static,
293    G: Scope<Timestamp = T>,
294{
295    let worker_idx = scope.index();
296    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
297    let name = format!("txns_progress_source({})", name);
298    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
299    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
300    let (remap_output, remap_stream) = builder.new_output();
301
302    let shutdown_button = builder.build(move |capabilities| async move {
303        if worker_idx != chosen_worker {
304            return;
305        }
306
307        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
308        let client = client.await;
309        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
310
311        let _ = txns_read.update_gt(as_of.clone()).await;
312        let data_write = client
313            .open_writer::<K, V, T, D>(
314                data_id,
315                Arc::clone(&data_key_schema),
316                Arc::clone(&data_val_schema),
317                Diagnostics::from_purpose("data read physical upper"),
318            )
319            .await
320            .expect("schema shouldn't change");
321        let mut rx = txns_read
322            .data_subscribe(data_id, as_of.clone(), Box::new(data_write))
323            .await;
324        debug!("{} starting as_of={:?}", name, as_of);
325
326        let mut physical_upper = T::minimum();
327
328        while let Some(remap) = rx.recv().await {
329            assert!(physical_upper <= remap.physical_upper);
330            assert!(physical_upper < remap.logical_upper);
331
332            let logical_upper = remap.logical_upper.clone();
333            // As mentioned in the docs on this function, we only
334            // emit updates when the physical upper changes (which
335            // happens to makes the protocol a tiny bit more
336            // remap-like).
337            if remap.physical_upper != physical_upper {
338                physical_upper = remap.physical_upper.clone();
339                debug!("{} emitting {:?}", name, remap);
340                remap_output.give(&cap, remap);
341            } else {
342                debug!("{} not emitting {:?}", name, remap);
343            }
344            cap.downgrade(&logical_upper);
345        }
346    });
347    (remap_stream, shutdown_button.press_on_drop())
348}
349
350fn txns_progress_frontiers<K, V, T, D, P, C, G>(
351    remap: Stream<G, DataRemapEntry<T>>,
352    passthrough: Stream<G, P>,
353    name: &str,
354    data_id: ShardId,
355    until: Antichain<T>,
356    unique_id: u64,
357) -> (Stream<G, P>, PressOnDropButton)
358where
359    K: Debug + Codec,
360    V: Debug + Codec,
361    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
362    D: Data + Semigroup + Codec64 + Send + Sync,
363    P: Debug + Data,
364    C: TxnsCodec,
365    G: Scope<Timestamp = T>,
366{
367    let name = format!("txns_progress_frontiers({})", name);
368    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
369    let name = format!(
370        "{} [{}] {}/{} {:.9}",
371        name,
372        unique_id,
373        passthrough.scope().index(),
374        passthrough.scope().peers(),
375        data_id.to_string(),
376    );
377    let (passthrough_output, passthrough_stream) =
378        builder.new_output::<CapacityContainerBuilder<_>>();
379    let mut remap_input = builder.new_disconnected_input(&remap, Pipeline);
380    let mut passthrough_input = builder.new_disconnected_input(&passthrough, Pipeline);
381
382    let shutdown_button = builder.build(move |capabilities| async move {
383        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
384
385        // None is used to indicate that both uppers are the empty antichain.
386        let mut remap = Some(DataRemapEntry {
387            physical_upper: T::minimum(),
388            logical_upper: T::minimum(),
389        });
390        // NB: The following loop uses `cap.time()`` to track how far we've
391        // progressed in copying along the passthrough input.
392        loop {
393            debug!("{} remap {:?}", name, remap);
394            if let Some(r) = remap.as_ref() {
395                assert!(r.physical_upper <= r.logical_upper);
396                // If we've passed through data to at least `physical_upper`,
397                // then it means we can artificially advance the upper of the
398                // output to `logical_upper`. This also indicates that we need
399                // to wait for the next DataRemapEntry. It can either (A) have
400                // the same physical upper or (B) have a larger physical upper.
401                //
402                // - If (A), then we would again satisfy this `physical_upper`
403                //   check, again advance the logical upper again, ...
404                // - If (B), then we'd fall down to the code below, which copies
405                //   the passthrough data until the frontier passes
406                //   `physical_upper`, then loops back up here.
407                if r.physical_upper.less_equal(cap.time()) {
408                    if cap.time() < &r.logical_upper {
409                        cap.downgrade(&r.logical_upper);
410                    }
411                    remap = txns_progress_frontiers_read_remap_input(
412                        &name,
413                        &mut remap_input,
414                        r.clone(),
415                    )
416                    .await;
417                    continue;
418                }
419            }
420
421            // This only returns None when there are no more data left. Turn it
422            // into an empty frontier progress so we can re-use the shutdown
423            // code below.
424            let event = passthrough_input
425                .next()
426                .await
427                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
428            match event {
429                // NB: Ignore the data_cap because this input is disconnected.
430                AsyncEvent::Data(_data_cap, mut data) => {
431                    // NB: Nothing to do here for `until` because both the
432                    // `shard_source` (before this operator) and
433                    // `mfp_and_decode` (after this operator) do the necessary
434                    // filtering.
435                    debug!("{} emitting data {:?}", name, data);
436                    passthrough_output.give_container(&cap, &mut data);
437                }
438                AsyncEvent::Progress(new_progress) => {
439                    // If `until.less_equal(new_progress)`, it means that all
440                    // subsequent batches will contain only times greater or
441                    // equal to `until`, which means they can be dropped in
442                    // their entirety.
443                    //
444                    // Ideally this check would live in `txns_progress_source`,
445                    // but that turns out to be much more invasive (requires
446                    // replacing lots of `T`s with `Antichain<T>`s). Given that
447                    // we've been thinking about reworking the operators, do the
448                    // easy but more wasteful thing for now.
449                    if PartialOrder::less_equal(&until, &new_progress) {
450                        debug!(
451                            "{} progress {:?} has passed until {:?}",
452                            name,
453                            new_progress.elements(),
454                            until.elements()
455                        );
456                        return;
457                    }
458                    // We reached the empty frontier! Shut down.
459                    let Some(new_progress) = new_progress.into_option() else {
460                        return;
461                    };
462
463                    // Recall that any reads of the data shard are always
464                    // correct, so given that we've passed through any data
465                    // from the input, that means we're free to pass through
466                    // frontier updates too.
467                    if cap.time() < &new_progress {
468                        debug!("{} downgrading cap to {:?}", name, new_progress);
469                        cap.downgrade(&new_progress);
470                    }
471                }
472            }
473        }
474    });
475    (passthrough_stream, shutdown_button.press_on_drop())
476}
477
478async fn txns_progress_frontiers_read_remap_input<T, C>(
479    name: &str,
480    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
481    mut remap: DataRemapEntry<T>,
482) -> Option<DataRemapEntry<T>>
483where
484    T: Timestamp + TotalOrder,
485    C: InputConnection<T>,
486{
487    while let Some(event) = input.next().await {
488        let xs = match event {
489            AsyncEvent::Progress(logical_upper) => {
490                if let Some(logical_upper) = logical_upper.into_option() {
491                    if remap.logical_upper < logical_upper {
492                        remap.logical_upper = logical_upper;
493                        return Some(remap);
494                    }
495                }
496                continue;
497            }
498            AsyncEvent::Data(_cap, xs) => xs,
499        };
500        for x in xs {
501            debug!("{} got remap {:?}", name, x);
502            // Don't assume anything about the ordering.
503            if remap.logical_upper < x.logical_upper {
504                assert!(
505                    remap.physical_upper <= x.physical_upper,
506                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
507                    remap.physical_upper,
508                    x.physical_upper,
509                );
510                // TODO: If the physical upper has advanced, that's a very
511                // strong hint that the data shard is about to be written to.
512                // Because the data shard's upper advances sparsely (on write,
513                // but not on passage of time) which invalidates the "every 1s"
514                // assumption of the default tuning, we've had to de-tune the
515                // listen sleeps on the paired persist_source. Maybe we use "one
516                // state" to wake it up in case pubsub doesn't and remove the
517                // listen polling entirely? (NB: This would have to happen in
518                // each worker so that it's guaranteed to happen in each
519                // process.)
520                remap = x;
521            }
522        }
523        return Some(remap);
524    }
525    // remap_input is closed, which indicates the data shard is finished.
526    None
527}
528
529/// The process global [`TxnsRead`] that any operator can communicate with.
530#[derive(Default, Debug, Clone)]
531pub struct TxnsContext {
532    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
533}
534
535impl TxnsContext {
536    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
537    where
538        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
539        C: TxnsCodec + 'static,
540    {
541        let read = self
542            .read
543            .get_or_init(|| {
544                let client = client.clone();
545                async move {
546                    let read: Box<dyn Any + Send + Sync> =
547                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
548                    read
549                }
550            })
551            .await
552            .downcast_ref::<TxnsRead<T>>()
553            .expect("timestamp types should match");
554        // We initially only have one txns shard in the system.
555        assert_eq!(&txns_id, read.txns_id());
556        read.clone()
557    }
558}
559
560// Existing configs use the prefix "persist_txns_" for historical reasons. New
561// configs should use the prefix "txn_wal_".
562
563pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
564    "persist_txns_data_shard_retryer_initial_backoff",
565    Duration::from_millis(1024),
566    "The initial backoff when polling for new batches from a txns data shard persist_source.",
567);
568
569pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
570    "persist_txns_data_shard_retryer_multiplier",
571    2,
572    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
573);
574
575pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
576    "persist_txns_data_shard_retryer_clamp",
577    Duration::from_secs(16),
578    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
579);
580
581/// Retry configuration for txn-wal data shard override of
582/// `next_listen_batch`.
583pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
584    RetryParameters {
585        fixed_sleep: Duration::ZERO,
586        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
587        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
588        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
589    }
590}
591
592/// A helper for subscribing to a data shard using the timely operators.
593///
594/// This could instead be a wrapper around a [Subscribe], but it's only used in
595/// tests and maelstrom, so do it by wrapping the timely operators to get
596/// additional coverage. For the same reason, hardcode the K, V, T, D types.
597///
598/// [Subscribe]: mz_persist_client::read::Subscribe
599pub struct DataSubscribe {
600    pub(crate) as_of: u64,
601    pub(crate) worker: Worker<timely::communication::allocator::Thread>,
602    data: ProbeHandle<u64>,
603    txns: ProbeHandle<u64>,
604    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
605    output: Vec<(String, u64, i64)>,
606
607    _tokens: Vec<PressOnDropButton>,
608}
609
610impl std::fmt::Debug for DataSubscribe {
611    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612        let DataSubscribe {
613            as_of,
614            worker: _,
615            data,
616            txns,
617            capture: _,
618            output,
619            _tokens: _,
620        } = self;
621        f.debug_struct("DataSubscribe")
622            .field("as_of", as_of)
623            .field("data", data)
624            .field("txns", txns)
625            .field("output", output)
626            .finish_non_exhaustive()
627    }
628}
629
630impl DataSubscribe {
631    /// Creates a new [DataSubscribe].
632    pub fn new(
633        name: &str,
634        client: PersistClient,
635        txns_id: ShardId,
636        data_id: ShardId,
637        as_of: u64,
638        until: Antichain<u64>,
639        use_global_txn_cache: bool,
640    ) -> Self {
641        let mut worker = Worker::new(
642            WorkerConfig::default(),
643            timely::communication::allocator::Thread::default(),
644        );
645        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
646            let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
647                let client = client.clone();
648                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
649                    scope,
650                    name,
651                    move || std::future::ready(client.clone()),
652                    data_id,
653                    Some(Antichain::from_elem(as_of)),
654                    SnapshotMode::Include,
655                    until.clone(),
656                    false.then_some(|_, _: &_, _| unreachable!()),
657                    Arc::new(StringSchema),
658                    Arc::new(UnitSchema),
659                    FilterResult::keep_all,
660                    false.then_some(|| unreachable!()),
661                    async {},
662                    |error| panic!("data_subscribe: {error}"),
663                );
664                (data_stream.leave(), token)
665            });
666            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
667            let data_stream = data_stream.flat_map(|part| {
668                let part = part.parse();
669                part.part.map(|((k, v), t, d)| {
670                    let (k, ()) = (k.unwrap(), v.unwrap());
671                    (k, t, d)
672                })
673            });
674            let data_stream = data_stream.probe_with(&data);
675            // We purposely do not use the `ConfigSet` in `client` so that
676            // different tests can set different values.
677            let config_set = ConfigSet::default().add(&USE_GLOBAL_TXN_CACHE_SOURCE);
678            let mut updates = ConfigUpdates::default();
679            updates.add(&USE_GLOBAL_TXN_CACHE_SOURCE, use_global_txn_cache);
680            updates.apply(&config_set);
681            let (data_stream, mut txns_progress_token) =
682                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
683                    data_stream,
684                    name,
685                    &TxnsContext::default(),
686                    &config_set,
687                    || std::future::ready(client.clone()),
688                    txns_id,
689                    data_id,
690                    as_of,
691                    until,
692                    Arc::new(StringSchema),
693                    Arc::new(UnitSchema),
694                );
695            let data_stream = data_stream.probe_with(&txns);
696            let mut tokens = shard_source_token;
697            tokens.append(&mut txns_progress_token);
698            (data, txns, data_stream.capture(), tokens)
699        });
700        Self {
701            as_of,
702            worker,
703            data,
704            txns,
705            capture,
706            output: Vec::new(),
707            _tokens: tokens,
708        }
709    }
710
711    /// Returns the exclusive progress of the dataflow.
712    pub fn progress(&self) -> u64 {
713        self.txns
714            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
715    }
716
717    /// Steps the dataflow, capturing output.
718    pub fn step(&mut self) {
719        self.worker.step();
720        self.capture_output()
721    }
722
723    pub(crate) fn capture_output(&mut self) {
724        loop {
725            let event = match self.capture.try_recv() {
726                Ok(x) => x,
727                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
728            };
729            match event {
730                Event::Progress(_) => {}
731                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
732            }
733        }
734    }
735
736    /// Steps the dataflow past the given time, capturing output.
737    #[cfg(test)]
738    pub async fn step_past(&mut self, ts: u64) {
739        while self.txns.less_equal(&ts) {
740            tracing::trace!(
741                "progress at {:?}",
742                self.txns.with_frontier(|x| x.to_owned()).elements()
743            );
744            self.step();
745            tokio::task::yield_now().await;
746        }
747    }
748
749    /// Returns captured output.
750    pub fn output(&self) -> &Vec<(String, u64, i64)> {
751        &self.output
752    }
753}
754
755/// A handle to a [DataSubscribe] running in a task.
756#[derive(Debug)]
757pub struct DataSubscribeTask {
758    /// Carries step requests. A `None` timestamp requests one step, a
759    /// `Some(ts)` requests stepping until we progress beyond `ts`.
760    tx: std::sync::mpsc::Sender<(
761        Option<u64>,
762        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
763    )>,
764    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
765    output: Vec<(String, u64, i64)>,
766    progress: u64,
767}
768
769impl DataSubscribeTask {
770    /// Creates a new [DataSubscribeTask].
771    pub async fn new(
772        client: PersistClient,
773        txns_id: ShardId,
774        data_id: ShardId,
775        as_of: u64,
776    ) -> Self {
777        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
778        let (tx, rx) = std::sync::mpsc::channel();
779        let task = mz_ore::task::spawn_blocking(
780            || "data_subscribe task",
781            move || Self::task(client, cache, data_id, as_of, rx),
782        );
783        DataSubscribeTask {
784            tx,
785            task,
786            output: Vec::new(),
787            progress: 0,
788        }
789    }
790
791    #[cfg(test)]
792    async fn step(&mut self) {
793        self.send(None).await;
794    }
795
796    /// Steps the dataflow past the given time, capturing output.
797    pub async fn step_past(&mut self, ts: u64) -> u64 {
798        self.send(Some(ts)).await;
799        self.progress
800    }
801
802    /// Returns captured output.
803    pub fn output(&self) -> &Vec<(String, u64, i64)> {
804        &self.output
805    }
806
807    async fn send(&mut self, ts: Option<u64>) {
808        let (tx, rx) = tokio::sync::oneshot::channel();
809        self.tx.send((ts, tx)).expect("task should be running");
810        let (mut new_output, new_progress) = rx.await.expect("task should be running");
811        self.output.append(&mut new_output);
812        assert!(self.progress <= new_progress);
813        self.progress = new_progress;
814    }
815
816    /// Signals for the task to exit, and then waits for this to happen.
817    ///
818    /// _All_ output from the lifetime of the task (not just what was previously
819    /// captured) is returned.
820    pub async fn finish(self) -> Vec<(String, u64, i64)> {
821        // Closing the channel signals the task to exit.
822        drop(self.tx);
823        self.task.wait_and_assert_finished().await
824    }
825
826    fn task(
827        client: PersistClient,
828        cache: TxnsCache<u64>,
829        data_id: ShardId,
830        as_of: u64,
831        rx: std::sync::mpsc::Receiver<(
832            Option<u64>,
833            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
834        )>,
835    ) -> Vec<(String, u64, i64)> {
836        let mut subscribe = DataSubscribe::new(
837            "DataSubscribeTask",
838            client.clone(),
839            cache.txns_id(),
840            data_id,
841            as_of,
842            Antichain::new(),
843            true,
844        );
845        let mut output = Vec::new();
846        loop {
847            let (ts, tx) = match rx.try_recv() {
848                Ok(x) => x,
849                Err(TryRecvError::Empty) => {
850                    // No requests, continue stepping so nothing deadlocks.
851                    subscribe.step();
852                    continue;
853                }
854                Err(TryRecvError::Disconnected) => {
855                    // All done! Return our output.
856                    return output;
857                }
858            };
859            // Always step at least once.
860            subscribe.step();
861            // If we got a ts, make sure to step past it.
862            if let Some(ts) = ts {
863                while subscribe.progress() <= ts {
864                    subscribe.step();
865                }
866            }
867            let new_output = std::mem::take(&mut subscribe.output);
868            output.extend(new_output.iter().cloned());
869            let _ = tx.send((new_output, subscribe.progress()));
870        }
871    }
872}
873
874#[cfg(test)]
875mod tests {
876    use itertools::{Either, Itertools};
877    use mz_persist_types::Opaque;
878
879    use crate::tests::writer;
880    use crate::txns::TxnsHandle;
881
882    use super::*;
883
884    impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
885    where
886        K: Debug + Codec,
887        V: Debug + Codec,
888        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
889        D: Debug + Semigroup + Ord + Codec64 + Send + Sync,
890        O: Opaque + Debug + Codec64,
891        C: TxnsCodec,
892    {
893        async fn subscribe_task(
894            &self,
895            client: &PersistClient,
896            data_id: ShardId,
897            as_of: u64,
898        ) -> DataSubscribeTask {
899            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
900        }
901    }
902
903    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
904    #[cfg_attr(miri, ignore)] // too slow
905    async fn data_subscribe() {
906        async fn step(subs: &mut Vec<DataSubscribeTask>) {
907            for sub in subs.iter_mut() {
908                sub.step().await;
909            }
910        }
911
912        let client = PersistClient::new_for_tests().await;
913        let mut txns = TxnsHandle::expect_open(client.clone()).await;
914        let log = txns.new_log();
915        let d0 = ShardId::new();
916
917        // Start a subscription before the shard gets registered.
918        let mut subs = Vec::new();
919        subs.push(txns.subscribe_task(&client, d0, 5).await);
920        step(&mut subs).await;
921
922        // Now register the shard. Also start a new subscription and step the
923        // previous one (plus repeat this for every later step).
924        txns.register(1, [writer(&client, d0).await]).await.unwrap();
925        subs.push(txns.subscribe_task(&client, d0, 5).await);
926        step(&mut subs).await;
927
928        // Now write something unrelated.
929        let d1 = txns.expect_register(2).await;
930        txns.expect_commit_at(3, d1, &["nope"], &log).await;
931        subs.push(txns.subscribe_task(&client, d0, 5).await);
932        step(&mut subs).await;
933
934        // Now write to our shard before.
935        txns.expect_commit_at(4, d0, &["4"], &log).await;
936        subs.push(txns.subscribe_task(&client, d0, 5).await);
937        step(&mut subs).await;
938
939        // Now write to our shard at the as_of.
940        txns.expect_commit_at(5, d0, &["5"], &log).await;
941        subs.push(txns.subscribe_task(&client, d0, 5).await);
942        step(&mut subs).await;
943
944        // Now write to our shard past the as_of.
945        txns.expect_commit_at(6, d0, &["6"], &log).await;
946        subs.push(txns.subscribe_task(&client, d0, 5).await);
947        step(&mut subs).await;
948
949        // Now write something unrelated again.
950        txns.expect_commit_at(7, d1, &["nope"], &log).await;
951        subs.push(txns.subscribe_task(&client, d0, 5).await);
952        step(&mut subs).await;
953
954        // Verify that the dataflows can progress to the expected point and that
955        // we read the right thing no matter when the dataflow started.
956        for mut sub in subs {
957            let progress = sub.step_past(7).await;
958            assert_eq!(progress, 8);
959            log.assert_eq(d0, 5, 8, sub.finish().await);
960        }
961    }
962
963    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
964    #[cfg_attr(miri, ignore)] // too slow
965    async fn subscribe_shard_finalize() {
966        let client = PersistClient::new_for_tests().await;
967        let mut txns = TxnsHandle::expect_open(client.clone()).await;
968        let log = txns.new_log();
969        let d0 = txns.expect_register(1).await;
970
971        // Start the operator as_of the register ts.
972        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
973        sub.step_past(1).await;
974
975        // Write to it via txns.
976        txns.expect_commit_at(2, d0, &["foo"], &log).await;
977        sub.step_past(2).await;
978
979        // Unregister it.
980        txns.forget(3, [d0]).await.unwrap();
981        sub.step_past(3).await;
982
983        // TODO: Hard mode, see if we can get the rest of this test to work even
984        // _without_ the txns shard advancing.
985        txns.begin().commit_at(&mut txns, 7).await.unwrap();
986
987        // The operator should continue to emit data written directly even
988        // though it's no longer in the txns set.
989        let mut d0_write = writer(&client, d0).await;
990        let key = "bar".to_owned();
991        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
992            .await
993            .unwrap();
994        log.record((d0, key, 5, 1));
995        sub.step_past(4).await;
996
997        // Now finalize the shard to writes.
998        let () = d0_write
999            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new())
1000            .await
1001            .unwrap()
1002            .unwrap();
1003        while sub.txns.less_than(&u64::MAX) {
1004            sub.step();
1005            tokio::task::yield_now().await;
1006        }
1007
1008        // Make sure we read the correct things.
1009        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
1010
1011        // Also make sure that we can read the right things if we start up after
1012        // the forget but before the direct write and ditto after the direct
1013        // write.
1014        log.assert_subscribe(d0, 4, u64::MAX).await;
1015        log.assert_subscribe(d0, 6, u64::MAX).await;
1016    }
1017
1018    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1019    #[cfg_attr(miri, ignore)] // too slow
1020    async fn subscribe_shard_register_forget() {
1021        let client = PersistClient::new_for_tests().await;
1022        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1023        let d0 = ShardId::new();
1024
1025        // Start a subscription on the data shard.
1026        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
1027        assert_eq!(sub.progress(), 0);
1028
1029        // Register the shard at 10.
1030        txns.register(10, [writer(&client, d0).await])
1031            .await
1032            .unwrap();
1033        sub.step_past(10).await;
1034        assert!(
1035            sub.progress() > 10,
1036            "operator should advance past 10 when shard is registered"
1037        );
1038
1039        // Forget the shard at 20.
1040        txns.forget(20, [d0]).await.unwrap();
1041        sub.step_past(20).await;
1042        assert!(
1043            sub.progress() > 20,
1044            "operator should advance past 20 when shard is forgotten"
1045        );
1046    }
1047
1048    #[mz_ore::test(tokio::test)]
1049    #[cfg_attr(miri, ignore)] // too slow
1050    async fn as_of_until() {
1051        let client = PersistClient::new_for_tests().await;
1052        let mut txns = TxnsHandle::expect_open(client.clone()).await;
1053        let log = txns.new_log();
1054
1055        let d0 = txns.expect_register(1).await;
1056        txns.expect_commit_at(2, d0, &["2"], &log).await;
1057        txns.expect_commit_at(3, d0, &["3"], &log).await;
1058        txns.expect_commit_at(4, d0, &["4"], &log).await;
1059        txns.expect_commit_at(5, d0, &["5"], &log).await;
1060        txns.expect_commit_at(6, d0, &["6"], &log).await;
1061        txns.expect_commit_at(7, d0, &["7"], &log).await;
1062
1063        let until = 5;
1064        let mut sub = DataSubscribe::new(
1065            "as_of_until",
1066            client,
1067            txns.txns_id(),
1068            d0,
1069            3,
1070            Antichain::from_elem(until),
1071            true,
1072        );
1073        // Manually step the dataflow, instead of going through the
1074        // `DataSubscribe` helper because we're interested in all captured
1075        // events.
1076        while sub.txns.less_equal(&5) {
1077            sub.worker.step();
1078            tokio::task::yield_now().await;
1079            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1080        }
1081        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
1082            sub.capture.into_iter().partition_map(|event| match event {
1083                Event::Progress(progress) => Either::Left(progress),
1084                Event::Messages(ts, data) => Either::Right((ts, data)),
1085            });
1086        let expected = vec![
1087            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
1088            (3, vec![("4".to_owned(), 4, 1)]),
1089        ];
1090        assert_eq!(actual_events, expected);
1091
1092        // The number and contents of progress messages is not guaranteed and
1093        // depends on the downgrade behavior. The only thing we can assert is
1094        // the max progress timestamp, if there is one, is less than the until.
1095        if let Some(max_progress_ts) = actual_progresses
1096            .into_iter()
1097            .flatten()
1098            .map(|(ts, _diff)| ts)
1099            .max()
1100        {
1101            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
1102        }
1103    }
1104}