mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet};
24use mz_ore::cast::CastFrom;
25use mz_ore::task::JoinHandleExt;
26use mz_persist_client::cfg::RetryParameters;
27use mz_persist_client::operators::shard_source::{
28    ErrorHandler, FilterResult, SnapshotMode, shard_source,
29};
30use mz_persist_client::{Diagnostics, PersistClient, ShardId};
31use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
32use mz_persist_types::txn::TxnsCodec;
33use mz_persist_types::{Codec, Codec64, StepForward};
34use mz_timely_util::builder_async::{
35    AsyncInputHandle, Event as AsyncEvent, InputConnection,
36    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
37};
38use timely::container::CapacityContainerBuilder;
39use timely::dataflow::channels::pact::Pipeline;
40use timely::dataflow::operators::capture::Event;
41use timely::dataflow::operators::{Broadcast, Capture, Leave, Map, Probe};
42use timely::dataflow::{ProbeHandle, Scope, Stream};
43use timely::order::TotalOrder;
44use timely::progress::{Antichain, Timestamp};
45use timely::worker::Worker;
46use timely::{Data, PartialOrder, WorkerConfig};
47use tracing::debug;
48
49use crate::TxnsCodecDefault;
50use crate::txn_cache::TxnsCache;
51use crate::txn_read::{DataRemapEntry, TxnsRead};
52
53/// An operator for translating physical data shard frontiers into logical ones.
54///
55/// A data shard in the txns set logically advances its upper each time a txn is
56/// committed, but the upper is not physically advanced unless that data shard
57/// was involved in the txn. This means that a shard_source (or any read)
58/// pointed at a data shard would appear to stall at the time of the most recent
59/// write. We fix this for shard_source by flowing its output through a new
60/// `txns_progress` dataflow operator, which ensures that the
61/// frontier/capability is advanced as the txns shard progresses, as long as the
62/// shard_source is up to date with the latest committed write to that data
63/// shard.
64///
65/// Example:
66///
67/// - A data shard has most recently been written to at 3.
68/// - The txns shard's upper is at 6.
69/// - We render a dataflow containing a shard_source with an as_of of 5.
70/// - A txn NOT involving the data shard is committed at 7.
71/// - A txn involving the data shard is committed at 9.
72///
73/// How it works:
74///
75/// - The shard_source operator is rendered. Its single output is hooked up as a
76///   _disconnected_ input to txns_progress. The txns_progress single output is
77///   a stream of the same type, which is used by downstream operators. This
78///   txns_progress operator is targeted at one data_shard; rendering a
79///   shard_source for a second data shard requires a second txns_progress
80///   operator.
81/// - The shard_source operator emits data through 3 and advances the frontier.
82/// - The txns_progress operator passes through these writes and frontier
83///   advancements unchanged. (Recall that it's always correct to read a data
84///   shard "normally", it just might stall.) Because the txns_progress operator
85///   knows there are no writes in `[3,5]`, it then downgrades its own
86///   capability past 5 (to 6). Because the input is disconnected, this means
87///   the overall frontier of the output is downgraded to 6.
88/// - The txns_progress operator learns about the write at 7 (the upper is now
89///   8). Because it knows that the data shard was not involved in this, it's
90///   free to downgrade its capability to 8.
91/// - The txns_progress operator learns about the write at 9 (the upper is now
92///   10). It knows that the data shard _WAS_ involved in this, so it forwards
93///   on data from its input until the input has progressed to 10, at which
94///   point it can itself downgrade to 10.
95pub fn txns_progress<K, V, T, D, P, C, F, G>(
96    passthrough: Stream<G, P>,
97    name: &str,
98    ctx: &TxnsContext,
99    client_fn: impl Fn() -> F,
100    txns_id: ShardId,
101    data_id: ShardId,
102    as_of: T,
103    until: Antichain<T>,
104    data_key_schema: Arc<K::Schema>,
105    data_val_schema: Arc<V::Schema>,
106) -> (Stream<G, P>, Vec<PressOnDropButton>)
107where
108    K: Debug + Codec + Send + Sync,
109    V: Debug + Codec + Send + Sync,
110    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
111    D: Debug + Data + Monoid + Ord + Codec64 + Send + Sync,
112    P: Debug + Data,
113    C: TxnsCodec + 'static,
114    F: Future<Output = PersistClient> + Send + 'static,
115    G: Scope<Timestamp = T>,
116{
117    let unique_id = (name, passthrough.scope().addr()).hashed();
118    let (remap, source_button) = txns_progress_source_global::<K, V, T, D, P, C, G>(
119        passthrough.scope(),
120        name,
121        ctx.clone(),
122        client_fn(),
123        txns_id,
124        data_id,
125        as_of,
126        data_key_schema,
127        data_val_schema,
128        unique_id,
129    );
130    // Each of the `txns_frontiers` workers wants the full copy of the remap
131    // information.
132    let remap = remap.broadcast();
133    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
134        remap,
135        passthrough,
136        name,
137        data_id,
138        until,
139        unique_id,
140    );
141    (passthrough, vec![source_button, frontiers_button])
142}
143
144/// TODO: I'd much prefer the communication protocol between the two operators
145/// to be exactly remap as defined in the [reclocking design doc]. However, we
146/// can't quite recover exactly the information necessary to construct that at
147/// the moment. Seems worth doing, but in the meantime, intentionally make this
148/// look fairly different (`Stream` of `DataRemapEntry` instead of
149/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
150/// optimization, we only re-emit this when the _physical_ upper has changed,
151/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
152/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
153/// one with the largest physical_upper).
154///
155/// [reclocking design doc]:
156///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
157fn txns_progress_source_global<K, V, T, D, P, C, G>(
158    scope: G,
159    name: &str,
160    ctx: TxnsContext,
161    client: impl Future<Output = PersistClient> + 'static,
162    txns_id: ShardId,
163    data_id: ShardId,
164    as_of: T,
165    data_key_schema: Arc<K::Schema>,
166    data_val_schema: Arc<V::Schema>,
167    unique_id: u64,
168) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
169where
170    K: Debug + Codec + Send + Sync,
171    V: Debug + Codec + Send + Sync,
172    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
173    D: Debug + Data + Monoid + Ord + Codec64 + Send + Sync,
174    P: Debug + Data,
175    C: TxnsCodec + 'static,
176    G: Scope<Timestamp = T>,
177{
178    let worker_idx = scope.index();
179    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
180    let name = format!("txns_progress_source({})", name);
181    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
182    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
183    let (remap_output, remap_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
184
185    let shutdown_button = builder.build(move |capabilities| async move {
186        if worker_idx != chosen_worker {
187            return;
188        }
189
190        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
191        let client = client.await;
192        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
193
194        let _ = txns_read.update_gt(as_of.clone()).await;
195        let data_write = client
196            .open_writer::<K, V, T, D>(
197                data_id,
198                Arc::clone(&data_key_schema),
199                Arc::clone(&data_val_schema),
200                Diagnostics::from_purpose("data read physical upper"),
201            )
202            .await
203            .expect("schema shouldn't change");
204        let mut rx = txns_read
205            .data_subscribe(data_id, as_of.clone(), data_write)
206            .await;
207        debug!("{} starting as_of={:?}", name, as_of);
208
209        let mut physical_upper = T::minimum();
210
211        while let Some(remap) = rx.recv().await {
212            assert!(physical_upper <= remap.physical_upper);
213            assert!(physical_upper < remap.logical_upper);
214
215            let logical_upper = remap.logical_upper.clone();
216            // As mentioned in the docs on this function, we only
217            // emit updates when the physical upper changes (which
218            // happens to makes the protocol a tiny bit more
219            // remap-like).
220            if remap.physical_upper != physical_upper {
221                physical_upper = remap.physical_upper.clone();
222                debug!("{} emitting {:?}", name, remap);
223                remap_output.give(&cap, remap);
224            } else {
225                debug!("{} not emitting {:?}", name, remap);
226            }
227            cap.downgrade(&logical_upper);
228        }
229    });
230    (remap_stream, shutdown_button.press_on_drop())
231}
232
233fn txns_progress_frontiers<K, V, T, D, P, C, G>(
234    remap: Stream<G, DataRemapEntry<T>>,
235    passthrough: Stream<G, P>,
236    name: &str,
237    data_id: ShardId,
238    until: Antichain<T>,
239    unique_id: u64,
240) -> (Stream<G, P>, PressOnDropButton)
241where
242    K: Debug + Codec,
243    V: Debug + Codec,
244    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
245    D: Data + Monoid + Codec64 + Send + Sync,
246    P: Debug + Data,
247    C: TxnsCodec,
248    G: Scope<Timestamp = T>,
249{
250    let name = format!("txns_progress_frontiers({})", name);
251    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
252    let name = format!(
253        "{} [{}] {}/{} {:.9}",
254        name,
255        unique_id,
256        passthrough.scope().index(),
257        passthrough.scope().peers(),
258        data_id.to_string(),
259    );
260    let (passthrough_output, passthrough_stream) =
261        builder.new_output::<CapacityContainerBuilder<_>>();
262    let mut remap_input = builder.new_disconnected_input(&remap, Pipeline);
263    let mut passthrough_input = builder.new_disconnected_input(&passthrough, Pipeline);
264
265    let shutdown_button = builder.build(move |capabilities| async move {
266        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
267
268        // None is used to indicate that both uppers are the empty antichain.
269        let mut remap = Some(DataRemapEntry {
270            physical_upper: T::minimum(),
271            logical_upper: T::minimum(),
272        });
273        // NB: The following loop uses `cap.time()`` to track how far we've
274        // progressed in copying along the passthrough input.
275        loop {
276            debug!("{} remap {:?}", name, remap);
277            if let Some(r) = remap.as_ref() {
278                assert!(r.physical_upper <= r.logical_upper);
279                // If we've passed through data to at least `physical_upper`,
280                // then it means we can artificially advance the upper of the
281                // output to `logical_upper`. This also indicates that we need
282                // to wait for the next DataRemapEntry. It can either (A) have
283                // the same physical upper or (B) have a larger physical upper.
284                //
285                // - If (A), then we would again satisfy this `physical_upper`
286                //   check, again advance the logical upper again, ...
287                // - If (B), then we'd fall down to the code below, which copies
288                //   the passthrough data until the frontier passes
289                //   `physical_upper`, then loops back up here.
290                if r.physical_upper.less_equal(cap.time()) {
291                    if cap.time() < &r.logical_upper {
292                        cap.downgrade(&r.logical_upper);
293                    }
294                    remap = txns_progress_frontiers_read_remap_input(
295                        &name,
296                        &mut remap_input,
297                        r.clone(),
298                    )
299                    .await;
300                    continue;
301                }
302            }
303
304            // This only returns None when there are no more data left. Turn it
305            // into an empty frontier progress so we can re-use the shutdown
306            // code below.
307            let event = passthrough_input
308                .next()
309                .await
310                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
311            match event {
312                // NB: Ignore the data_cap because this input is disconnected.
313                AsyncEvent::Data(_data_cap, mut data) => {
314                    // NB: Nothing to do here for `until` because both the
315                    // `shard_source` (before this operator) and
316                    // `mfp_and_decode` (after this operator) do the necessary
317                    // filtering.
318                    debug!("{} emitting data {:?}", name, data);
319                    passthrough_output.give_container(&cap, &mut data);
320                }
321                AsyncEvent::Progress(new_progress) => {
322                    // If `until.less_equal(new_progress)`, it means that all
323                    // subsequent batches will contain only times greater or
324                    // equal to `until`, which means they can be dropped in
325                    // their entirety.
326                    //
327                    // Ideally this check would live in `txns_progress_source`,
328                    // but that turns out to be much more invasive (requires
329                    // replacing lots of `T`s with `Antichain<T>`s). Given that
330                    // we've been thinking about reworking the operators, do the
331                    // easy but more wasteful thing for now.
332                    if PartialOrder::less_equal(&until, &new_progress) {
333                        debug!(
334                            "{} progress {:?} has passed until {:?}",
335                            name,
336                            new_progress.elements(),
337                            until.elements()
338                        );
339                        return;
340                    }
341                    // We reached the empty frontier! Shut down.
342                    let Some(new_progress) = new_progress.into_option() else {
343                        return;
344                    };
345
346                    // Recall that any reads of the data shard are always
347                    // correct, so given that we've passed through any data
348                    // from the input, that means we're free to pass through
349                    // frontier updates too.
350                    if cap.time() < &new_progress {
351                        debug!("{} downgrading cap to {:?}", name, new_progress);
352                        cap.downgrade(&new_progress);
353                    }
354                }
355            }
356        }
357    });
358    (passthrough_stream, shutdown_button.press_on_drop())
359}
360
361async fn txns_progress_frontiers_read_remap_input<T, C>(
362    name: &str,
363    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
364    mut remap: DataRemapEntry<T>,
365) -> Option<DataRemapEntry<T>>
366where
367    T: Timestamp + TotalOrder,
368    C: InputConnection<T>,
369{
370    while let Some(event) = input.next().await {
371        let xs = match event {
372            AsyncEvent::Progress(logical_upper) => {
373                if let Some(logical_upper) = logical_upper.into_option() {
374                    if remap.logical_upper < logical_upper {
375                        remap.logical_upper = logical_upper;
376                        return Some(remap);
377                    }
378                }
379                continue;
380            }
381            AsyncEvent::Data(_cap, xs) => xs,
382        };
383        for x in xs {
384            debug!("{} got remap {:?}", name, x);
385            // Don't assume anything about the ordering.
386            if remap.logical_upper < x.logical_upper {
387                assert!(
388                    remap.physical_upper <= x.physical_upper,
389                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
390                    remap.physical_upper,
391                    x.physical_upper,
392                );
393                // TODO: If the physical upper has advanced, that's a very
394                // strong hint that the data shard is about to be written to.
395                // Because the data shard's upper advances sparsely (on write,
396                // but not on passage of time) which invalidates the "every 1s"
397                // assumption of the default tuning, we've had to de-tune the
398                // listen sleeps on the paired persist_source. Maybe we use "one
399                // state" to wake it up in case pubsub doesn't and remove the
400                // listen polling entirely? (NB: This would have to happen in
401                // each worker so that it's guaranteed to happen in each
402                // process.)
403                remap = x;
404            }
405        }
406        return Some(remap);
407    }
408    // remap_input is closed, which indicates the data shard is finished.
409    None
410}
411
412/// The process global [`TxnsRead`] that any operator can communicate with.
413#[derive(Default, Debug, Clone)]
414pub struct TxnsContext {
415    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
416}
417
418impl TxnsContext {
419    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
420    where
421        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
422        C: TxnsCodec + 'static,
423    {
424        let read = self
425            .read
426            .get_or_init(|| {
427                let client = client.clone();
428                async move {
429                    let read: Box<dyn Any + Send + Sync> =
430                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
431                    read
432                }
433            })
434            .await
435            .downcast_ref::<TxnsRead<T>>()
436            .expect("timestamp types should match");
437        // We initially only have one txns shard in the system.
438        assert_eq!(&txns_id, read.txns_id());
439        read.clone()
440    }
441}
442
443// Existing configs use the prefix "persist_txns_" for historical reasons. New
444// configs should use the prefix "txn_wal_".
445
446pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
447    "persist_txns_data_shard_retryer_initial_backoff",
448    Duration::from_millis(1024),
449    "The initial backoff when polling for new batches from a txns data shard persist_source.",
450);
451
452pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
453    "persist_txns_data_shard_retryer_multiplier",
454    2,
455    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
456);
457
458pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
459    "persist_txns_data_shard_retryer_clamp",
460    Duration::from_secs(16),
461    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
462);
463
464/// Retry configuration for txn-wal data shard override of
465/// `next_listen_batch`.
466pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
467    RetryParameters {
468        fixed_sleep: Duration::ZERO,
469        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
470        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
471        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
472    }
473}
474
475/// A helper for subscribing to a data shard using the timely operators.
476///
477/// This could instead be a wrapper around a [Subscribe], but it's only used in
478/// tests and maelstrom, so do it by wrapping the timely operators to get
479/// additional coverage. For the same reason, hardcode the K, V, T, D types.
480///
481/// [Subscribe]: mz_persist_client::read::Subscribe
482pub struct DataSubscribe {
483    pub(crate) as_of: u64,
484    pub(crate) worker: Worker<timely::communication::allocator::Thread>,
485    data: ProbeHandle<u64>,
486    txns: ProbeHandle<u64>,
487    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
488    output: Vec<(String, u64, i64)>,
489
490    _tokens: Vec<PressOnDropButton>,
491}
492
493impl std::fmt::Debug for DataSubscribe {
494    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        let DataSubscribe {
496            as_of,
497            worker: _,
498            data,
499            txns,
500            capture: _,
501            output,
502            _tokens: _,
503        } = self;
504        f.debug_struct("DataSubscribe")
505            .field("as_of", as_of)
506            .field("data", data)
507            .field("txns", txns)
508            .field("output", output)
509            .finish_non_exhaustive()
510    }
511}
512
513impl DataSubscribe {
514    /// Creates a new [DataSubscribe].
515    pub fn new(
516        name: &str,
517        client: PersistClient,
518        txns_id: ShardId,
519        data_id: ShardId,
520        as_of: u64,
521        until: Antichain<u64>,
522    ) -> Self {
523        let mut worker = Worker::new(
524            WorkerConfig::default(),
525            timely::communication::allocator::Thread::default(),
526            Some(std::time::Instant::now()),
527        );
528        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
529            let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
530                let client = client.clone();
531                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
532                    scope,
533                    name,
534                    move || std::future::ready(client.clone()),
535                    data_id,
536                    Some(Antichain::from_elem(as_of)),
537                    SnapshotMode::Include,
538                    until.clone(),
539                    false.then_some(|_, _: &_, _| unreachable!()),
540                    Arc::new(StringSchema),
541                    Arc::new(UnitSchema),
542                    FilterResult::keep_all,
543                    false.then_some(|| unreachable!()),
544                    async {},
545                    ErrorHandler::Halt("data_subscribe"),
546                );
547                (data_stream.leave(), token)
548            });
549            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
550            let data_stream = data_stream.flat_map(|part| {
551                let part = part.parse();
552                part.part.map(|((k, v), t, d)| {
553                    let (k, ()) = (k.unwrap(), v.unwrap());
554                    (k, t, d)
555                })
556            });
557            let data_stream = data_stream.probe_with(&data);
558            let (data_stream, mut txns_progress_token) =
559                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
560                    data_stream,
561                    name,
562                    &TxnsContext::default(),
563                    || std::future::ready(client.clone()),
564                    txns_id,
565                    data_id,
566                    as_of,
567                    until,
568                    Arc::new(StringSchema),
569                    Arc::new(UnitSchema),
570                );
571            let data_stream = data_stream.probe_with(&txns);
572            let mut tokens = shard_source_token;
573            tokens.append(&mut txns_progress_token);
574            (data, txns, data_stream.capture(), tokens)
575        });
576        Self {
577            as_of,
578            worker,
579            data,
580            txns,
581            capture,
582            output: Vec::new(),
583            _tokens: tokens,
584        }
585    }
586
587    /// Returns the exclusive progress of the dataflow.
588    pub fn progress(&self) -> u64 {
589        self.txns
590            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
591    }
592
593    /// Steps the dataflow, capturing output.
594    pub fn step(&mut self) {
595        self.worker.step();
596        self.capture_output()
597    }
598
599    pub(crate) fn capture_output(&mut self) {
600        loop {
601            let event = match self.capture.try_recv() {
602                Ok(x) => x,
603                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
604            };
605            match event {
606                Event::Progress(_) => {}
607                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
608            }
609        }
610    }
611
612    /// Steps the dataflow past the given time, capturing output.
613    #[cfg(test)]
614    pub async fn step_past(&mut self, ts: u64) {
615        while self.txns.less_equal(&ts) {
616            tracing::trace!(
617                "progress at {:?}",
618                self.txns.with_frontier(|x| x.to_owned()).elements()
619            );
620            self.step();
621            tokio::task::yield_now().await;
622        }
623    }
624
625    /// Returns captured output.
626    pub fn output(&self) -> &Vec<(String, u64, i64)> {
627        &self.output
628    }
629}
630
631/// A handle to a [DataSubscribe] running in a task.
632#[derive(Debug)]
633pub struct DataSubscribeTask {
634    /// Carries step requests. A `None` timestamp requests one step, a
635    /// `Some(ts)` requests stepping until we progress beyond `ts`.
636    tx: std::sync::mpsc::Sender<(
637        Option<u64>,
638        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
639    )>,
640    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
641    output: Vec<(String, u64, i64)>,
642    progress: u64,
643}
644
645impl DataSubscribeTask {
646    /// Creates a new [DataSubscribeTask].
647    pub async fn new(
648        client: PersistClient,
649        txns_id: ShardId,
650        data_id: ShardId,
651        as_of: u64,
652    ) -> Self {
653        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
654        let (tx, rx) = std::sync::mpsc::channel();
655        let task = mz_ore::task::spawn_blocking(
656            || "data_subscribe task",
657            move || Self::task(client, cache, data_id, as_of, rx),
658        );
659        DataSubscribeTask {
660            tx,
661            task,
662            output: Vec::new(),
663            progress: 0,
664        }
665    }
666
667    #[cfg(test)]
668    async fn step(&mut self) {
669        self.send(None).await;
670    }
671
672    /// Steps the dataflow past the given time, capturing output.
673    pub async fn step_past(&mut self, ts: u64) -> u64 {
674        self.send(Some(ts)).await;
675        self.progress
676    }
677
678    /// Returns captured output.
679    pub fn output(&self) -> &Vec<(String, u64, i64)> {
680        &self.output
681    }
682
683    async fn send(&mut self, ts: Option<u64>) {
684        let (tx, rx) = tokio::sync::oneshot::channel();
685        self.tx.send((ts, tx)).expect("task should be running");
686        let (mut new_output, new_progress) = rx.await.expect("task should be running");
687        self.output.append(&mut new_output);
688        assert!(self.progress <= new_progress);
689        self.progress = new_progress;
690    }
691
692    /// Signals for the task to exit, and then waits for this to happen.
693    ///
694    /// _All_ output from the lifetime of the task (not just what was previously
695    /// captured) is returned.
696    pub async fn finish(self) -> Vec<(String, u64, i64)> {
697        // Closing the channel signals the task to exit.
698        drop(self.tx);
699        self.task.wait_and_assert_finished().await
700    }
701
702    fn task(
703        client: PersistClient,
704        cache: TxnsCache<u64>,
705        data_id: ShardId,
706        as_of: u64,
707        rx: std::sync::mpsc::Receiver<(
708            Option<u64>,
709            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
710        )>,
711    ) -> Vec<(String, u64, i64)> {
712        let mut subscribe = DataSubscribe::new(
713            "DataSubscribeTask",
714            client.clone(),
715            cache.txns_id(),
716            data_id,
717            as_of,
718            Antichain::new(),
719        );
720        let mut output = Vec::new();
721        loop {
722            let (ts, tx) = match rx.try_recv() {
723                Ok(x) => x,
724                Err(TryRecvError::Empty) => {
725                    // No requests, continue stepping so nothing deadlocks.
726                    subscribe.step();
727                    continue;
728                }
729                Err(TryRecvError::Disconnected) => {
730                    // All done! Return our output.
731                    return output;
732                }
733            };
734            // Always step at least once.
735            subscribe.step();
736            // If we got a ts, make sure to step past it.
737            if let Some(ts) = ts {
738                while subscribe.progress() <= ts {
739                    subscribe.step();
740                }
741            }
742            let new_output = std::mem::take(&mut subscribe.output);
743            output.extend(new_output.iter().cloned());
744            let _ = tx.send((new_output, subscribe.progress()));
745        }
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use itertools::{Either, Itertools};
752    use mz_persist_types::Opaque;
753
754    use crate::tests::writer;
755    use crate::txns::TxnsHandle;
756
757    use super::*;
758
759    impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
760    where
761        K: Debug + Codec,
762        V: Debug + Codec,
763        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
764        D: Debug + Monoid + Ord + Codec64 + Send + Sync,
765        O: Opaque + Debug + Codec64,
766        C: TxnsCodec,
767    {
768        async fn subscribe_task(
769            &self,
770            client: &PersistClient,
771            data_id: ShardId,
772            as_of: u64,
773        ) -> DataSubscribeTask {
774            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
775        }
776    }
777
778    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
779    #[cfg_attr(miri, ignore)] // too slow
780    async fn data_subscribe() {
781        async fn step(subs: &mut Vec<DataSubscribeTask>) {
782            for sub in subs.iter_mut() {
783                sub.step().await;
784            }
785        }
786
787        let client = PersistClient::new_for_tests().await;
788        let mut txns = TxnsHandle::expect_open(client.clone()).await;
789        let log = txns.new_log();
790        let d0 = ShardId::new();
791
792        // Start a subscription before the shard gets registered.
793        let mut subs = Vec::new();
794        subs.push(txns.subscribe_task(&client, d0, 5).await);
795        step(&mut subs).await;
796
797        // Now register the shard. Also start a new subscription and step the
798        // previous one (plus repeat this for every later step).
799        txns.register(1, [writer(&client, d0).await]).await.unwrap();
800        subs.push(txns.subscribe_task(&client, d0, 5).await);
801        step(&mut subs).await;
802
803        // Now write something unrelated.
804        let d1 = txns.expect_register(2).await;
805        txns.expect_commit_at(3, d1, &["nope"], &log).await;
806        subs.push(txns.subscribe_task(&client, d0, 5).await);
807        step(&mut subs).await;
808
809        // Now write to our shard before.
810        txns.expect_commit_at(4, d0, &["4"], &log).await;
811        subs.push(txns.subscribe_task(&client, d0, 5).await);
812        step(&mut subs).await;
813
814        // Now write to our shard at the as_of.
815        txns.expect_commit_at(5, d0, &["5"], &log).await;
816        subs.push(txns.subscribe_task(&client, d0, 5).await);
817        step(&mut subs).await;
818
819        // Now write to our shard past the as_of.
820        txns.expect_commit_at(6, d0, &["6"], &log).await;
821        subs.push(txns.subscribe_task(&client, d0, 5).await);
822        step(&mut subs).await;
823
824        // Now write something unrelated again.
825        txns.expect_commit_at(7, d1, &["nope"], &log).await;
826        subs.push(txns.subscribe_task(&client, d0, 5).await);
827        step(&mut subs).await;
828
829        // Verify that the dataflows can progress to the expected point and that
830        // we read the right thing no matter when the dataflow started.
831        for mut sub in subs {
832            let progress = sub.step_past(7).await;
833            assert_eq!(progress, 8);
834            log.assert_eq(d0, 5, 8, sub.finish().await);
835        }
836    }
837
838    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
839    #[cfg_attr(miri, ignore)] // too slow
840    async fn subscribe_shard_finalize() {
841        let client = PersistClient::new_for_tests().await;
842        let mut txns = TxnsHandle::expect_open(client.clone()).await;
843        let log = txns.new_log();
844        let d0 = txns.expect_register(1).await;
845
846        // Start the operator as_of the register ts.
847        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
848        sub.step_past(1).await;
849
850        // Write to it via txns.
851        txns.expect_commit_at(2, d0, &["foo"], &log).await;
852        sub.step_past(2).await;
853
854        // Unregister it.
855        txns.forget(3, [d0]).await.unwrap();
856        sub.step_past(3).await;
857
858        // TODO: Hard mode, see if we can get the rest of this test to work even
859        // _without_ the txns shard advancing.
860        txns.begin().commit_at(&mut txns, 7).await.unwrap();
861
862        // The operator should continue to emit data written directly even
863        // though it's no longer in the txns set.
864        let mut d0_write = writer(&client, d0).await;
865        let key = "bar".to_owned();
866        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
867            .await
868            .unwrap();
869        log.record((d0, key, 5, 1));
870        sub.step_past(4).await;
871
872        // Now finalize the shard to writes.
873        let () = d0_write
874            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new(), true)
875            .await
876            .unwrap()
877            .unwrap();
878        while sub.txns.less_than(&u64::MAX) {
879            sub.step();
880            tokio::task::yield_now().await;
881        }
882
883        // Make sure we read the correct things.
884        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
885
886        // Also make sure that we can read the right things if we start up after
887        // the forget but before the direct write and ditto after the direct
888        // write.
889        log.assert_subscribe(d0, 4, u64::MAX).await;
890        log.assert_subscribe(d0, 6, u64::MAX).await;
891    }
892
893    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
894    #[cfg_attr(miri, ignore)] // too slow
895    async fn subscribe_shard_register_forget() {
896        let client = PersistClient::new_for_tests().await;
897        let mut txns = TxnsHandle::expect_open(client.clone()).await;
898        let d0 = ShardId::new();
899
900        // Start a subscription on the data shard.
901        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
902        assert_eq!(sub.progress(), 0);
903
904        // Register the shard at 10.
905        txns.register(10, [writer(&client, d0).await])
906            .await
907            .unwrap();
908        sub.step_past(10).await;
909        assert!(
910            sub.progress() > 10,
911            "operator should advance past 10 when shard is registered"
912        );
913
914        // Forget the shard at 20.
915        txns.forget(20, [d0]).await.unwrap();
916        sub.step_past(20).await;
917        assert!(
918            sub.progress() > 20,
919            "operator should advance past 20 when shard is forgotten"
920        );
921    }
922
923    #[mz_ore::test(tokio::test)]
924    #[cfg_attr(miri, ignore)] // too slow
925    async fn as_of_until() {
926        let client = PersistClient::new_for_tests().await;
927        let mut txns = TxnsHandle::expect_open(client.clone()).await;
928        let log = txns.new_log();
929
930        let d0 = txns.expect_register(1).await;
931        txns.expect_commit_at(2, d0, &["2"], &log).await;
932        txns.expect_commit_at(3, d0, &["3"], &log).await;
933        txns.expect_commit_at(4, d0, &["4"], &log).await;
934        txns.expect_commit_at(5, d0, &["5"], &log).await;
935        txns.expect_commit_at(6, d0, &["6"], &log).await;
936        txns.expect_commit_at(7, d0, &["7"], &log).await;
937
938        let until = 5;
939        let mut sub = DataSubscribe::new(
940            "as_of_until",
941            client,
942            txns.txns_id(),
943            d0,
944            3,
945            Antichain::from_elem(until),
946        );
947        // Manually step the dataflow, instead of going through the
948        // `DataSubscribe` helper because we're interested in all captured
949        // events.
950        while sub.txns.less_equal(&5) {
951            sub.worker.step();
952            tokio::task::yield_now().await;
953            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
954        }
955        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
956            sub.capture.into_iter().partition_map(|event| match event {
957                Event::Progress(progress) => Either::Left(progress),
958                Event::Messages(ts, data) => Either::Right((ts, data)),
959            });
960        let expected = vec![
961            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
962            (3, vec![("4".to_owned(), 4, 1)]),
963        ];
964        assert_eq!(actual_events, expected);
965
966        // The number and contents of progress messages is not guaranteed and
967        // depends on the downgrade behavior. The only thing we can assert is
968        // the max progress timestamp, if there is one, is less than the until.
969        if let Some(max_progress_ts) = actual_progresses
970            .into_iter()
971            .flatten()
972            .map(|(ts, _diff)| ts)
973            .max()
974        {
975            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
976        }
977    }
978}