Skip to main content

mz_txn_wal/
operator.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Timely operators for the crate
11
12use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Monoid;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet};
24use mz_ore::cast::CastFrom;
25use mz_persist_client::cfg::RetryParameters;
26use mz_persist_client::operators::shard_source::{
27    ErrorHandler, FilterResult, SnapshotMode, shard_source,
28};
29use mz_persist_client::{Diagnostics, PersistClient, ShardId};
30use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
31use mz_persist_types::txn::TxnsCodec;
32use mz_persist_types::{Codec, Codec64, StepForward};
33use mz_timely_util::builder_async::{
34    AsyncInputHandle, Event as AsyncEvent, InputConnection,
35    OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
36};
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pact::Pipeline;
39use timely::dataflow::operators::capture::Event;
40use timely::dataflow::operators::vec::{Broadcast, Map};
41use timely::dataflow::operators::{Capture, Leave, Probe};
42use timely::dataflow::{ProbeHandle, Scope, StreamVec};
43use timely::order::TotalOrder;
44use timely::progress::{Antichain, Timestamp};
45use timely::worker::Worker;
46use timely::{PartialOrder, WorkerConfig};
47use tracing::debug;
48
49use crate::TxnsCodecDefault;
50use crate::txn_cache::TxnsCache;
51use crate::txn_read::{DataRemapEntry, TxnsRead};
52
53/// An operator for translating physical data shard frontiers into logical ones.
54///
55/// A data shard in the txns set logically advances its upper each time a txn is
56/// committed, but the upper is not physically advanced unless that data shard
57/// was involved in the txn. This means that a shard_source (or any read)
58/// pointed at a data shard would appear to stall at the time of the most recent
59/// write. We fix this for shard_source by flowing its output through a new
60/// `txns_progress` dataflow operator, which ensures that the
61/// frontier/capability is advanced as the txns shard progresses, as long as the
62/// shard_source is up to date with the latest committed write to that data
63/// shard.
64///
65/// Example:
66///
67/// - A data shard has most recently been written to at 3.
68/// - The txns shard's upper is at 6.
69/// - We render a dataflow containing a shard_source with an as_of of 5.
70/// - A txn NOT involving the data shard is committed at 7.
71/// - A txn involving the data shard is committed at 9.
72///
73/// How it works:
74///
75/// - The shard_source operator is rendered. Its single output is hooked up as a
76///   _disconnected_ input to txns_progress. The txns_progress single output is
77///   a stream of the same type, which is used by downstream operators. This
78///   txns_progress operator is targeted at one data_shard; rendering a
79///   shard_source for a second data shard requires a second txns_progress
80///   operator.
81/// - The shard_source operator emits data through 3 and advances the frontier.
82/// - The txns_progress operator passes through these writes and frontier
83///   advancements unchanged. (Recall that it's always correct to read a data
84///   shard "normally", it just might stall.) Because the txns_progress operator
85///   knows there are no writes in `[3,5]`, it then downgrades its own
86///   capability past 5 (to 6). Because the input is disconnected, this means
87///   the overall frontier of the output is downgraded to 6.
88/// - The txns_progress operator learns about the write at 7 (the upper is now
89///   8). Because it knows that the data shard was not involved in this, it's
90///   free to downgrade its capability to 8.
91/// - The txns_progress operator learns about the write at 9 (the upper is now
92///   10). It knows that the data shard _WAS_ involved in this, so it forwards
93///   on data from its input until the input has progressed to 10, at which
94///   point it can itself downgrade to 10.
95pub fn txns_progress<K, V, T, D, P, C, F, G>(
96    passthrough: StreamVec<G, P>,
97    name: &str,
98    ctx: &TxnsContext,
99    client_fn: impl Fn() -> F,
100    txns_id: ShardId,
101    data_id: ShardId,
102    as_of: T,
103    until: Antichain<T>,
104    data_key_schema: Arc<K::Schema>,
105    data_val_schema: Arc<V::Schema>,
106) -> (StreamVec<G, P>, Vec<PressOnDropButton>)
107where
108    K: Debug + Codec + Send + Sync,
109    V: Debug + Codec + Send + Sync,
110    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
111    D: Debug + Clone + 'static + Monoid + Ord + Codec64 + Send + Sync,
112    P: Debug + Clone + 'static,
113    C: TxnsCodec + 'static,
114    F: Future<Output = PersistClient> + Send + 'static,
115    G: Scope<Timestamp = T>,
116{
117    let unique_id = (name, passthrough.scope().addr()).hashed();
118    let (remap, source_button) = txns_progress_source_global::<K, V, T, D, P, C, G>(
119        passthrough.scope(),
120        name,
121        ctx.clone(),
122        client_fn(),
123        txns_id,
124        data_id,
125        as_of,
126        data_key_schema,
127        data_val_schema,
128        unique_id,
129    );
130    // Each of the `txns_frontiers` workers wants the full copy of the remap
131    // information.
132    let remap = remap.broadcast();
133    let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
134        remap,
135        passthrough,
136        name,
137        data_id,
138        until,
139        unique_id,
140    );
141    (passthrough, vec![source_button, frontiers_button])
142}
143
144/// TODO: I'd much prefer the communication protocol between the two operators
145/// to be exactly remap as defined in the [reclocking design doc]. However, we
146/// can't quite recover exactly the information necessary to construct that at
147/// the moment. Seems worth doing, but in the meantime, intentionally make this
148/// look fairly different (`Stream` of `DataRemapEntry` instead of
149/// `Collection<FromTime>`) to hopefully minimize confusion. As a performance
150/// optimization, we only re-emit this when the _physical_ upper has changed,
151/// which means that the frontier of the `Stream<DataRemapEntry<T>>` indicates
152/// updates to the logical_upper of the most recent `DataRemapEntry` (i.e. the
153/// one with the largest physical_upper).
154///
155/// [reclocking design doc]:
156///     https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/20210714_reclocking.md
157fn txns_progress_source_global<K, V, T, D, P, C, G>(
158    scope: G,
159    name: &str,
160    ctx: TxnsContext,
161    client: impl Future<Output = PersistClient> + 'static,
162    txns_id: ShardId,
163    data_id: ShardId,
164    as_of: T,
165    data_key_schema: Arc<K::Schema>,
166    data_val_schema: Arc<V::Schema>,
167    unique_id: u64,
168) -> (StreamVec<G, DataRemapEntry<T>>, PressOnDropButton)
169where
170    K: Debug + Codec + Send + Sync,
171    V: Debug + Codec + Send + Sync,
172    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
173    D: Debug + Clone + 'static + Monoid + Ord + Codec64 + Send + Sync,
174    P: Debug + Clone + 'static,
175    C: TxnsCodec + 'static,
176    G: Scope<Timestamp = T>,
177{
178    let worker_idx = scope.index();
179    let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
180    let name = format!("txns_progress_source({})", name);
181    let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
182    let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
183    let (remap_output, remap_stream) = builder.new_output::<CapacityContainerBuilder<Vec<_>>>();
184
185    let shutdown_button = builder.build(move |capabilities| async move {
186        if worker_idx != chosen_worker {
187            return;
188        }
189
190        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
191        let client = client.await;
192        let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
193
194        let _ = txns_read.update_gt(as_of.clone()).await;
195        let data_write = client
196            .open_writer::<K, V, T, D>(
197                data_id,
198                Arc::clone(&data_key_schema),
199                Arc::clone(&data_val_schema),
200                Diagnostics::from_purpose("data read physical upper"),
201            )
202            .await
203            .expect("schema shouldn't change");
204        let mut rx = txns_read
205            .data_subscribe(data_id, as_of.clone(), data_write)
206            .await;
207        debug!("{} starting as_of={:?}", name, as_of);
208
209        let mut physical_upper = T::minimum();
210
211        while let Some(remap) = rx.recv().await {
212            assert!(physical_upper <= remap.physical_upper);
213            assert!(physical_upper < remap.logical_upper);
214
215            let logical_upper = remap.logical_upper.clone();
216            // As mentioned in the docs on this function, we only
217            // emit updates when the physical upper changes (which
218            // happens to makes the protocol a tiny bit more
219            // remap-like).
220            if remap.physical_upper != physical_upper {
221                physical_upper = remap.physical_upper.clone();
222                debug!("{} emitting {:?}", name, remap);
223                remap_output.give(&cap, remap);
224            } else {
225                debug!("{} not emitting {:?}", name, remap);
226            }
227            cap.downgrade(&logical_upper);
228        }
229    });
230    (remap_stream, shutdown_button.press_on_drop())
231}
232
233fn txns_progress_frontiers<K, V, T, D, P, C, G>(
234    remap: StreamVec<G, DataRemapEntry<T>>,
235    passthrough: StreamVec<G, P>,
236    name: &str,
237    data_id: ShardId,
238    until: Antichain<T>,
239    unique_id: u64,
240) -> (StreamVec<G, P>, PressOnDropButton)
241where
242    K: Debug + Codec,
243    V: Debug + Codec,
244    T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
245    D: Clone + 'static + Monoid + Codec64 + Send + Sync,
246    P: Debug + Clone + 'static,
247    C: TxnsCodec,
248    G: Scope<Timestamp = T>,
249{
250    let name = format!("txns_progress_frontiers({})", name);
251    let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
252    let name = format!(
253        "{} [{}] {}/{} {:.9}",
254        name,
255        unique_id,
256        passthrough.scope().index(),
257        passthrough.scope().peers(),
258        data_id.to_string(),
259    );
260    let (passthrough_output, passthrough_stream) =
261        builder.new_output::<CapacityContainerBuilder<Vec<_>>>();
262    let mut remap_input = builder.new_disconnected_input(remap, Pipeline);
263    let mut passthrough_input = builder.new_disconnected_input(passthrough, Pipeline);
264
265    let shutdown_button = builder.build(move |capabilities| async move {
266        let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
267
268        // None is used to indicate that both uppers are the empty antichain.
269        let mut remap = Some(DataRemapEntry {
270            physical_upper: T::minimum(),
271            logical_upper: T::minimum(),
272        });
273        // NB: The following loop uses `cap.time()`` to track how far we've
274        // progressed in copying along the passthrough input.
275        loop {
276            debug!("{} remap {:?}", name, remap);
277            if let Some(r) = remap.as_ref() {
278                assert!(r.physical_upper <= r.logical_upper);
279                // If we've passed through data to at least `physical_upper`,
280                // then it means we can artificially advance the upper of the
281                // output to `logical_upper`. This also indicates that we need
282                // to wait for the next DataRemapEntry. It can either (A) have
283                // the same physical upper or (B) have a larger physical upper.
284                //
285                // - If (A), then we would again satisfy this `physical_upper`
286                //   check, again advance the logical upper again, ...
287                // - If (B), then we'd fall down to the code below, which copies
288                //   the passthrough data until the frontier passes
289                //   `physical_upper`, then loops back up here.
290                if r.physical_upper.less_equal(cap.time()) {
291                    if cap.time() < &r.logical_upper {
292                        cap.downgrade(&r.logical_upper);
293                    }
294                    remap = txns_progress_frontiers_read_remap_input(
295                        &name,
296                        &mut remap_input,
297                        r.clone(),
298                    )
299                    .await;
300                    continue;
301                }
302            }
303
304            // This only returns None when there are no more data left. Turn it
305            // into an empty frontier progress so we can re-use the shutdown
306            // code below.
307            let event = passthrough_input
308                .next()
309                .await
310                .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
311            match event {
312                // NB: Ignore the data_cap because this input is disconnected.
313                AsyncEvent::Data(_data_cap, mut data) => {
314                    // NB: Nothing to do here for `until` because both the
315                    // `shard_source` (before this operator) and
316                    // `mfp_and_decode` (after this operator) do the necessary
317                    // filtering.
318                    debug!("{} emitting data {:?}", name, data);
319                    passthrough_output.give_container(&cap, &mut data);
320                }
321                AsyncEvent::Progress(new_progress) => {
322                    // If `until.less_equal(new_progress)`, it means that all
323                    // subsequent batches will contain only times greater or
324                    // equal to `until`, which means they can be dropped in
325                    // their entirety.
326                    //
327                    // Ideally this check would live in `txns_progress_source`,
328                    // but that turns out to be much more invasive (requires
329                    // replacing lots of `T`s with `Antichain<T>`s). Given that
330                    // we've been thinking about reworking the operators, do the
331                    // easy but more wasteful thing for now.
332                    if PartialOrder::less_equal(&until, &new_progress) {
333                        debug!(
334                            "{} progress {:?} has passed until {:?}",
335                            name,
336                            new_progress.elements(),
337                            until.elements()
338                        );
339                        return;
340                    }
341                    // We reached the empty frontier! Shut down.
342                    let Some(new_progress) = new_progress.into_option() else {
343                        return;
344                    };
345
346                    // Recall that any reads of the data shard are always
347                    // correct, so given that we've passed through any data
348                    // from the input, that means we're free to pass through
349                    // frontier updates too.
350                    if cap.time() < &new_progress {
351                        debug!("{} downgrading cap to {:?}", name, new_progress);
352                        cap.downgrade(&new_progress);
353                    }
354                }
355            }
356        }
357    });
358    (passthrough_stream, shutdown_button.press_on_drop())
359}
360
361async fn txns_progress_frontiers_read_remap_input<T, C>(
362    name: &str,
363    input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
364    mut remap: DataRemapEntry<T>,
365) -> Option<DataRemapEntry<T>>
366where
367    T: Timestamp + TotalOrder,
368    C: InputConnection<T>,
369{
370    while let Some(event) = input.next().await {
371        let xs = match event {
372            AsyncEvent::Progress(logical_upper) => {
373                if let Some(logical_upper) = logical_upper.into_option() {
374                    if remap.logical_upper < logical_upper {
375                        remap.logical_upper = logical_upper;
376                        return Some(remap);
377                    }
378                }
379                continue;
380            }
381            AsyncEvent::Data(_cap, xs) => xs,
382        };
383        for x in xs {
384            debug!("{} got remap {:?}", name, x);
385            // Don't assume anything about the ordering.
386            if remap.logical_upper < x.logical_upper {
387                assert!(
388                    remap.physical_upper <= x.physical_upper,
389                    "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
390                    remap.physical_upper,
391                    x.physical_upper,
392                );
393                // TODO: If the physical upper has advanced, that's a very
394                // strong hint that the data shard is about to be written to.
395                // Because the data shard's upper advances sparsely (on write,
396                // but not on passage of time) which invalidates the "every 1s"
397                // assumption of the default tuning, we've had to de-tune the
398                // listen sleeps on the paired persist_source. Maybe we use "one
399                // state" to wake it up in case pubsub doesn't and remove the
400                // listen polling entirely? (NB: This would have to happen in
401                // each worker so that it's guaranteed to happen in each
402                // process.)
403                remap = x;
404            }
405        }
406        return Some(remap);
407    }
408    // remap_input is closed, which indicates the data shard is finished.
409    None
410}
411
412/// The process global [`TxnsRead`] that any operator can communicate with.
413#[derive(Default, Debug, Clone)]
414pub struct TxnsContext {
415    read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
416}
417
418impl TxnsContext {
419    async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
420    where
421        T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
422        C: TxnsCodec + 'static,
423    {
424        let read = self
425            .read
426            .get_or_init(|| {
427                let client = client.clone();
428                async move {
429                    let read: Box<dyn Any + Send + Sync> =
430                        Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
431                    read
432                }
433            })
434            .await
435            .downcast_ref::<TxnsRead<T>>()
436            .expect("timestamp types should match");
437        // We initially only have one txns shard in the system.
438        assert_eq!(&txns_id, read.txns_id());
439        read.clone()
440    }
441}
442
443// Existing configs use the prefix "persist_txns_" for historical reasons. New
444// configs should use the prefix "txn_wal_".
445
446pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
447    "persist_txns_data_shard_retryer_initial_backoff",
448    Duration::from_millis(1024),
449    "The initial backoff when polling for new batches from a txns data shard persist_source.",
450);
451
452pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
453    "persist_txns_data_shard_retryer_multiplier",
454    2,
455    "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
456);
457
458pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
459    "persist_txns_data_shard_retryer_clamp",
460    Duration::from_secs(16),
461    "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
462);
463
464/// Retry configuration for txn-wal data shard override of
465/// `next_listen_batch`.
466pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
467    RetryParameters {
468        fixed_sleep: Duration::ZERO,
469        initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
470        multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
471        clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
472    }
473}
474
475/// A helper for subscribing to a data shard using the timely operators.
476///
477/// This could instead be a wrapper around a [Subscribe], but it's only used in
478/// tests and maelstrom, so do it by wrapping the timely operators to get
479/// additional coverage. For the same reason, hardcode the K, V, T, D types.
480///
481/// [Subscribe]: mz_persist_client::read::Subscribe
482pub struct DataSubscribe {
483    pub(crate) as_of: u64,
484    pub(crate) worker: Worker<timely::communication::allocator::Thread>,
485    data: ProbeHandle<u64>,
486    txns: ProbeHandle<u64>,
487    capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
488    output: Vec<(String, u64, i64)>,
489
490    _tokens: Vec<PressOnDropButton>,
491}
492
493impl std::fmt::Debug for DataSubscribe {
494    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        let DataSubscribe {
496            as_of,
497            worker: _,
498            data,
499            txns,
500            capture: _,
501            output,
502            _tokens: _,
503        } = self;
504        f.debug_struct("DataSubscribe")
505            .field("as_of", as_of)
506            .field("data", data)
507            .field("txns", txns)
508            .field("output", output)
509            .finish_non_exhaustive()
510    }
511}
512
513impl DataSubscribe {
514    /// Creates a new [DataSubscribe].
515    pub fn new(
516        name: &str,
517        client: PersistClient,
518        txns_id: ShardId,
519        data_id: ShardId,
520        as_of: u64,
521        until: Antichain<u64>,
522    ) -> Self {
523        let mut worker = Worker::new(
524            WorkerConfig::default(),
525            timely::communication::allocator::Thread::default(),
526            Some(std::time::Instant::now()),
527        );
528        let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
529            let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
530                let client = client.clone();
531                let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
532                    scope,
533                    name,
534                    move || std::future::ready(client.clone()),
535                    data_id,
536                    Some(Antichain::from_elem(as_of)),
537                    SnapshotMode::Include,
538                    until.clone(),
539                    false.then_some(|_, _, _| unreachable!()),
540                    Arc::new(StringSchema),
541                    Arc::new(UnitSchema),
542                    FilterResult::keep_all,
543                    false.then_some(|| unreachable!()),
544                    async {},
545                    ErrorHandler::Halt("data_subscribe"),
546                );
547                (data_stream.leave(), token)
548            });
549            let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
550            let data_stream = data_stream.flat_map(|part| {
551                let part = part.parse();
552                part.part.map(|((k, ()), t, d)| (k, t, d))
553            });
554            let data_stream = data_stream.probe_with(&data);
555            let (data_stream, mut txns_progress_token) =
556                txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
557                    data_stream,
558                    name,
559                    &TxnsContext::default(),
560                    || std::future::ready(client.clone()),
561                    txns_id,
562                    data_id,
563                    as_of,
564                    until,
565                    Arc::new(StringSchema),
566                    Arc::new(UnitSchema),
567                );
568            let data_stream = data_stream.probe_with(&txns);
569            let mut tokens = shard_source_token;
570            tokens.append(&mut txns_progress_token);
571            (data, txns, data_stream.capture(), tokens)
572        });
573        Self {
574            as_of,
575            worker,
576            data,
577            txns,
578            capture,
579            output: Vec::new(),
580            _tokens: tokens,
581        }
582    }
583
584    /// Returns the exclusive progress of the dataflow.
585    pub fn progress(&self) -> u64 {
586        self.txns
587            .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
588    }
589
590    /// Steps the dataflow, capturing output.
591    pub fn step(&mut self) {
592        self.worker.step();
593        self.capture_output()
594    }
595
596    pub(crate) fn capture_output(&mut self) {
597        loop {
598            let event = match self.capture.try_recv() {
599                Ok(x) => x,
600                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
601            };
602            match event {
603                Event::Progress(_) => {}
604                Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
605            }
606        }
607    }
608
609    /// Steps the dataflow past the given time, capturing output.
610    #[cfg(test)]
611    pub async fn step_past(&mut self, ts: u64) {
612        while self.txns.less_equal(&ts) {
613            tracing::trace!(
614                "progress at {:?}",
615                self.txns.with_frontier(|x| x.to_owned()).elements()
616            );
617            self.step();
618            tokio::task::yield_now().await;
619        }
620    }
621
622    /// Returns captured output.
623    pub fn output(&self) -> &Vec<(String, u64, i64)> {
624        &self.output
625    }
626}
627
628/// A handle to a [DataSubscribe] running in a task.
629#[derive(Debug)]
630pub struct DataSubscribeTask {
631    /// Carries step requests. A `None` timestamp requests one step, a
632    /// `Some(ts)` requests stepping until we progress beyond `ts`.
633    tx: std::sync::mpsc::Sender<(
634        Option<u64>,
635        tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
636    )>,
637    task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
638    output: Vec<(String, u64, i64)>,
639    progress: u64,
640}
641
642impl DataSubscribeTask {
643    /// Creates a new [DataSubscribeTask].
644    pub async fn new(
645        client: PersistClient,
646        txns_id: ShardId,
647        data_id: ShardId,
648        as_of: u64,
649    ) -> Self {
650        let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
651        let (tx, rx) = std::sync::mpsc::channel();
652        let task = mz_ore::task::spawn_blocking(
653            || "data_subscribe task",
654            move || Self::task(client, cache, data_id, as_of, rx),
655        );
656        DataSubscribeTask {
657            tx,
658            task,
659            output: Vec::new(),
660            progress: 0,
661        }
662    }
663
664    #[cfg(test)]
665    async fn step(&mut self) {
666        self.send(None).await;
667    }
668
669    /// Steps the dataflow past the given time, capturing output.
670    pub async fn step_past(&mut self, ts: u64) -> u64 {
671        self.send(Some(ts)).await;
672        self.progress
673    }
674
675    /// Returns captured output.
676    pub fn output(&self) -> &Vec<(String, u64, i64)> {
677        &self.output
678    }
679
680    async fn send(&mut self, ts: Option<u64>) {
681        let (tx, rx) = tokio::sync::oneshot::channel();
682        self.tx.send((ts, tx)).expect("task should be running");
683        let (mut new_output, new_progress) = rx.await.expect("task should be running");
684        self.output.append(&mut new_output);
685        assert!(self.progress <= new_progress);
686        self.progress = new_progress;
687    }
688
689    /// Signals for the task to exit, and then waits for this to happen.
690    ///
691    /// _All_ output from the lifetime of the task (not just what was previously
692    /// captured) is returned.
693    pub async fn finish(self) -> Vec<(String, u64, i64)> {
694        // Closing the channel signals the task to exit.
695        drop(self.tx);
696        self.task.await
697    }
698
699    fn task(
700        client: PersistClient,
701        cache: TxnsCache<u64>,
702        data_id: ShardId,
703        as_of: u64,
704        rx: std::sync::mpsc::Receiver<(
705            Option<u64>,
706            tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
707        )>,
708    ) -> Vec<(String, u64, i64)> {
709        let mut subscribe = DataSubscribe::new(
710            "DataSubscribeTask",
711            client.clone(),
712            cache.txns_id(),
713            data_id,
714            as_of,
715            Antichain::new(),
716        );
717        let mut output = Vec::new();
718        loop {
719            let (ts, tx) = match rx.try_recv() {
720                Ok(x) => x,
721                Err(TryRecvError::Empty) => {
722                    // No requests, continue stepping so nothing deadlocks.
723                    subscribe.step();
724                    continue;
725                }
726                Err(TryRecvError::Disconnected) => {
727                    // All done! Return our output.
728                    return output;
729                }
730            };
731            // Always step at least once.
732            subscribe.step();
733            // If we got a ts, make sure to step past it.
734            if let Some(ts) = ts {
735                while subscribe.progress() <= ts {
736                    subscribe.step();
737                }
738            }
739            let new_output = std::mem::take(&mut subscribe.output);
740            output.extend(new_output.iter().cloned());
741            let _ = tx.send((new_output, subscribe.progress()));
742        }
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use itertools::{Either, Itertools};
749
750    use crate::tests::writer;
751    use crate::txns::TxnsHandle;
752
753    use super::*;
754
755    impl<K, V, T, D, C> TxnsHandle<K, V, T, D, C>
756    where
757        K: Debug + Codec,
758        V: Debug + Codec,
759        T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
760        D: Debug + Monoid + Ord + Codec64 + Send + Sync,
761        C: TxnsCodec,
762    {
763        async fn subscribe_task(
764            &self,
765            client: &PersistClient,
766            data_id: ShardId,
767            as_of: u64,
768        ) -> DataSubscribeTask {
769            DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
770        }
771    }
772
773    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
774    #[cfg_attr(miri, ignore)] // too slow
775    async fn data_subscribe() {
776        async fn step(subs: &mut Vec<DataSubscribeTask>) {
777            for sub in subs.iter_mut() {
778                sub.step().await;
779            }
780        }
781
782        let client = PersistClient::new_for_tests().await;
783        let mut txns = TxnsHandle::expect_open(client.clone()).await;
784        let log = txns.new_log();
785        let d0 = ShardId::new();
786
787        // Start a subscription before the shard gets registered.
788        let mut subs = Vec::new();
789        subs.push(txns.subscribe_task(&client, d0, 5).await);
790        step(&mut subs).await;
791
792        // Now register the shard. Also start a new subscription and step the
793        // previous one (plus repeat this for every later step).
794        txns.register(1, [writer(&client, d0).await]).await.unwrap();
795        subs.push(txns.subscribe_task(&client, d0, 5).await);
796        step(&mut subs).await;
797
798        // Now write something unrelated.
799        let d1 = txns.expect_register(2).await;
800        txns.expect_commit_at(3, d1, &["nope"], &log).await;
801        subs.push(txns.subscribe_task(&client, d0, 5).await);
802        step(&mut subs).await;
803
804        // Now write to our shard before.
805        txns.expect_commit_at(4, d0, &["4"], &log).await;
806        subs.push(txns.subscribe_task(&client, d0, 5).await);
807        step(&mut subs).await;
808
809        // Now write to our shard at the as_of.
810        txns.expect_commit_at(5, d0, &["5"], &log).await;
811        subs.push(txns.subscribe_task(&client, d0, 5).await);
812        step(&mut subs).await;
813
814        // Now write to our shard past the as_of.
815        txns.expect_commit_at(6, d0, &["6"], &log).await;
816        subs.push(txns.subscribe_task(&client, d0, 5).await);
817        step(&mut subs).await;
818
819        // Now write something unrelated again.
820        txns.expect_commit_at(7, d1, &["nope"], &log).await;
821        subs.push(txns.subscribe_task(&client, d0, 5).await);
822        step(&mut subs).await;
823
824        // Verify that the dataflows can progress to the expected point and that
825        // we read the right thing no matter when the dataflow started.
826        for mut sub in subs {
827            let progress = sub.step_past(7).await;
828            assert_eq!(progress, 8);
829            log.assert_eq(d0, 5, 8, sub.finish().await);
830        }
831    }
832
833    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
834    #[cfg_attr(miri, ignore)] // too slow
835    async fn subscribe_shard_finalize() {
836        let client = PersistClient::new_for_tests().await;
837        let mut txns = TxnsHandle::expect_open(client.clone()).await;
838        let log = txns.new_log();
839        let d0 = txns.expect_register(1).await;
840
841        // Start the operator as_of the register ts.
842        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
843        sub.step_past(1).await;
844
845        // Write to it via txns.
846        txns.expect_commit_at(2, d0, &["foo"], &log).await;
847        sub.step_past(2).await;
848
849        // Unregister it.
850        txns.forget(3, [d0]).await.unwrap();
851        sub.step_past(3).await;
852
853        // TODO: Hard mode, see if we can get the rest of this test to work even
854        // _without_ the txns shard advancing.
855        txns.begin().commit_at(&mut txns, 7).await.unwrap();
856
857        // The operator should continue to emit data written directly even
858        // though it's no longer in the txns set.
859        let mut d0_write = writer(&client, d0).await;
860        let key = "bar".to_owned();
861        crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
862            .await
863            .unwrap();
864        log.record((d0, key, 5, 1));
865        sub.step_past(4).await;
866
867        // Now finalize the shard to writes.
868        let () = d0_write
869            .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new(), true)
870            .await
871            .unwrap()
872            .unwrap();
873        while sub.txns.less_than(&u64::MAX) {
874            sub.step();
875            tokio::task::yield_now().await;
876        }
877
878        // Make sure we read the correct things.
879        log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
880
881        // Also make sure that we can read the right things if we start up after
882        // the forget but before the direct write and ditto after the direct
883        // write.
884        log.assert_subscribe(d0, 4, u64::MAX).await;
885        log.assert_subscribe(d0, 6, u64::MAX).await;
886    }
887
888    #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
889    #[cfg_attr(miri, ignore)] // too slow
890    async fn subscribe_shard_register_forget() {
891        let client = PersistClient::new_for_tests().await;
892        let mut txns = TxnsHandle::expect_open(client.clone()).await;
893        let d0 = ShardId::new();
894
895        // Start a subscription on the data shard.
896        let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
897        assert_eq!(sub.progress(), 0);
898
899        // Register the shard at 10.
900        txns.register(10, [writer(&client, d0).await])
901            .await
902            .unwrap();
903        sub.step_past(10).await;
904        assert!(
905            sub.progress() > 10,
906            "operator should advance past 10 when shard is registered"
907        );
908
909        // Forget the shard at 20.
910        txns.forget(20, [d0]).await.unwrap();
911        sub.step_past(20).await;
912        assert!(
913            sub.progress() > 20,
914            "operator should advance past 20 when shard is forgotten"
915        );
916    }
917
918    #[mz_ore::test(tokio::test)]
919    #[cfg_attr(miri, ignore)] // too slow
920    async fn as_of_until() {
921        let client = PersistClient::new_for_tests().await;
922        let mut txns = TxnsHandle::expect_open(client.clone()).await;
923        let log = txns.new_log();
924
925        let d0 = txns.expect_register(1).await;
926        txns.expect_commit_at(2, d0, &["2"], &log).await;
927        txns.expect_commit_at(3, d0, &["3"], &log).await;
928        txns.expect_commit_at(4, d0, &["4"], &log).await;
929        txns.expect_commit_at(5, d0, &["5"], &log).await;
930        txns.expect_commit_at(6, d0, &["6"], &log).await;
931        txns.expect_commit_at(7, d0, &["7"], &log).await;
932
933        let until = 5;
934        let mut sub = DataSubscribe::new(
935            "as_of_until",
936            client,
937            txns.txns_id(),
938            d0,
939            3,
940            Antichain::from_elem(until),
941        );
942        // Manually step the dataflow, instead of going through the
943        // `DataSubscribe` helper because we're interested in all captured
944        // events.
945        while sub.txns.less_equal(&5) {
946            sub.worker.step();
947            tokio::task::yield_now().await;
948            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
949        }
950        let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
951            sub.capture.into_iter().partition_map(|event| match event {
952                Event::Progress(progress) => Either::Left(progress),
953                Event::Messages(ts, data) => Either::Right((ts, data)),
954            });
955        let expected = vec![
956            (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
957            (3, vec![("4".to_owned(), 4, 1)]),
958        ];
959        assert_eq!(actual_events, expected);
960
961        // The number and contents of progress messages is not guaranteed and
962        // depends on the downgrade behavior. The only thing we can assert is
963        // the max progress timestamp, if there is one, is less than the until.
964        if let Some(max_progress_ts) = actual_progresses
965            .into_iter()
966            .flatten()
967            .map(|(ts, _diff)| ts)
968            .max()
969        {
970            assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
971        }
972    }
973}