Skip to main content

mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet};
24use mz_ore::cast::CastFrom;
25use mz_persist_client::cfg::RetryParameters;
26use mz_persist_client::operators::shard_source::{
27    ErrorHandler, FilterResult, SnapshotMode, shard_source,
28};
29use mz_persist_client::{Diagnostics, PersistClient, ShardId};
30use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
31use mz_persist_types::txn::TxnsCodec;
32use mz_persist_types::{Codec, Codec64, StepForward};
33use mz_timely_util::builder_async::{
34    AsyncInputHandle, Event as AsyncEvent, InputConnection,
35    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
36};
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pact::Pipeline;
39use timely::dataflow::operators::capture::Event;
40use timely::dataflow::operators::vec::{Broadcast, Map};
41use timely::dataflow::operators::{Capture, Leave, Probe};
42use timely::dataflow::{ProbeHandle, Scope, StreamVec};
43use timely::order::TotalOrder;
44use timely::progress::{Antichain, Timestamp};
45use timely::worker::Worker;
46use timely::{PartialOrder, WorkerConfig};
47use tracing::debug;
48
49use crate::TxnsCodecDefault;
50use crate::txn_cache::TxnsCache;
51use crate::txn_read::{DataRemapEntry, TxnsRead};
52
53/// An operator for translating physical data shard frontiers into logical ones.
54///
55/// A data shard in the txns set logically advances its upper each time a txn is
56/// committed, but the upper is not physically advanced unless that data shard
57/// was involved in the txn. This means that a shard_source (or any read)
58/// pointed at a data shard would appear to stall at the time of the most recent
59/// write. We fix this for shard_source by flowing its output through a new
60/// `txns_progress` dataflow operator, which ensures that the
61/// frontier/capability is advanced as the txns shard progresses, as long as the
62/// shard_source is up to date with the latest committed write to that data
63/// shard.
64///
65/// Example:
66///
67/// - A data shard has most recently been written to at 3.
68/// - The txns shard's upper is at 6.
69/// - We render a dataflow containing a shard_source with an as_of of 5.
70/// - A txn NOT involving the data shard is committed at 7.
71/// - A txn involving the data shard is committed at 9.
72///
73/// How it works:
74///
75/// - The shard_source operator is rendered. Its single output is hooked up as a
76///   _disconnected_ input to txns_progress. The txns_progress single output is
77///   a stream of the same type, which is used by downstream operators. This
78///   txns_progress operator is targeted at one data_shard; rendering a
79///   shard_source for a second data shard requires a second txns_progress
80///   operator.
81/// - The shard_source operator emits data through 3 and advances the frontier.
82/// - The txns_progress operator passes through these writes and frontier
83///   advancements unchanged. (Recall that it's always correct to read a data
84///   shard "normally", it just might stall.) Because the txns_progress operator
85///   knows there are no writes in `[3,5]`, it then downgrades its own
86///   capability past 5 (to 6). Because the input is disconnected, this means
87///   the overall frontier of the output is downgraded to 6.
88/// - The txns_progress operator learns about the write at 7 (the upper is now
89///   8). Because it knows that the data shard was not involved in this, it's
90///   free to downgrade its capability to 8.
91/// - The txns_progress operator learns about the write at 9 (the upper is now
92///   10). It knows that the data shard _WAS_ involved in this, so it forwards
93///   on data from its input until the input has progressed to 10, at which
94///   point it can itself downgrade to 10.
95pub fn txns_progress<'scope, K, V, T, D, P, C, F>(
96    passthrough: StreamVec<'scope, T, P>,
97    name: &str,
98    ctx: &TxnsContext,
99    client_fn: impl Fn() -> F,
100    txns_id: ShardId,
101    data_id: ShardId,
102    as_of: T,
103    until: Antichain<T>,
104    data_key_schema: Arc<K::Schema>,
105    data_val_schema: Arc<V::Schema>,
106) -> (StreamVec<'scope, T, P>, Vec<PressOnDropButton>)
107where
108    K: Debug + Codec + Send + Sync,
109    V: Debug + Codec + Send + Sync,
110    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
111    D: Debug + Clone + 'static + Monoid + Ord + Codec64 + Send + Sync,
112    P: Debug + Clone + 'static,
113    C: TxnsCodec + 'static,
114    F: Future<Output = PersistClient> + Send + 'static,
115{
116    let unique_id = (name, passthrough.scope().addr()).hashed();
117    let (remap, source_button) = txns_progress_source_global::<K, V, T, D, P, C>(
118        passthrough.scope(),
119        name,
120        ctx.clone(),
121        client_fn(),
122        txns_id,
123        data_id,
124        as_of,
125        data_key_schema,
126        data_val_schema,
127        unique_id,
128    );
129    // Each of the `txns_frontiers` workers wants the full copy of the remap
130    // information.
131    let remap = remap.broadcast();
132    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C>(
133        remap,
134        passthrough,
135        name,
136        data_id,
137        until,
138        unique_id,
139    );
140    (passthrough, vec![source_button, frontiers_button])
141}
142
143/// TODO: I'd much prefer the communication protocol between the two operators
144/// to be exactly remap as defined in the [reclocking design doc]. However, we
145/// can't quite recover exactly the information necessary to construct that at
146/// the moment. Seems worth doing, but in the meantime, intentionally make this
147/// look fairly different (`Stream` of `DataRemapEntry` instead of
148/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
149/// optimization, we only re-emit this when the _physical_ upper has changed,
150/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
151/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
152/// one with the largest physical_upper).
153///
154/// [reclocking design doc]:
155///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
156fn txns_progress_source_global<'scope, K, V, T, D, P, C>(
157    scope: Scope<'scope, T>,
158    name: &str,
159    ctx: TxnsContext,
160    client: impl Future<Output = PersistClient> + 'static,
161    txns_id: ShardId,
162    data_id: ShardId,
163    as_of: T,
164    data_key_schema: Arc<K::Schema>,
165    data_val_schema: Arc<V::Schema>,
166    unique_id: u64,
167) -> (StreamVec<'scope, T, DataRemapEntry<T>>, PressOnDropButton)
168where
169    K: Debug + Codec + Send + Sync,
170    V: Debug + Codec + Send + Sync,
171    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
172    D: Debug + Clone + 'static + Monoid + Ord + Codec64 + Send + Sync,
173    P: Debug + Clone + 'static,
174    C: TxnsCodec + 'static,
175{
176    let worker_idx = scope.index();
177    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
178    let name = format!("txns_progress_source({})", name);
179    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
180    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
181    let (remap_output, remap_stream) = builder.new_output::<CapacityContainerBuilder<Vec<_>>>();
182
183    let shutdown_button = builder.build(move |capabilities| async move {
184        if worker_idx != chosen_worker {
185            return;
186        }
187
188        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
189        let client = client.await;
190        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
191
192        let _ = txns_read.update_gt(as_of.clone()).await;
193        let data_write = client
194            .open_writer::<K, V, T, D>(
195                data_id,
196                Arc::clone(&data_key_schema),
197                Arc::clone(&data_val_schema),
198                Diagnostics::from_purpose("data read physical upper"),
199            )
200            .await
201            .expect("schema shouldn't change");
202        let mut rx = txns_read
203            .data_subscribe(data_id, as_of.clone(), data_write)
204            .await;
205        debug!("{} starting as_of={:?}", name, as_of);
206
207        let mut physical_upper = T::minimum();
208
209        while let Some(remap) = rx.recv().await {
210            assert!(physical_upper <= remap.physical_upper);
211            assert!(physical_upper < remap.logical_upper);
212
213            let logical_upper = remap.logical_upper.clone();
214            // As mentioned in the docs on this function, we only
215            // emit updates when the physical upper changes (which
216            // happens to makes the protocol a tiny bit more
217            // remap-like).
218            if remap.physical_upper != physical_upper {
219                physical_upper = remap.physical_upper.clone();
220                debug!("{} emitting {:?}", name, remap);
221                remap_output.give(&cap, remap);
222            } else {
223                debug!("{} not emitting {:?}", name, remap);
224            }
225            cap.downgrade(&logical_upper);
226        }
227    });
228    (remap_stream, shutdown_button.press_on_drop())
229}
230
231fn txns_progress_frontiers<'scope, K, V, T, D, P, C>(
232    remap: StreamVec<'scope, T, DataRemapEntry<T>>,
233    passthrough: StreamVec<'scope, T, P>,
234    name: &str,
235    data_id: ShardId,
236    until: Antichain<T>,
237    unique_id: u64,
238) -> (StreamVec<'scope, T, P>, PressOnDropButton)
239where
240    K: Debug + Codec,
241    V: Debug + Codec,
242    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
243    D: Clone + 'static + Monoid + Codec64 + Send + Sync,
244    P: Debug + Clone + 'static,
245    C: TxnsCodec,
246{
247    let name = format!("txns_progress_frontiers({})", name);
248    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
249    let name = format!(
250        "{} [{}] {}/{} {:.9}",
251        name,
252        unique_id,
253        passthrough.scope().index(),
254        passthrough.scope().peers(),
255        data_id.to_string(),
256    );
257    let (passthrough_output, passthrough_stream) =
258        builder.new_output::<CapacityContainerBuilder<Vec<_>>>();
259    let mut remap_input = builder.new_disconnected_input(remap, Pipeline);
260    let mut passthrough_input = builder.new_disconnected_input(passthrough, Pipeline);
261
262    let shutdown_button = builder.build(move |capabilities| async move {
263        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
264
265        // None is used to indicate that both uppers are the empty antichain.
266        let mut remap = Some(DataRemapEntry {
267            physical_upper: T::minimum(),
268            logical_upper: T::minimum(),
269        });
270        // NB: The following loop uses `cap.time()`` to track how far we've
271        // progressed in copying along the passthrough input.
272        loop {
273            debug!("{} remap {:?}", name, remap);
274            if let Some(r) = remap.as_ref() {
275                assert!(r.physical_upper <= r.logical_upper);
276                // If we've passed through data to at least `physical_upper`,
277                // then it means we can artificially advance the upper of the
278                // output to `logical_upper`. This also indicates that we need
279                // to wait for the next DataRemapEntry. It can either (A) have
280                // the same physical upper or (B) have a larger physical upper.
281                //
282                // - If (A), then we would again satisfy this `physical_upper`
283                //   check, again advance the logical upper again, ...
284                // - If (B), then we'd fall down to the code below, which copies
285                //   the passthrough data until the frontier passes
286                //   `physical_upper`, then loops back up here.
287                if r.physical_upper.less_equal(cap.time()) {
288                    if cap.time() < &r.logical_upper {
289                        cap.downgrade(&r.logical_upper);
290                    }
291                    remap = txns_progress_frontiers_read_remap_input(
292                        &name,
293                        &mut remap_input,
294                        r.clone(),
295                    )
296                    .await;
297                    continue;
298                }
299            }
300
301            // This only returns None when there are no more data left. Turn it
302            // into an empty frontier progress so we can re-use the shutdown
303            // code below.
304            let event = passthrough_input
305                .next()
306                .await
307                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
308            match event {
309                // NB: Ignore the data_cap because this input is disconnected.
310                AsyncEvent::Data(_data_cap, mut data) => {
311                    // NB: Nothing to do here for `until` because both the
312                    // `shard_source` (before this operator) and
313                    // `mfp_and_decode` (after this operator) do the necessary
314                    // filtering.
315                    debug!("{} emitting data {:?}", name, data);
316                    passthrough_output.give_container(&cap, &mut data);
317                }
318                AsyncEvent::Progress(new_progress) => {
319                    // If `until.less_equal(new_progress)`, it means that all
320                    // subsequent batches will contain only times greater or
321                    // equal to `until`, which means they can be dropped in
322                    // their entirety.
323                    //
324                    // Ideally this check would live in `txns_progress_source`,
325                    // but that turns out to be much more invasive (requires
326                    // replacing lots of `T`s with `Antichain<T>`s). Given that
327                    // we've been thinking about reworking the operators, do the
328                    // easy but more wasteful thing for now.
329                    if PartialOrder::less_equal(&until, &new_progress) {
330                        debug!(
331                            "{} progress {:?} has passed until {:?}",
332                            name,
333                            new_progress.elements(),
334                            until.elements()
335                        );
336                        return;
337                    }
338                    // We reached the empty frontier! Shut down.
339                    let Some(new_progress) = new_progress.into_option() else {
340                        return;
341                    };
342
343                    // Recall that any reads of the data shard are always
344                    // correct, so given that we've passed through any data
345                    // from the input, that means we're free to pass through
346                    // frontier updates too.
347                    if cap.time() < &new_progress {
348                        debug!("{} downgrading cap to {:?}", name, new_progress);
349                        cap.downgrade(&new_progress);
350                    }
351                }
352            }
353        }
354    });
355    (passthrough_stream, shutdown_button.press_on_drop())
356}
357
358async fn txns_progress_frontiers_read_remap_input<T, C>(
359    name: &str,
360    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
361    mut remap: DataRemapEntry<T>,
362) -> Option<DataRemapEntry<T>>
363where
364    T: Timestamp + TotalOrder,
365    C: InputConnection<T>,
366{
367    while let Some(event) = input.next().await {
368        let xs = match event {
369            AsyncEvent::Progress(logical_upper) => {
370                if let Some(logical_upper) = logical_upper.into_option() {
371                    if remap.logical_upper < logical_upper {
372                        remap.logical_upper = logical_upper;
373                        return Some(remap);
374                    }
375                }
376                continue;
377            }
378            AsyncEvent::Data(_cap, xs) => xs,
379        };
380        for x in xs {
381            debug!("{} got remap {:?}", name, x);
382            // Don't assume anything about the ordering.
383            if remap.logical_upper < x.logical_upper {
384                assert!(
385                    remap.physical_upper <= x.physical_upper,
386                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
387                    remap.physical_upper,
388                    x.physical_upper,
389                );
390                // TODO: If the physical upper has advanced, that's a very
391                // strong hint that the data shard is about to be written to.
392                // Because the data shard's upper advances sparsely (on write,
393                // but not on passage of time) which invalidates the "every 1s"
394                // assumption of the default tuning, we've had to de-tune the
395                // listen sleeps on the paired persist_source. Maybe we use "one
396                // state" to wake it up in case pubsub doesn't and remove the
397                // listen polling entirely? (NB: This would have to happen in
398                // each worker so that it's guaranteed to happen in each
399                // process.)
400                remap = x;
401            }
402        }
403        return Some(remap);
404    }
405    // remap_input is closed, which indicates the data shard is finished.
406    None
407}
408
409/// The process global [`TxnsRead`] that any operator can communicate with.
410#[derive(Default, Debug, Clone)]
411pub struct TxnsContext {
412    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
413}
414
415impl TxnsContext {
416    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
417    where
418        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
419        C: TxnsCodec + 'static,
420    {
421        let read = self
422            .read
423            .get_or_init(|| {
424                let client = client.clone();
425                async move {
426                    let read: Box<dyn Any + Send + Sync> =
427                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
428                    read
429                }
430            })
431            .await
432            .downcast_ref::<TxnsRead<T>>()
433            .expect("timestamp types should match");
434        // We initially only have one txns shard in the system.
435        assert_eq!(&txns_id, read.txns_id());
436        read.clone()
437    }
438}
439
440// Existing configs use the prefix "persist_txns_" for historical reasons. New
441// configs should use the prefix "txn_wal_".
442
443pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
444    "persist_txns_data_shard_retryer_initial_backoff",
445    Duration::from_millis(1024),
446    "The initial backoff when polling for new batches from a txns data shard persist_source.",
447);
448
449pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
450    "persist_txns_data_shard_retryer_multiplier",
451    2,
452    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
453);
454
455pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
456    "persist_txns_data_shard_retryer_clamp",
457    Duration::from_secs(16),
458    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
459);
460
461/// Retry configuration for txn-wal data shard override of
462/// `next_listen_batch`.
463pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
464    RetryParameters {
465        fixed_sleep: Duration::ZERO,
466        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
467        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
468        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
469    }
470}
471
472/// A helper for subscribing to a data shard using the timely operators.
473///
474/// This could instead be a wrapper around a [Subscribe], but it's only used in
475/// tests and maelstrom, so do it by wrapping the timely operators to get
476/// additional coverage. For the same reason, hardcode the K, V, T, D types.
477///
478/// [Subscribe]: mz_persist_client::read::Subscribe
479pub struct DataSubscribe {
480    pub(crate) as_of: u64,
481    pub(crate) worker: Worker,
482    data: ProbeHandle<u64>,
483    txns: ProbeHandle<u64>,
484    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
485    output: Vec<(String, u64, i64)>,
486
487    _tokens: Vec<PressOnDropButton>,
488}
489
490impl std::fmt::Debug for DataSubscribe {
491    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
492        let DataSubscribe {
493            as_of,
494            worker: _,
495            data,
496            txns,
497            capture: _,
498            output,
499            _tokens: _,
500        } = self;
501        f.debug_struct("DataSubscribe")
502            .field("as_of", as_of)
503            .field("data", data)
504            .field("txns", txns)
505            .field("output", output)
506            .finish_non_exhaustive()
507    }
508}
509
510impl DataSubscribe {
511    /// Creates a new [DataSubscribe].
512    pub fn new(
513        name: &str,
514        client: PersistClient,
515        txns_id: ShardId,
516        data_id: ShardId,
517        as_of: u64,
518        until: Antichain<u64>,
519    ) -> Self {
520        let mut worker = Worker::new(
521            WorkerConfig::default(),
522            timely::communication::Allocator::Thread(
523                timely::communication::allocator::Thread::default(),
524            ),
525            Some(std::time::Instant::now()),
526        );
527        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|outer| {
528            let (data_stream, shard_source_token) = outer.scoped::<u64, _, _>("hybrid", |scope| {
529                let client = client.clone();
530                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
531                    outer,
532                    scope,
533                    name,
534                    move || std::future::ready(client.clone()),
535                    data_id,
536                    Some(Antichain::from_elem(as_of)),
537                    SnapshotMode::Include,
538                    until.clone(),
539                    false.then_some(|_, _, _| unreachable!()),
540                    Arc::new(StringSchema),
541                    Arc::new(UnitSchema),
542                    FilterResult::keep_all,
543                    false.then_some(|| unreachable!()),
544                    async {},
545                    ErrorHandler::Halt("data_subscribe"),
546                );
547                (data_stream.leave(outer), token)
548            });
549            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
550            let data_stream = data_stream.flat_map(|part| {
551                let part = part.parse();
552                part.part.map(|((k, ()), t, d)| (k, t, d))
553            });
554            let data_stream = data_stream.probe_with(&data);
555            let (data_stream, mut txns_progress_token) =
556                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _>(
557                    data_stream,
558                    name,
559                    &TxnsContext::default(),
560                    || std::future::ready(client.clone()),
561                    txns_id,
562                    data_id,
563                    as_of,
564                    until,
565                    Arc::new(StringSchema),
566                    Arc::new(UnitSchema),
567                );
568            let data_stream = data_stream.probe_with(&txns);
569            let mut tokens = shard_source_token;
570            tokens.append(&mut txns_progress_token);
571            (data, txns, data_stream.capture(), tokens)
572        });
573        Self {
574            as_of,
575            worker,
576            data,
577            txns,
578            capture,
579            output: Vec::new(),
580            _tokens: tokens,
581        }
582    }
583
584    /// Returns the exclusive progress of the dataflow.
585    pub fn progress(&self) -> u64 {
586        self.txns
587            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
588    }
589
590    /// Steps the dataflow, capturing output.
591    pub fn step(&mut self) {
592        self.worker.step();
593        self.capture_output()
594    }
595
596    pub(crate) fn capture_output(&mut self) {
597        loop {
598            let event = match self.capture.try_recv() {
599                Ok(x) => x,
600                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
601            };
602            match event {
603                Event::Progress(_) => {}
604                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
605            }
606        }
607    }
608
609    /// Steps the dataflow past the given time, capturing output.
610    #[cfg(test)]
611    pub async fn step_past(&mut self, ts: u64) {
612        while self.txns.less_equal(&ts) {
613            tracing::trace!(
614                "progress at {:?}",
615                self.txns.with_frontier(|x| x.to_owned()).elements()
616            );
617            self.step();
618            tokio::task::yield_now().await;
619        }
620    }
621
622    /// Returns captured output.
623    pub fn output(&self) -> &Vec<(String, u64, i64)> {
624        &self.output
625    }
626}
627
628/// A handle to a [DataSubscribe] running in a task.
629#[derive(Debug)]
630pub struct DataSubscribeTask {
631    /// Carries step requests. A `None` timestamp requests one step, a
632    /// `Some(ts)` requests stepping until we progress beyond `ts`.
633    tx: std::sync::mpsc::Sender<(
634        Option<u64>,
635        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
636    )>,
637    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
638    output: Vec<(String, u64, i64)>,
639    progress: u64,
640}
641
642impl DataSubscribeTask {
643    /// Creates a new [DataSubscribeTask].
644    pub async fn new(
645        client: PersistClient,
646        txns_id: ShardId,
647        data_id: ShardId,
648        as_of: u64,
649    ) -> Self {
650        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
651        let (tx, rx) = std::sync::mpsc::channel();
652        let task = mz_ore::task::spawn_blocking(
653            || "data_subscribe task",
654            move || Self::task(client, cache, data_id, as_of, rx),
655        );
656        DataSubscribeTask {
657            tx,
658            task,
659            output: Vec::new(),
660            progress: 0,
661        }
662    }
663
664    #[cfg(test)]
665    async fn step(&mut self) {
666        self.send(None).await;
667    }
668
669    /// Steps the dataflow past the given time, capturing output.
670    pub async fn step_past(&mut self, ts: u64) -> u64 {
671        self.send(Some(ts)).await;
672        self.progress
673    }
674
675    /// Returns captured output.
676    pub fn output(&self) -> &Vec<(String, u64, i64)> {
677        &self.output
678    }
679
680    async fn send(&mut self, ts: Option<u64>) {
681        let (tx, rx) = tokio::sync::oneshot::channel();
682        self.tx.send((ts, tx)).expect("task should be running");
683        let (mut new_output, new_progress) = rx.await.expect("task should be running");
684        self.output.append(&mut new_output);
685        assert!(self.progress <= new_progress);
686        self.progress = new_progress;
687    }
688
689    /// Signals for the task to exit, and then waits for this to happen.
690    ///
691    /// _All_ output from the lifetime of the task (not just what was previously
692    /// captured) is returned.
693    pub async fn finish(self) -> Vec<(String, u64, i64)> {
694        // Closing the channel signals the task to exit.
695        drop(self.tx);
696        self.task.await
697    }
698
699    fn task(
700        client: PersistClient,
701        cache: TxnsCache<u64>,
702        data_id: ShardId,
703        as_of: u64,
704        rx: std::sync::mpsc::Receiver<(
705            Option<u64>,
706            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
707        )>,
708    ) -> Vec<(String, u64, i64)> {
709        let mut subscribe = DataSubscribe::new(
710            "DataSubscribeTask",
711            client.clone(),
712            cache.txns_id(),
713            data_id,
714            as_of,
715            Antichain::new(),
716        );
717        let mut output = Vec::new();
718        loop {
719            let (ts, tx) = match rx.try_recv() {
720                Ok(x) => x,
721                Err(TryRecvError::Empty) => {
722                    // No requests, continue stepping so nothing deadlocks.
723                    subscribe.step();
724                    continue;
725                }
726                Err(TryRecvError::Disconnected) => {
727                    // All done! Return our output.
728                    return output;
729                }
730            };
731            // Always step at least once.
732            subscribe.step();
733            // If we got a ts, make sure to step past it.
734            if let Some(ts) = ts {
735                while subscribe.progress() <= ts {
736                    subscribe.step();
737                }
738            }
739            let new_output = std::mem::take(&mut subscribe.output);
740            output.extend(new_output.iter().cloned());
741            let _ = tx.send((new_output, subscribe.progress()));
742        }
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use itertools::{Either, Itertools};
749
750    use crate::tests::writer;
751    use crate::txns::TxnsHandle;
752
753    use super::*;
754
755    impl<K, V, T, D, C> TxnsHandle<K, V, T, D, C>
756    where
757        K: Debug + Codec,
758        V: Debug + Codec,
759        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
760        D: Debug + Monoid + Ord + Codec64 + Send + Sync,
761        C: TxnsCodec,
762    {
763        async fn subscribe_task(
764            &self,
765            client: &PersistClient,
766            data_id: ShardId,
767            as_of: u64,
768        ) -> DataSubscribeTask {
769            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
770        }
771    }
772
773    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
774    #[cfg_attr(miri, ignore)] // too slow
775    async fn data_subscribe() {
776        async fn step(subs: &mut Vec<DataSubscribeTask>) {
777            for sub in subs.iter_mut() {
778                sub.step().await;
779            }
780        }
781
782        let client = PersistClient::new_for_tests().await;
783        let mut txns = TxnsHandle::expect_open(client.clone()).await;
784        let log = txns.new_log();
785        let d0 = ShardId::new();
786
787        // Start a subscription before the shard gets registered.
788        let mut subs = Vec::new();
789        subs.push(txns.subscribe_task(&client, d0, 5).await);
790        step(&mut subs).await;
791
792        // Now register the shard. Also start a new subscription and step the
793        // previous one (plus repeat this for every later step).
794        txns.register(1, [writer(&client, d0).await]).await.unwrap();
795        subs.push(txns.subscribe_task(&client, d0, 5).await);
796        step(&mut subs).await;
797
798        // Now write something unrelated.
799        let d1 = txns.expect_register(2).await;
800        txns.expect_commit_at(3, d1, &["nope"], &log).await;
801        subs.push(txns.subscribe_task(&client, d0, 5).await);
802        step(&mut subs).await;
803
804        // Now write to our shard before.
805        txns.expect_commit_at(4, d0, &["4"], &log).await;
806        subs.push(txns.subscribe_task(&client, d0, 5).await);
807        step(&mut subs).await;
808
809        // Now write to our shard at the as_of.
810        txns.expect_commit_at(5, d0, &["5"], &log).await;
811        subs.push(txns.subscribe_task(&client, d0, 5).await);
812        step(&mut subs).await;
813
814        // Now write to our shard past the as_of.
815        txns.expect_commit_at(6, d0, &["6"], &log).await;
816        subs.push(txns.subscribe_task(&client, d0, 5).await);
817        step(&mut subs).await;
818
819        // Now write something unrelated again.
820        txns.expect_commit_at(7, d1, &["nope"], &log).await;
821        subs.push(txns.subscribe_task(&client, d0, 5).await);
822        step(&mut subs).await;
823
824        // Verify that the dataflows can progress to the expected point and that
825        // we read the right thing no matter when the dataflow started.
826        for mut sub in subs {
827            let progress = sub.step_past(7).await;
828            assert_eq!(progress, 8);
829            log.assert_eq(d0, 5, 8, sub.finish().await);
830        }
831    }
832
833    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
834    #[cfg_attr(miri, ignore)] // too slow
835    async fn subscribe_shard_finalize() {
836        let client = PersistClient::new_for_tests().await;
837        let mut txns = TxnsHandle::expect_open(client.clone()).await;
838        let log = txns.new_log();
839        let d0 = txns.expect_register(1).await;
840
841        // Start the operator as_of the register ts.
842        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
843        sub.step_past(1).await;
844
845        // Write to it via txns.
846        txns.expect_commit_at(2, d0, &["foo"], &log).await;
847        sub.step_past(2).await;
848
849        // Unregister it.
850        txns.forget(3, [d0]).await.unwrap();
851        sub.step_past(3).await;
852
853        // TODO: Hard mode, see if we can get the rest of this test to work even
854        // _without_ the txns shard advancing.
855        txns.begin().commit_at(&mut txns, 7).await.unwrap();
856
857        // The operator should continue to emit data written directly even
858        // though it's no longer in the txns set.
859        let mut d0_write = writer(&client, d0).await;
860        let key = "bar".to_owned();
861        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
862            .await
863            .unwrap();
864        log.record((d0, key, 5, 1));
865        sub.step_past(4).await;
866
867        // Now finalize the shard to writes.
868        let () = d0_write
869            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new(), true)
870            .await
871            .unwrap()
872            .unwrap();
873        while sub.txns.less_than(&u64::MAX) {
874            sub.step();
875            tokio::task::yield_now().await;
876        }
877
878        // Make sure we read the correct things.
879        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
880
881        // Also make sure that we can read the right things if we start up after
882        // the forget but before the direct write and ditto after the direct
883        // write.
884        log.assert_subscribe(d0, 4, u64::MAX).await;
885        log.assert_subscribe(d0, 6, u64::MAX).await;
886    }
887
888    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
889    #[cfg_attr(miri, ignore)] // too slow
890    async fn subscribe_shard_register_forget() {
891        let client = PersistClient::new_for_tests().await;
892        let mut txns = TxnsHandle::expect_open(client.clone()).await;
893        let d0 = ShardId::new();
894
895        // Start a subscription on the data shard.
896        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
897        assert_eq!(sub.progress(), 0);
898
899        // Register the shard at 10.
900        txns.register(10, [writer(&client, d0).await])
901            .await
902            .unwrap();
903        sub.step_past(10).await;
904        assert!(
905            sub.progress() > 10,
906            "operator should advance past 10 when shard is registered"
907        );
908
909        // Forget the shard at 20.
910        txns.forget(20, [d0]).await.unwrap();
911        sub.step_past(20).await;
912        assert!(
913            sub.progress() > 20,
914            "operator should advance past 20 when shard is forgotten"
915        );
916    }
917
918    #[mz_ore::test(tokio::test)]
919    #[cfg_attr(miri, ignore)] // too slow
920    async fn as_of_until() {
921        let client = PersistClient::new_for_tests().await;
922        let mut txns = TxnsHandle::expect_open(client.clone()).await;
923        let log = txns.new_log();
924
925        let d0 = txns.expect_register(1).await;
926        txns.expect_commit_at(2, d0, &["2"], &log).await;
927        txns.expect_commit_at(3, d0, &["3"], &log).await;
928        txns.expect_commit_at(4, d0, &["4"], &log).await;
929        txns.expect_commit_at(5, d0, &["5"], &log).await;
930        txns.expect_commit_at(6, d0, &["6"], &log).await;
931        txns.expect_commit_at(7, d0, &["7"], &log).await;
932
933        let until = 5;
934        let mut sub = DataSubscribe::new(
935            "as_of_until",
936            client,
937            txns.txns_id(),
938            d0,
939            3,
940            Antichain::from_elem(until),
941        );
942        // Manually step the dataflow, instead of going through the
943        // `DataSubscribe` helper because we're interested in all captured
944        // events.
945        while sub.txns.less_equal(&5) {
946            sub.worker.step();
947            tokio::task::yield_now().await;
948            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
949        }
950        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
951            sub.capture.into_iter().partition_map(|event| match event {
952                Event::Progress(progress) => Either::Left(progress),
953                Event::Messages(ts, data) => Either::Right((ts, data)),
954            });
955        let expected = vec![
956            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
957            (3, vec![("4".to_owned(), 4, 1)]),
958        ];
959        assert_eq!(actual_events, expected);
960
961        // The number and contents of progress messages is not guaranteed and
962        // depends on the downgrade behavior. The only thing we can assert is
963        // the max progress timestamp, if there is one, is less than the until.
964        if let Some(max_progress_ts) = actual_progresses
965            .into_iter()
966            .flatten()
967            .map(|(ts, _diff)| ts)
968            .max()
969        {
970            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
971        }
972    }
973}