mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet};
24use mz_ore::cast::CastFrom;
25use mz_persist_client::cfg::RetryParameters;
26use mz_persist_client::operators::shard_source::{
27    ErrorHandler, FilterResult, SnapshotMode, shard_source,
28};
29use mz_persist_client::{Diagnostics, PersistClient, ShardId};
30use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
31use mz_persist_types::txn::TxnsCodec;
32use mz_persist_types::{Codec, Codec64, StepForward};
33use mz_timely_util::builder_async::{
34    AsyncInputHandle, Event as AsyncEvent, InputConnection,
35    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
36};
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pact::Pipeline;
39use timely::dataflow::operators::capture::Event;
40use timely::dataflow::operators::{Broadcast, Capture, Leave, Map, Probe};
41use timely::dataflow::{ProbeHandle, Scope, Stream};
42use timely::order::TotalOrder;
43use timely::progress::{Antichain, Timestamp};
44use timely::worker::Worker;
45use timely::{Data, PartialOrder, WorkerConfig};
46use tracing::debug;
47
48use crate::TxnsCodecDefault;
49use crate::txn_cache::TxnsCache;
50use crate::txn_read::{DataRemapEntry, TxnsRead};
51
52/// An operator for translating physical data shard frontiers into logical ones.
53///
54/// A data shard in the txns set logically advances its upper each time a txn is
55/// committed, but the upper is not physically advanced unless that data shard
56/// was involved in the txn. This means that a shard_source (or any read)
57/// pointed at a data shard would appear to stall at the time of the most recent
58/// write. We fix this for shard_source by flowing its output through a new
59/// `txns_progress` dataflow operator, which ensures that the
60/// frontier/capability is advanced as the txns shard progresses, as long as the
61/// shard_source is up to date with the latest committed write to that data
62/// shard.
63///
64/// Example:
65///
66/// - A data shard has most recently been written to at 3.
67/// - The txns shard's upper is at 6.
68/// - We render a dataflow containing a shard_source with an as_of of 5.
69/// - A txn NOT involving the data shard is committed at 7.
70/// - A txn involving the data shard is committed at 9.
71///
72/// How it works:
73///
74/// - The shard_source operator is rendered. Its single output is hooked up as a
75///   _disconnected_ input to txns_progress. The txns_progress single output is
76///   a stream of the same type, which is used by downstream operators. This
77///   txns_progress operator is targeted at one data_shard; rendering a
78///   shard_source for a second data shard requires a second txns_progress
79///   operator.
80/// - The shard_source operator emits data through 3 and advances the frontier.
81/// - The txns_progress operator passes through these writes and frontier
82///   advancements unchanged. (Recall that it's always correct to read a data
83///   shard "normally", it just might stall.) Because the txns_progress operator
84///   knows there are no writes in `[3,5]`, it then downgrades its own
85///   capability past 5 (to 6). Because the input is disconnected, this means
86///   the overall frontier of the output is downgraded to 6.
87/// - The txns_progress operator learns about the write at 7 (the upper is now
88///   8). Because it knows that the data shard was not involved in this, it's
89///   free to downgrade its capability to 8.
90/// - The txns_progress operator learns about the write at 9 (the upper is now
91///   10). It knows that the data shard _WAS_ involved in this, so it forwards
92///   on data from its input until the input has progressed to 10, at which
93///   point it can itself downgrade to 10.
94pub fn txns_progress<K, V, T, D, P, C, F, G>(
95    passthrough: Stream<G, P>,
96    name: &str,
97    ctx: &TxnsContext,
98    client_fn: impl Fn() -> F,
99    txns_id: ShardId,
100    data_id: ShardId,
101    as_of: T,
102    until: Antichain<T>,
103    data_key_schema: Arc<K::Schema>,
104    data_val_schema: Arc<V::Schema>,
105) -> (Stream<G, P>, Vec<PressOnDropButton>)
106where
107    K: Debug + Codec + Send + Sync,
108    V: Debug + Codec + Send + Sync,
109    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
110    D: Debug + Data + Monoid + Ord + Codec64 + Send + Sync,
111    P: Debug + Data,
112    C: TxnsCodec + 'static,
113    F: Future<Output = PersistClient> + Send + 'static,
114    G: Scope<Timestamp = T>,
115{
116    let unique_id = (name, passthrough.scope().addr()).hashed();
117    let (remap, source_button) = txns_progress_source_global::<K, V, T, D, P, C, G>(
118        passthrough.scope(),
119        name,
120        ctx.clone(),
121        client_fn(),
122        txns_id,
123        data_id,
124        as_of,
125        data_key_schema,
126        data_val_schema,
127        unique_id,
128    );
129    // Each of the `txns_frontiers` workers wants the full copy of the remap
130    // information.
131    let remap = remap.broadcast();
132    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
133        remap,
134        passthrough,
135        name,
136        data_id,
137        until,
138        unique_id,
139    );
140    (passthrough, vec![source_button, frontiers_button])
141}
142
143/// TODO: I'd much prefer the communication protocol between the two operators
144/// to be exactly remap as defined in the [reclocking design doc]. However, we
145/// can't quite recover exactly the information necessary to construct that at
146/// the moment. Seems worth doing, but in the meantime, intentionally make this
147/// look fairly different (`Stream` of `DataRemapEntry` instead of
148/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
149/// optimization, we only re-emit this when the _physical_ upper has changed,
150/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
151/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
152/// one with the largest physical_upper).
153///
154/// [reclocking design doc]:
155///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
156fn txns_progress_source_global<K, V, T, D, P, C, G>(
157    scope: G,
158    name: &str,
159    ctx: TxnsContext,
160    client: impl Future<Output = PersistClient> + 'static,
161    txns_id: ShardId,
162    data_id: ShardId,
163    as_of: T,
164    data_key_schema: Arc<K::Schema>,
165    data_val_schema: Arc<V::Schema>,
166    unique_id: u64,
167) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
168where
169    K: Debug + Codec + Send + Sync,
170    V: Debug + Codec + Send + Sync,
171    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
172    D: Debug + Data + Monoid + Ord + Codec64 + Send + Sync,
173    P: Debug + Data,
174    C: TxnsCodec + 'static,
175    G: Scope<Timestamp = T>,
176{
177    let worker_idx = scope.index();
178    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
179    let name = format!("txns_progress_source({})", name);
180    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
181    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
182    let (remap_output, remap_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
183
184    let shutdown_button = builder.build(move |capabilities| async move {
185        if worker_idx != chosen_worker {
186            return;
187        }
188
189        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
190        let client = client.await;
191        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
192
193        let _ = txns_read.update_gt(as_of.clone()).await;
194        let data_write = client
195            .open_writer::<K, V, T, D>(
196                data_id,
197                Arc::clone(&data_key_schema),
198                Arc::clone(&data_val_schema),
199                Diagnostics::from_purpose("data read physical upper"),
200            )
201            .await
202            .expect("schema shouldn't change");
203        let mut rx = txns_read
204            .data_subscribe(data_id, as_of.clone(), data_write)
205            .await;
206        debug!("{} starting as_of={:?}", name, as_of);
207
208        let mut physical_upper = T::minimum();
209
210        while let Some(remap) = rx.recv().await {
211            assert!(physical_upper <= remap.physical_upper);
212            assert!(physical_upper < remap.logical_upper);
213
214            let logical_upper = remap.logical_upper.clone();
215            // As mentioned in the docs on this function, we only
216            // emit updates when the physical upper changes (which
217            // happens to makes the protocol a tiny bit more
218            // remap-like).
219            if remap.physical_upper != physical_upper {
220                physical_upper = remap.physical_upper.clone();
221                debug!("{} emitting {:?}", name, remap);
222                remap_output.give(&cap, remap);
223            } else {
224                debug!("{} not emitting {:?}", name, remap);
225            }
226            cap.downgrade(&logical_upper);
227        }
228    });
229    (remap_stream, shutdown_button.press_on_drop())
230}
231
232fn txns_progress_frontiers<K, V, T, D, P, C, G>(
233    remap: Stream<G, DataRemapEntry<T>>,
234    passthrough: Stream<G, P>,
235    name: &str,
236    data_id: ShardId,
237    until: Antichain<T>,
238    unique_id: u64,
239) -> (Stream<G, P>, PressOnDropButton)
240where
241    K: Debug + Codec,
242    V: Debug + Codec,
243    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
244    D: Data + Monoid + Codec64 + Send + Sync,
245    P: Debug + Data,
246    C: TxnsCodec,
247    G: Scope<Timestamp = T>,
248{
249    let name = format!("txns_progress_frontiers({})", name);
250    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
251    let name = format!(
252        "{} [{}] {}/{} {:.9}",
253        name,
254        unique_id,
255        passthrough.scope().index(),
256        passthrough.scope().peers(),
257        data_id.to_string(),
258    );
259    let (passthrough_output, passthrough_stream) =
260        builder.new_output::<CapacityContainerBuilder<_>>();
261    let mut remap_input = builder.new_disconnected_input(&remap, Pipeline);
262    let mut passthrough_input = builder.new_disconnected_input(&passthrough, Pipeline);
263
264    let shutdown_button = builder.build(move |capabilities| async move {
265        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
266
267        // None is used to indicate that both uppers are the empty antichain.
268        let mut remap = Some(DataRemapEntry {
269            physical_upper: T::minimum(),
270            logical_upper: T::minimum(),
271        });
272        // NB: The following loop uses `cap.time()`` to track how far we've
273        // progressed in copying along the passthrough input.
274        loop {
275            debug!("{} remap {:?}", name, remap);
276            if let Some(r) = remap.as_ref() {
277                assert!(r.physical_upper <= r.logical_upper);
278                // If we've passed through data to at least `physical_upper`,
279                // then it means we can artificially advance the upper of the
280                // output to `logical_upper`. This also indicates that we need
281                // to wait for the next DataRemapEntry. It can either (A) have
282                // the same physical upper or (B) have a larger physical upper.
283                //
284                // - If (A), then we would again satisfy this `physical_upper`
285                //   check, again advance the logical upper again, ...
286                // - If (B), then we'd fall down to the code below, which copies
287                //   the passthrough data until the frontier passes
288                //   `physical_upper`, then loops back up here.
289                if r.physical_upper.less_equal(cap.time()) {
290                    if cap.time() < &r.logical_upper {
291                        cap.downgrade(&r.logical_upper);
292                    }
293                    remap = txns_progress_frontiers_read_remap_input(
294                        &name,
295                        &mut remap_input,
296                        r.clone(),
297                    )
298                    .await;
299                    continue;
300                }
301            }
302
303            // This only returns None when there are no more data left. Turn it
304            // into an empty frontier progress so we can re-use the shutdown
305            // code below.
306            let event = passthrough_input
307                .next()
308                .await
309                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
310            match event {
311                // NB: Ignore the data_cap because this input is disconnected.
312                AsyncEvent::Data(_data_cap, mut data) => {
313                    // NB: Nothing to do here for `until` because both the
314                    // `shard_source` (before this operator) and
315                    // `mfp_and_decode` (after this operator) do the necessary
316                    // filtering.
317                    debug!("{} emitting data {:?}", name, data);
318                    passthrough_output.give_container(&cap, &mut data);
319                }
320                AsyncEvent::Progress(new_progress) => {
321                    // If `until.less_equal(new_progress)`, it means that all
322                    // subsequent batches will contain only times greater or
323                    // equal to `until`, which means they can be dropped in
324                    // their entirety.
325                    //
326                    // Ideally this check would live in `txns_progress_source`,
327                    // but that turns out to be much more invasive (requires
328                    // replacing lots of `T`s with `Antichain<T>`s). Given that
329                    // we've been thinking about reworking the operators, do the
330                    // easy but more wasteful thing for now.
331                    if PartialOrder::less_equal(&until, &new_progress) {
332                        debug!(
333                            "{} progress {:?} has passed until {:?}",
334                            name,
335                            new_progress.elements(),
336                            until.elements()
337                        );
338                        return;
339                    }
340                    // We reached the empty frontier! Shut down.
341                    let Some(new_progress) = new_progress.into_option() else {
342                        return;
343                    };
344
345                    // Recall that any reads of the data shard are always
346                    // correct, so given that we've passed through any data
347                    // from the input, that means we're free to pass through
348                    // frontier updates too.
349                    if cap.time() < &new_progress {
350                        debug!("{} downgrading cap to {:?}", name, new_progress);
351                        cap.downgrade(&new_progress);
352                    }
353                }
354            }
355        }
356    });
357    (passthrough_stream, shutdown_button.press_on_drop())
358}
359
360async fn txns_progress_frontiers_read_remap_input<T, C>(
361    name: &str,
362    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
363    mut remap: DataRemapEntry<T>,
364) -> Option<DataRemapEntry<T>>
365where
366    T: Timestamp + TotalOrder,
367    C: InputConnection<T>,
368{
369    while let Some(event) = input.next().await {
370        let xs = match event {
371            AsyncEvent::Progress(logical_upper) => {
372                if let Some(logical_upper) = logical_upper.into_option() {
373                    if remap.logical_upper < logical_upper {
374                        remap.logical_upper = logical_upper;
375                        return Some(remap);
376                    }
377                }
378                continue;
379            }
380            AsyncEvent::Data(_cap, xs) => xs,
381        };
382        for x in xs {
383            debug!("{} got remap {:?}", name, x);
384            // Don't assume anything about the ordering.
385            if remap.logical_upper < x.logical_upper {
386                assert!(
387                    remap.physical_upper <= x.physical_upper,
388                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
389                    remap.physical_upper,
390                    x.physical_upper,
391                );
392                // TODO: If the physical upper has advanced, that's a very
393                // strong hint that the data shard is about to be written to.
394                // Because the data shard's upper advances sparsely (on write,
395                // but not on passage of time) which invalidates the "every 1s"
396                // assumption of the default tuning, we've had to de-tune the
397                // listen sleeps on the paired persist_source. Maybe we use "one
398                // state" to wake it up in case pubsub doesn't and remove the
399                // listen polling entirely? (NB: This would have to happen in
400                // each worker so that it's guaranteed to happen in each
401                // process.)
402                remap = x;
403            }
404        }
405        return Some(remap);
406    }
407    // remap_input is closed, which indicates the data shard is finished.
408    None
409}
410
411/// The process global [`TxnsRead`] that any operator can communicate with.
412#[derive(Default, Debug, Clone)]
413pub struct TxnsContext {
414    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
415}
416
417impl TxnsContext {
418    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
419    where
420        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
421        C: TxnsCodec + 'static,
422    {
423        let read = self
424            .read
425            .get_or_init(|| {
426                let client = client.clone();
427                async move {
428                    let read: Box<dyn Any + Send + Sync> =
429                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
430                    read
431                }
432            })
433            .await
434            .downcast_ref::<TxnsRead<T>>()
435            .expect("timestamp types should match");
436        // We initially only have one txns shard in the system.
437        assert_eq!(&txns_id, read.txns_id());
438        read.clone()
439    }
440}
441
442// Existing configs use the prefix "persist_txns_" for historical reasons. New
443// configs should use the prefix "txn_wal_".
444
445pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
446    "persist_txns_data_shard_retryer_initial_backoff",
447    Duration::from_millis(1024),
448    "The initial backoff when polling for new batches from a txns data shard persist_source.",
449);
450
451pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
452    "persist_txns_data_shard_retryer_multiplier",
453    2,
454    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
455);
456
457pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
458    "persist_txns_data_shard_retryer_clamp",
459    Duration::from_secs(16),
460    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
461);
462
463/// Retry configuration for txn-wal data shard override of
464/// `next_listen_batch`.
465pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
466    RetryParameters {
467        fixed_sleep: Duration::ZERO,
468        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
469        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
470        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
471    }
472}
473
474/// A helper for subscribing to a data shard using the timely operators.
475///
476/// This could instead be a wrapper around a [Subscribe], but it's only used in
477/// tests and maelstrom, so do it by wrapping the timely operators to get
478/// additional coverage. For the same reason, hardcode the K, V, T, D types.
479///
480/// [Subscribe]: mz_persist_client::read::Subscribe
481pub struct DataSubscribe {
482    pub(crate) as_of: u64,
483    pub(crate) worker: Worker<timely::communication::allocator::Thread>,
484    data: ProbeHandle<u64>,
485    txns: ProbeHandle<u64>,
486    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
487    output: Vec<(String, u64, i64)>,
488
489    _tokens: Vec<PressOnDropButton>,
490}
491
492impl std::fmt::Debug for DataSubscribe {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        let DataSubscribe {
495            as_of,
496            worker: _,
497            data,
498            txns,
499            capture: _,
500            output,
501            _tokens: _,
502        } = self;
503        f.debug_struct("DataSubscribe")
504            .field("as_of", as_of)
505            .field("data", data)
506            .field("txns", txns)
507            .field("output", output)
508            .finish_non_exhaustive()
509    }
510}
511
512impl DataSubscribe {
513    /// Creates a new [DataSubscribe].
514    pub fn new(
515        name: &str,
516        client: PersistClient,
517        txns_id: ShardId,
518        data_id: ShardId,
519        as_of: u64,
520        until: Antichain<u64>,
521    ) -> Self {
522        let mut worker = Worker::new(
523            WorkerConfig::default(),
524            timely::communication::allocator::Thread::default(),
525            Some(std::time::Instant::now()),
526        );
527        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
528            let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
529                let client = client.clone();
530                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
531                    scope,
532                    name,
533                    move || std::future::ready(client.clone()),
534                    data_id,
535                    Some(Antichain::from_elem(as_of)),
536                    SnapshotMode::Include,
537                    until.clone(),
538                    false.then_some(|_, _: &_, _| unreachable!()),
539                    Arc::new(StringSchema),
540                    Arc::new(UnitSchema),
541                    FilterResult::keep_all,
542                    false.then_some(|| unreachable!()),
543                    async {},
544                    ErrorHandler::Halt("data_subscribe"),
545                );
546                (data_stream.leave(), token)
547            });
548            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
549            let data_stream = data_stream.flat_map(|part| {
550                let part = part.parse();
551                part.part.map(|((k, v), t, d)| {
552                    let (k, ()) = (k.unwrap(), v.unwrap());
553                    (k, t, d)
554                })
555            });
556            let data_stream = data_stream.probe_with(&data);
557            let (data_stream, mut txns_progress_token) =
558                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
559                    data_stream,
560                    name,
561                    &TxnsContext::default(),
562                    || std::future::ready(client.clone()),
563                    txns_id,
564                    data_id,
565                    as_of,
566                    until,
567                    Arc::new(StringSchema),
568                    Arc::new(UnitSchema),
569                );
570            let data_stream = data_stream.probe_with(&txns);
571            let mut tokens = shard_source_token;
572            tokens.append(&mut txns_progress_token);
573            (data, txns, data_stream.capture(), tokens)
574        });
575        Self {
576            as_of,
577            worker,
578            data,
579            txns,
580            capture,
581            output: Vec::new(),
582            _tokens: tokens,
583        }
584    }
585
586    /// Returns the exclusive progress of the dataflow.
587    pub fn progress(&self) -> u64 {
588        self.txns
589            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
590    }
591
592    /// Steps the dataflow, capturing output.
593    pub fn step(&mut self) {
594        self.worker.step();
595        self.capture_output()
596    }
597
598    pub(crate) fn capture_output(&mut self) {
599        loop {
600            let event = match self.capture.try_recv() {
601                Ok(x) => x,
602                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
603            };
604            match event {
605                Event::Progress(_) => {}
606                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
607            }
608        }
609    }
610
611    /// Steps the dataflow past the given time, capturing output.
612    #[cfg(test)]
613    pub async fn step_past(&mut self, ts: u64) {
614        while self.txns.less_equal(&ts) {
615            tracing::trace!(
616                "progress at {:?}",
617                self.txns.with_frontier(|x| x.to_owned()).elements()
618            );
619            self.step();
620            tokio::task::yield_now().await;
621        }
622    }
623
624    /// Returns captured output.
625    pub fn output(&self) -> &Vec<(String, u64, i64)> {
626        &self.output
627    }
628}
629
630/// A handle to a [DataSubscribe] running in a task.
631#[derive(Debug)]
632pub struct DataSubscribeTask {
633    /// Carries step requests. A `None` timestamp requests one step, a
634    /// `Some(ts)` requests stepping until we progress beyond `ts`.
635    tx: std::sync::mpsc::Sender<(
636        Option<u64>,
637        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
638    )>,
639    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
640    output: Vec<(String, u64, i64)>,
641    progress: u64,
642}
643
644impl DataSubscribeTask {
645    /// Creates a new [DataSubscribeTask].
646    pub async fn new(
647        client: PersistClient,
648        txns_id: ShardId,
649        data_id: ShardId,
650        as_of: u64,
651    ) -> Self {
652        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
653        let (tx, rx) = std::sync::mpsc::channel();
654        let task = mz_ore::task::spawn_blocking(
655            || "data_subscribe task",
656            move || Self::task(client, cache, data_id, as_of, rx),
657        );
658        DataSubscribeTask {
659            tx,
660            task,
661            output: Vec::new(),
662            progress: 0,
663        }
664    }
665
666    #[cfg(test)]
667    async fn step(&mut self) {
668        self.send(None).await;
669    }
670
671    /// Steps the dataflow past the given time, capturing output.
672    pub async fn step_past(&mut self, ts: u64) -> u64 {
673        self.send(Some(ts)).await;
674        self.progress
675    }
676
677    /// Returns captured output.
678    pub fn output(&self) -> &Vec<(String, u64, i64)> {
679        &self.output
680    }
681
682    async fn send(&mut self, ts: Option<u64>) {
683        let (tx, rx) = tokio::sync::oneshot::channel();
684        self.tx.send((ts, tx)).expect("task should be running");
685        let (mut new_output, new_progress) = rx.await.expect("task should be running");
686        self.output.append(&mut new_output);
687        assert!(self.progress <= new_progress);
688        self.progress = new_progress;
689    }
690
691    /// Signals for the task to exit, and then waits for this to happen.
692    ///
693    /// _All_ output from the lifetime of the task (not just what was previously
694    /// captured) is returned.
695    pub async fn finish(self) -> Vec<(String, u64, i64)> {
696        // Closing the channel signals the task to exit.
697        drop(self.tx);
698        self.task.await
699    }
700
701    fn task(
702        client: PersistClient,
703        cache: TxnsCache<u64>,
704        data_id: ShardId,
705        as_of: u64,
706        rx: std::sync::mpsc::Receiver<(
707            Option<u64>,
708            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
709        )>,
710    ) -> Vec<(String, u64, i64)> {
711        let mut subscribe = DataSubscribe::new(
712            "DataSubscribeTask",
713            client.clone(),
714            cache.txns_id(),
715            data_id,
716            as_of,
717            Antichain::new(),
718        );
719        let mut output = Vec::new();
720        loop {
721            let (ts, tx) = match rx.try_recv() {
722                Ok(x) => x,
723                Err(TryRecvError::Empty) => {
724                    // No requests, continue stepping so nothing deadlocks.
725                    subscribe.step();
726                    continue;
727                }
728                Err(TryRecvError::Disconnected) => {
729                    // All done! Return our output.
730                    return output;
731                }
732            };
733            // Always step at least once.
734            subscribe.step();
735            // If we got a ts, make sure to step past it.
736            if let Some(ts) = ts {
737                while subscribe.progress() <= ts {
738                    subscribe.step();
739                }
740            }
741            let new_output = std::mem::take(&mut subscribe.output);
742            output.extend(new_output.iter().cloned());
743            let _ = tx.send((new_output, subscribe.progress()));
744        }
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    use itertools::{Either, Itertools};
751    use mz_persist_types::Opaque;
752
753    use crate::tests::writer;
754    use crate::txns::TxnsHandle;
755
756    use super::*;
757
758    impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
759    where
760        K: Debug + Codec,
761        V: Debug + Codec,
762        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
763        D: Debug + Monoid + Ord + Codec64 + Send + Sync,
764        O: Opaque + Debug + Codec64,
765        C: TxnsCodec,
766    {
767        async fn subscribe_task(
768            &self,
769            client: &PersistClient,
770            data_id: ShardId,
771            as_of: u64,
772        ) -> DataSubscribeTask {
773            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
774        }
775    }
776
777    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
778    #[cfg_attr(miri, ignore)] // too slow
779    async fn data_subscribe() {
780        async fn step(subs: &mut Vec<DataSubscribeTask>) {
781            for sub in subs.iter_mut() {
782                sub.step().await;
783            }
784        }
785
786        let client = PersistClient::new_for_tests().await;
787        let mut txns = TxnsHandle::expect_open(client.clone()).await;
788        let log = txns.new_log();
789        let d0 = ShardId::new();
790
791        // Start a subscription before the shard gets registered.
792        let mut subs = Vec::new();
793        subs.push(txns.subscribe_task(&client, d0, 5).await);
794        step(&mut subs).await;
795
796        // Now register the shard. Also start a new subscription and step the
797        // previous one (plus repeat this for every later step).
798        txns.register(1, [writer(&client, d0).await]).await.unwrap();
799        subs.push(txns.subscribe_task(&client, d0, 5).await);
800        step(&mut subs).await;
801
802        // Now write something unrelated.
803        let d1 = txns.expect_register(2).await;
804        txns.expect_commit_at(3, d1, &["nope"], &log).await;
805        subs.push(txns.subscribe_task(&client, d0, 5).await);
806        step(&mut subs).await;
807
808        // Now write to our shard before.
809        txns.expect_commit_at(4, d0, &["4"], &log).await;
810        subs.push(txns.subscribe_task(&client, d0, 5).await);
811        step(&mut subs).await;
812
813        // Now write to our shard at the as_of.
814        txns.expect_commit_at(5, d0, &["5"], &log).await;
815        subs.push(txns.subscribe_task(&client, d0, 5).await);
816        step(&mut subs).await;
817
818        // Now write to our shard past the as_of.
819        txns.expect_commit_at(6, d0, &["6"], &log).await;
820        subs.push(txns.subscribe_task(&client, d0, 5).await);
821        step(&mut subs).await;
822
823        // Now write something unrelated again.
824        txns.expect_commit_at(7, d1, &["nope"], &log).await;
825        subs.push(txns.subscribe_task(&client, d0, 5).await);
826        step(&mut subs).await;
827
828        // Verify that the dataflows can progress to the expected point and that
829        // we read the right thing no matter when the dataflow started.
830        for mut sub in subs {
831            let progress = sub.step_past(7).await;
832            assert_eq!(progress, 8);
833            log.assert_eq(d0, 5, 8, sub.finish().await);
834        }
835    }
836
837    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
838    #[cfg_attr(miri, ignore)] // too slow
839    async fn subscribe_shard_finalize() {
840        let client = PersistClient::new_for_tests().await;
841        let mut txns = TxnsHandle::expect_open(client.clone()).await;
842        let log = txns.new_log();
843        let d0 = txns.expect_register(1).await;
844
845        // Start the operator as_of the register ts.
846        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
847        sub.step_past(1).await;
848
849        // Write to it via txns.
850        txns.expect_commit_at(2, d0, &["foo"], &log).await;
851        sub.step_past(2).await;
852
853        // Unregister it.
854        txns.forget(3, [d0]).await.unwrap();
855        sub.step_past(3).await;
856
857        // TODO: Hard mode, see if we can get the rest of this test to work even
858        // _without_ the txns shard advancing.
859        txns.begin().commit_at(&mut txns, 7).await.unwrap();
860
861        // The operator should continue to emit data written directly even
862        // though it's no longer in the txns set.
863        let mut d0_write = writer(&client, d0).await;
864        let key = "bar".to_owned();
865        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
866            .await
867            .unwrap();
868        log.record((d0, key, 5, 1));
869        sub.step_past(4).await;
870
871        // Now finalize the shard to writes.
872        let () = d0_write
873            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new(), true)
874            .await
875            .unwrap()
876            .unwrap();
877        while sub.txns.less_than(&u64::MAX) {
878            sub.step();
879            tokio::task::yield_now().await;
880        }
881
882        // Make sure we read the correct things.
883        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
884
885        // Also make sure that we can read the right things if we start up after
886        // the forget but before the direct write and ditto after the direct
887        // write.
888        log.assert_subscribe(d0, 4, u64::MAX).await;
889        log.assert_subscribe(d0, 6, u64::MAX).await;
890    }
891
892    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
893    #[cfg_attr(miri, ignore)] // too slow
894    async fn subscribe_shard_register_forget() {
895        let client = PersistClient::new_for_tests().await;
896        let mut txns = TxnsHandle::expect_open(client.clone()).await;
897        let d0 = ShardId::new();
898
899        // Start a subscription on the data shard.
900        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
901        assert_eq!(sub.progress(), 0);
902
903        // Register the shard at 10.
904        txns.register(10, [writer(&client, d0).await])
905            .await
906            .unwrap();
907        sub.step_past(10).await;
908        assert!(
909            sub.progress() > 10,
910            "operator should advance past 10 when shard is registered"
911        );
912
913        // Forget the shard at 20.
914        txns.forget(20, [d0]).await.unwrap();
915        sub.step_past(20).await;
916        assert!(
917            sub.progress() > 20,
918            "operator should advance past 20 when shard is forgotten"
919        );
920    }
921
922    #[mz_ore::test(tokio::test)]
923    #[cfg_attr(miri, ignore)] // too slow
924    async fn as_of_until() {
925        let client = PersistClient::new_for_tests().await;
926        let mut txns = TxnsHandle::expect_open(client.clone()).await;
927        let log = txns.new_log();
928
929        let d0 = txns.expect_register(1).await;
930        txns.expect_commit_at(2, d0, &["2"], &log).await;
931        txns.expect_commit_at(3, d0, &["3"], &log).await;
932        txns.expect_commit_at(4, d0, &["4"], &log).await;
933        txns.expect_commit_at(5, d0, &["5"], &log).await;
934        txns.expect_commit_at(6, d0, &["6"], &log).await;
935        txns.expect_commit_at(7, d0, &["7"], &log).await;
936
937        let until = 5;
938        let mut sub = DataSubscribe::new(
939            "as_of_until",
940            client,
941            txns.txns_id(),
942            d0,
943            3,
944            Antichain::from_elem(until),
945        );
946        // Manually step the dataflow, instead of going through the
947        // `DataSubscribe` helper because we're interested in all captured
948        // events.
949        while sub.txns.less_equal(&5) {
950            sub.worker.step();
951            tokio::task::yield_now().await;
952            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
953        }
954        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
955            sub.capture.into_iter().partition_map(|event| match event {
956                Event::Progress(progress) => Either::Left(progress),
957                Event::Messages(ts, data) => Either::Right((ts, data)),
958            });
959        let expected = vec![
960            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
961            (3, vec![("4".to_owned(), 4, 1)]),
962        ];
963        assert_eq!(actual_events, expected);
964
965        // The number and contents of progress messages is not guaranteed and
966        // depends on the downgrade behavior. The only thing we can assert is
967        // the max progress timestamp, if there is one, is less than the until.
968        if let Some(max_progress_ts) = actual_progresses
969            .into_iter()
970            .flatten()
971            .map(|(ts, _diff)| ts)
972            .max()
973        {
974            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
975        }
976    }
977}