1use std::any::Any;
13use std::fmt::Debug;
14use std::future::Future;
15use std::sync::mpsc::TryRecvError;
16use std::sync::{Arc, mpsc};
17use std::time::Duration;
18
19use differential_dataflow::Hashable;
20use differential_dataflow::difference::Semigroup;
21use differential_dataflow::lattice::Lattice;
22use futures::StreamExt;
23use mz_dyncfg::{Config, ConfigSet, ConfigUpdates};
24use mz_ore::cast::CastFrom;
25use mz_ore::task::JoinHandleExt;
26use mz_persist_client::cfg::{RetryParameters, USE_GLOBAL_TXN_CACHE_SOURCE};
27use mz_persist_client::operators::shard_source::{FilterResult, SnapshotMode, shard_source};
28use mz_persist_client::{Diagnostics, PersistClient, ShardId};
29use mz_persist_types::codec_impls::{StringSchema, UnitSchema};
30use mz_persist_types::txn::TxnsCodec;
31use mz_persist_types::{Codec, Codec64, StepForward};
32use mz_timely_util::builder_async::{
33 AsyncInputHandle, Event as AsyncEvent, InputConnection,
34 OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use timely::container::CapacityContainerBuilder;
37use timely::dataflow::channels::pact::Pipeline;
38use timely::dataflow::operators::capture::Event;
39use timely::dataflow::operators::{Broadcast, Capture, Leave, Map, Probe};
40use timely::dataflow::{ProbeHandle, Scope, Stream};
41use timely::order::TotalOrder;
42use timely::progress::{Antichain, Timestamp};
43use timely::worker::Worker;
44use timely::{Data, PartialOrder, WorkerConfig};
45use tracing::debug;
46
47use crate::TxnsCodecDefault;
48use crate::txn_cache::TxnsCache;
49use crate::txn_read::{DataListenNext, DataRemapEntry, TxnsRead};
50
51pub fn txns_progress<K, V, T, D, P, C, F, G>(
94 passthrough: Stream<G, P>,
95 name: &str,
96 ctx: &TxnsContext,
97 worker_dyncfgs: &ConfigSet,
98 client_fn: impl Fn() -> F,
99 txns_id: ShardId,
100 data_id: ShardId,
101 as_of: T,
102 until: Antichain<T>,
103 data_key_schema: Arc<K::Schema>,
104 data_val_schema: Arc<V::Schema>,
105) -> (Stream<G, P>, Vec<PressOnDropButton>)
106where
107 K: Debug + Codec + Send + Sync,
108 V: Debug + Codec + Send + Sync,
109 T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
110 D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
111 P: Debug + Data,
112 C: TxnsCodec + 'static,
113 F: Future<Output = PersistClient> + Send + 'static,
114 G: Scope<Timestamp = T>,
115{
116 let unique_id = (name, passthrough.scope().addr()).hashed();
117 let (remap, source_button) = if USE_GLOBAL_TXN_CACHE_SOURCE.get(worker_dyncfgs) {
118 txns_progress_source_global::<K, V, T, D, P, C, G>(
119 passthrough.scope(),
120 name,
121 ctx.clone(),
122 client_fn(),
123 txns_id,
124 data_id,
125 as_of,
126 data_key_schema,
127 data_val_schema,
128 unique_id,
129 )
130 } else {
131 txns_progress_source_local::<K, V, T, D, P, C, G>(
132 passthrough.scope(),
133 name,
134 client_fn(),
135 txns_id,
136 data_id,
137 as_of,
138 data_key_schema,
139 data_val_schema,
140 unique_id,
141 )
142 };
143 let remap = remap.broadcast();
146 let (passthrough, frontiers_button) = txns_progress_frontiers::<K, V, T, D, P, C, G>(
147 remap,
148 passthrough,
149 name,
150 data_id,
151 until,
152 unique_id,
153 );
154 (passthrough, vec![source_button, frontiers_button])
155}
156
157fn txns_progress_source_local<K, V, T, D, P, C, G>(
160 scope: G,
161 name: &str,
162 client: impl Future<Output = PersistClient> + 'static,
163 txns_id: ShardId,
164 data_id: ShardId,
165 as_of: T,
166 data_key_schema: Arc<K::Schema>,
167 data_val_schema: Arc<V::Schema>,
168 unique_id: u64,
169) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
170where
171 K: Debug + Codec + Send + Sync,
172 V: Debug + Codec + Send + Sync,
173 T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
174 D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
175 P: Debug + Data,
176 C: TxnsCodec + 'static,
177 G: Scope<Timestamp = T>,
178{
179 let worker_idx = scope.index();
180 let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
181 let name = format!("txns_progress_source({})", name);
182 let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
183 let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
184 let (remap_output, remap_stream) = builder.new_output();
185
186 let shutdown_button = builder.build(move |capabilities| async move {
187 if worker_idx != chosen_worker {
188 return;
189 }
190
191 let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
192 let client = client.await;
193 let mut txns_cache = TxnsCache::<T, C>::open(&client, txns_id, Some(data_id)).await;
194
195 let _ = txns_cache.update_gt(&as_of).await;
196 let mut subscribe = txns_cache.data_subscribe(data_id, as_of.clone());
197 let data_write = client
198 .open_writer::<K, V, T, D>(
199 data_id,
200 Arc::clone(&data_key_schema),
201 Arc::clone(&data_val_schema),
202 Diagnostics::from_purpose("data read physical upper"),
203 )
204 .await
205 .expect("schema shouldn't change");
206 if let Some(snapshot) = subscribe.snapshot.take() {
207 snapshot.unblock_read(data_write).await;
208 }
209
210 debug!("{} emitting {:?}", name, subscribe.remap);
211 remap_output.give(&cap, subscribe.remap.clone());
212
213 loop {
214 let _ = txns_cache.update_ge(&subscribe.remap.logical_upper).await;
215 cap.downgrade(&subscribe.remap.logical_upper);
216 let data_listen_next =
217 txns_cache.data_listen_next(&subscribe.data_id, &subscribe.remap.logical_upper);
218 debug!(
219 "{} data_listen_next at {:?}: {:?}",
220 name, subscribe.remap.logical_upper, data_listen_next,
221 );
222 match data_listen_next {
223 DataListenNext::WaitForTxnsProgress => {
230 let _ = txns_cache.update_gt(&subscribe.remap.logical_upper).await;
231 }
232 DataListenNext::ReadDataTo(new_upper) => {
234 subscribe.remap = DataRemapEntry {
236 physical_upper: new_upper.clone(),
237 logical_upper: new_upper,
238 };
239 debug!("{} emitting {:?}", name, subscribe.remap);
240 remap_output.give(&cap, subscribe.remap.clone());
241 }
242 DataListenNext::EmitLogicalProgress(new_progress) => {
245 assert!(subscribe.remap.physical_upper < new_progress);
246 assert!(subscribe.remap.logical_upper < new_progress);
247
248 subscribe.remap.logical_upper = new_progress;
249 debug!("{} not emitting {:?}", name, subscribe.remap);
254 }
255 }
256 }
257 });
258 (remap_stream, shutdown_button.press_on_drop())
259}
260
261fn txns_progress_source_global<K, V, T, D, P, C, G>(
275 scope: G,
276 name: &str,
277 ctx: TxnsContext,
278 client: impl Future<Output = PersistClient> + 'static,
279 txns_id: ShardId,
280 data_id: ShardId,
281 as_of: T,
282 data_key_schema: Arc<K::Schema>,
283 data_val_schema: Arc<V::Schema>,
284 unique_id: u64,
285) -> (Stream<G, DataRemapEntry<T>>, PressOnDropButton)
286where
287 K: Debug + Codec + Send + Sync,
288 V: Debug + Codec + Send + Sync,
289 T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
290 D: Debug + Data + Semigroup + Ord + Codec64 + Send + Sync,
291 P: Debug + Data,
292 C: TxnsCodec + 'static,
293 G: Scope<Timestamp = T>,
294{
295 let worker_idx = scope.index();
296 let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
297 let name = format!("txns_progress_source({})", name);
298 let mut builder = AsyncOperatorBuilder::new(name.clone(), scope);
299 let name = format!("{} [{}] {:.9}", name, unique_id, data_id.to_string());
300 let (remap_output, remap_stream) = builder.new_output();
301
302 let shutdown_button = builder.build(move |capabilities| async move {
303 if worker_idx != chosen_worker {
304 return;
305 }
306
307 let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
308 let client = client.await;
309 let txns_read = ctx.get_or_init::<T, C>(&client, txns_id).await;
310
311 let _ = txns_read.update_gt(as_of.clone()).await;
312 let data_write = client
313 .open_writer::<K, V, T, D>(
314 data_id,
315 Arc::clone(&data_key_schema),
316 Arc::clone(&data_val_schema),
317 Diagnostics::from_purpose("data read physical upper"),
318 )
319 .await
320 .expect("schema shouldn't change");
321 let mut rx = txns_read
322 .data_subscribe(data_id, as_of.clone(), Box::new(data_write))
323 .await;
324 debug!("{} starting as_of={:?}", name, as_of);
325
326 let mut physical_upper = T::minimum();
327
328 while let Some(remap) = rx.recv().await {
329 assert!(physical_upper <= remap.physical_upper);
330 assert!(physical_upper < remap.logical_upper);
331
332 let logical_upper = remap.logical_upper.clone();
333 if remap.physical_upper != physical_upper {
338 physical_upper = remap.physical_upper.clone();
339 debug!("{} emitting {:?}", name, remap);
340 remap_output.give(&cap, remap);
341 } else {
342 debug!("{} not emitting {:?}", name, remap);
343 }
344 cap.downgrade(&logical_upper);
345 }
346 });
347 (remap_stream, shutdown_button.press_on_drop())
348}
349
350fn txns_progress_frontiers<K, V, T, D, P, C, G>(
351 remap: Stream<G, DataRemapEntry<T>>,
352 passthrough: Stream<G, P>,
353 name: &str,
354 data_id: ShardId,
355 until: Antichain<T>,
356 unique_id: u64,
357) -> (Stream<G, P>, PressOnDropButton)
358where
359 K: Debug + Codec,
360 V: Debug + Codec,
361 T: Timestamp + Lattice + TotalOrder + StepForward + Codec64,
362 D: Data + Semigroup + Codec64 + Send + Sync,
363 P: Debug + Data,
364 C: TxnsCodec,
365 G: Scope<Timestamp = T>,
366{
367 let name = format!("txns_progress_frontiers({})", name);
368 let mut builder = AsyncOperatorBuilder::new(name.clone(), passthrough.scope());
369 let name = format!(
370 "{} [{}] {}/{} {:.9}",
371 name,
372 unique_id,
373 passthrough.scope().index(),
374 passthrough.scope().peers(),
375 data_id.to_string(),
376 );
377 let (passthrough_output, passthrough_stream) =
378 builder.new_output::<CapacityContainerBuilder<_>>();
379 let mut remap_input = builder.new_disconnected_input(&remap, Pipeline);
380 let mut passthrough_input = builder.new_disconnected_input(&passthrough, Pipeline);
381
382 let shutdown_button = builder.build(move |capabilities| async move {
383 let [mut cap]: [_; 1] = capabilities.try_into().expect("one capability per output");
384
385 let mut remap = Some(DataRemapEntry {
387 physical_upper: T::minimum(),
388 logical_upper: T::minimum(),
389 });
390 loop {
393 debug!("{} remap {:?}", name, remap);
394 if let Some(r) = remap.as_ref() {
395 assert!(r.physical_upper <= r.logical_upper);
396 if r.physical_upper.less_equal(cap.time()) {
408 if cap.time() < &r.logical_upper {
409 cap.downgrade(&r.logical_upper);
410 }
411 remap = txns_progress_frontiers_read_remap_input(
412 &name,
413 &mut remap_input,
414 r.clone(),
415 )
416 .await;
417 continue;
418 }
419 }
420
421 let event = passthrough_input
425 .next()
426 .await
427 .unwrap_or_else(|| AsyncEvent::Progress(Antichain::new()));
428 match event {
429 AsyncEvent::Data(_data_cap, mut data) => {
431 debug!("{} emitting data {:?}", name, data);
436 passthrough_output.give_container(&cap, &mut data);
437 }
438 AsyncEvent::Progress(new_progress) => {
439 if PartialOrder::less_equal(&until, &new_progress) {
450 debug!(
451 "{} progress {:?} has passed until {:?}",
452 name,
453 new_progress.elements(),
454 until.elements()
455 );
456 return;
457 }
458 let Some(new_progress) = new_progress.into_option() else {
460 return;
461 };
462
463 if cap.time() < &new_progress {
468 debug!("{} downgrading cap to {:?}", name, new_progress);
469 cap.downgrade(&new_progress);
470 }
471 }
472 }
473 }
474 });
475 (passthrough_stream, shutdown_button.press_on_drop())
476}
477
478async fn txns_progress_frontiers_read_remap_input<T, C>(
479 name: &str,
480 input: &mut AsyncInputHandle<T, Vec<DataRemapEntry<T>>, C>,
481 mut remap: DataRemapEntry<T>,
482) -> Option<DataRemapEntry<T>>
483where
484 T: Timestamp + TotalOrder,
485 C: InputConnection<T>,
486{
487 while let Some(event) = input.next().await {
488 let xs = match event {
489 AsyncEvent::Progress(logical_upper) => {
490 if let Some(logical_upper) = logical_upper.into_option() {
491 if remap.logical_upper < logical_upper {
492 remap.logical_upper = logical_upper;
493 return Some(remap);
494 }
495 }
496 continue;
497 }
498 AsyncEvent::Data(_cap, xs) => xs,
499 };
500 for x in xs {
501 debug!("{} got remap {:?}", name, x);
502 if remap.logical_upper < x.logical_upper {
504 assert!(
505 remap.physical_upper <= x.physical_upper,
506 "previous remap physical upper {:?} is ahead of new remap physical upper {:?}",
507 remap.physical_upper,
508 x.physical_upper,
509 );
510 remap = x;
521 }
522 }
523 return Some(remap);
524 }
525 None
527}
528
529#[derive(Default, Debug, Clone)]
531pub struct TxnsContext {
532 read: Arc<tokio::sync::OnceCell<Box<dyn Any + Send + Sync>>>,
533}
534
535impl TxnsContext {
536 async fn get_or_init<T, C>(&self, client: &PersistClient, txns_id: ShardId) -> TxnsRead<T>
537 where
538 T: Timestamp + Lattice + Codec64 + TotalOrder + StepForward + Sync,
539 C: TxnsCodec + 'static,
540 {
541 let read = self
542 .read
543 .get_or_init(|| {
544 let client = client.clone();
545 async move {
546 let read: Box<dyn Any + Send + Sync> =
547 Box::new(TxnsRead::<T>::start::<C>(client, txns_id).await);
548 read
549 }
550 })
551 .await
552 .downcast_ref::<TxnsRead<T>>()
553 .expect("timestamp types should match");
554 assert_eq!(&txns_id, read.txns_id());
556 read.clone()
557 }
558}
559
560pub(crate) const DATA_SHARD_RETRYER_INITIAL_BACKOFF: Config<Duration> = Config::new(
564 "persist_txns_data_shard_retryer_initial_backoff",
565 Duration::from_millis(1024),
566 "The initial backoff when polling for new batches from a txns data shard persist_source.",
567);
568
569pub(crate) const DATA_SHARD_RETRYER_MULTIPLIER: Config<u32> = Config::new(
570 "persist_txns_data_shard_retryer_multiplier",
571 2,
572 "The backoff multiplier when polling for new batches from a txns data shard persist_source.",
573);
574
575pub(crate) const DATA_SHARD_RETRYER_CLAMP: Config<Duration> = Config::new(
576 "persist_txns_data_shard_retryer_clamp",
577 Duration::from_secs(16),
578 "The backoff clamp duration when polling for new batches from a txns data shard persist_source.",
579);
580
581pub fn txns_data_shard_retry_params(cfg: &ConfigSet) -> RetryParameters {
584 RetryParameters {
585 fixed_sleep: Duration::ZERO,
586 initial_backoff: DATA_SHARD_RETRYER_INITIAL_BACKOFF.get(cfg),
587 multiplier: DATA_SHARD_RETRYER_MULTIPLIER.get(cfg),
588 clamp: DATA_SHARD_RETRYER_CLAMP.get(cfg),
589 }
590}
591
592pub struct DataSubscribe {
600 pub(crate) as_of: u64,
601 pub(crate) worker: Worker<timely::communication::allocator::Thread>,
602 data: ProbeHandle<u64>,
603 txns: ProbeHandle<u64>,
604 capture: mpsc::Receiver<Event<u64, Vec<(String, u64, i64)>>>,
605 output: Vec<(String, u64, i64)>,
606
607 _tokens: Vec<PressOnDropButton>,
608}
609
610impl std::fmt::Debug for DataSubscribe {
611 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
612 let DataSubscribe {
613 as_of,
614 worker: _,
615 data,
616 txns,
617 capture: _,
618 output,
619 _tokens: _,
620 } = self;
621 f.debug_struct("DataSubscribe")
622 .field("as_of", as_of)
623 .field("data", data)
624 .field("txns", txns)
625 .field("output", output)
626 .finish_non_exhaustive()
627 }
628}
629
630impl DataSubscribe {
631 pub fn new(
633 name: &str,
634 client: PersistClient,
635 txns_id: ShardId,
636 data_id: ShardId,
637 as_of: u64,
638 until: Antichain<u64>,
639 use_global_txn_cache: bool,
640 ) -> Self {
641 let mut worker = Worker::new(
642 WorkerConfig::default(),
643 timely::communication::allocator::Thread::default(),
644 );
645 let (data, txns, capture, tokens) = worker.dataflow::<u64, _, _>(|scope| {
646 let (data_stream, shard_source_token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
647 let client = client.clone();
648 let (data_stream, token) = shard_source::<String, (), u64, i64, _, _, _>(
649 scope,
650 name,
651 move || std::future::ready(client.clone()),
652 data_id,
653 Some(Antichain::from_elem(as_of)),
654 SnapshotMode::Include,
655 until.clone(),
656 false.then_some(|_, _: &_, _| unreachable!()),
657 Arc::new(StringSchema),
658 Arc::new(UnitSchema),
659 FilterResult::keep_all,
660 false.then_some(|| unreachable!()),
661 async {},
662 |error| panic!("data_subscribe: {error}"),
663 );
664 (data_stream.leave(), token)
665 });
666 let (data, txns) = (ProbeHandle::new(), ProbeHandle::new());
667 let data_stream = data_stream.flat_map(|part| {
668 let part = part.parse();
669 part.part.map(|((k, v), t, d)| {
670 let (k, ()) = (k.unwrap(), v.unwrap());
671 (k, t, d)
672 })
673 });
674 let data_stream = data_stream.probe_with(&data);
675 let config_set = ConfigSet::default().add(&USE_GLOBAL_TXN_CACHE_SOURCE);
678 let mut updates = ConfigUpdates::default();
679 updates.add(&USE_GLOBAL_TXN_CACHE_SOURCE, use_global_txn_cache);
680 updates.apply(&config_set);
681 let (data_stream, mut txns_progress_token) =
682 txns_progress::<String, (), u64, i64, _, TxnsCodecDefault, _, _>(
683 data_stream,
684 name,
685 &TxnsContext::default(),
686 &config_set,
687 || std::future::ready(client.clone()),
688 txns_id,
689 data_id,
690 as_of,
691 until,
692 Arc::new(StringSchema),
693 Arc::new(UnitSchema),
694 );
695 let data_stream = data_stream.probe_with(&txns);
696 let mut tokens = shard_source_token;
697 tokens.append(&mut txns_progress_token);
698 (data, txns, data_stream.capture(), tokens)
699 });
700 Self {
701 as_of,
702 worker,
703 data,
704 txns,
705 capture,
706 output: Vec::new(),
707 _tokens: tokens,
708 }
709 }
710
711 pub fn progress(&self) -> u64 {
713 self.txns
714 .with_frontier(|f| *f.as_option().unwrap_or(&u64::MAX))
715 }
716
717 pub fn step(&mut self) {
719 self.worker.step();
720 self.capture_output()
721 }
722
723 pub(crate) fn capture_output(&mut self) {
724 loop {
725 let event = match self.capture.try_recv() {
726 Ok(x) => x,
727 Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
728 };
729 match event {
730 Event::Progress(_) => {}
731 Event::Messages(_, mut msgs) => self.output.append(&mut msgs),
732 }
733 }
734 }
735
736 #[cfg(test)]
738 pub async fn step_past(&mut self, ts: u64) {
739 while self.txns.less_equal(&ts) {
740 tracing::trace!(
741 "progress at {:?}",
742 self.txns.with_frontier(|x| x.to_owned()).elements()
743 );
744 self.step();
745 tokio::task::yield_now().await;
746 }
747 }
748
749 pub fn output(&self) -> &Vec<(String, u64, i64)> {
751 &self.output
752 }
753}
754
755#[derive(Debug)]
757pub struct DataSubscribeTask {
758 tx: std::sync::mpsc::Sender<(
761 Option<u64>,
762 tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
763 )>,
764 task: mz_ore::task::JoinHandle<Vec<(String, u64, i64)>>,
765 output: Vec<(String, u64, i64)>,
766 progress: u64,
767}
768
769impl DataSubscribeTask {
770 pub async fn new(
772 client: PersistClient,
773 txns_id: ShardId,
774 data_id: ShardId,
775 as_of: u64,
776 ) -> Self {
777 let cache = TxnsCache::open(&client, txns_id, Some(data_id)).await;
778 let (tx, rx) = std::sync::mpsc::channel();
779 let task = mz_ore::task::spawn_blocking(
780 || "data_subscribe task",
781 move || Self::task(client, cache, data_id, as_of, rx),
782 );
783 DataSubscribeTask {
784 tx,
785 task,
786 output: Vec::new(),
787 progress: 0,
788 }
789 }
790
791 #[cfg(test)]
792 async fn step(&mut self) {
793 self.send(None).await;
794 }
795
796 pub async fn step_past(&mut self, ts: u64) -> u64 {
798 self.send(Some(ts)).await;
799 self.progress
800 }
801
802 pub fn output(&self) -> &Vec<(String, u64, i64)> {
804 &self.output
805 }
806
807 async fn send(&mut self, ts: Option<u64>) {
808 let (tx, rx) = tokio::sync::oneshot::channel();
809 self.tx.send((ts, tx)).expect("task should be running");
810 let (mut new_output, new_progress) = rx.await.expect("task should be running");
811 self.output.append(&mut new_output);
812 assert!(self.progress <= new_progress);
813 self.progress = new_progress;
814 }
815
816 pub async fn finish(self) -> Vec<(String, u64, i64)> {
821 drop(self.tx);
823 self.task.wait_and_assert_finished().await
824 }
825
826 fn task(
827 client: PersistClient,
828 cache: TxnsCache<u64>,
829 data_id: ShardId,
830 as_of: u64,
831 rx: std::sync::mpsc::Receiver<(
832 Option<u64>,
833 tokio::sync::oneshot::Sender<(Vec<(String, u64, i64)>, u64)>,
834 )>,
835 ) -> Vec<(String, u64, i64)> {
836 let mut subscribe = DataSubscribe::new(
837 "DataSubscribeTask",
838 client.clone(),
839 cache.txns_id(),
840 data_id,
841 as_of,
842 Antichain::new(),
843 true,
844 );
845 let mut output = Vec::new();
846 loop {
847 let (ts, tx) = match rx.try_recv() {
848 Ok(x) => x,
849 Err(TryRecvError::Empty) => {
850 subscribe.step();
852 continue;
853 }
854 Err(TryRecvError::Disconnected) => {
855 return output;
857 }
858 };
859 subscribe.step();
861 if let Some(ts) = ts {
863 while subscribe.progress() <= ts {
864 subscribe.step();
865 }
866 }
867 let new_output = std::mem::take(&mut subscribe.output);
868 output.extend(new_output.iter().cloned());
869 let _ = tx.send((new_output, subscribe.progress()));
870 }
871 }
872}
873
874#[cfg(test)]
875mod tests {
876 use itertools::{Either, Itertools};
877 use mz_persist_types::Opaque;
878
879 use crate::tests::writer;
880 use crate::txns::TxnsHandle;
881
882 use super::*;
883
884 impl<K, V, T, D, O, C> TxnsHandle<K, V, T, D, O, C>
885 where
886 K: Debug + Codec,
887 V: Debug + Codec,
888 T: Timestamp + Lattice + TotalOrder + StepForward + Codec64 + Sync,
889 D: Debug + Semigroup + Ord + Codec64 + Send + Sync,
890 O: Opaque + Debug + Codec64,
891 C: TxnsCodec,
892 {
893 async fn subscribe_task(
894 &self,
895 client: &PersistClient,
896 data_id: ShardId,
897 as_of: u64,
898 ) -> DataSubscribeTask {
899 DataSubscribeTask::new(client.clone(), self.txns_id(), data_id, as_of).await
900 }
901 }
902
903 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
904 #[cfg_attr(miri, ignore)] async fn data_subscribe() {
906 async fn step(subs: &mut Vec<DataSubscribeTask>) {
907 for sub in subs.iter_mut() {
908 sub.step().await;
909 }
910 }
911
912 let client = PersistClient::new_for_tests().await;
913 let mut txns = TxnsHandle::expect_open(client.clone()).await;
914 let log = txns.new_log();
915 let d0 = ShardId::new();
916
917 let mut subs = Vec::new();
919 subs.push(txns.subscribe_task(&client, d0, 5).await);
920 step(&mut subs).await;
921
922 txns.register(1, [writer(&client, d0).await]).await.unwrap();
925 subs.push(txns.subscribe_task(&client, d0, 5).await);
926 step(&mut subs).await;
927
928 let d1 = txns.expect_register(2).await;
930 txns.expect_commit_at(3, d1, &["nope"], &log).await;
931 subs.push(txns.subscribe_task(&client, d0, 5).await);
932 step(&mut subs).await;
933
934 txns.expect_commit_at(4, d0, &["4"], &log).await;
936 subs.push(txns.subscribe_task(&client, d0, 5).await);
937 step(&mut subs).await;
938
939 txns.expect_commit_at(5, d0, &["5"], &log).await;
941 subs.push(txns.subscribe_task(&client, d0, 5).await);
942 step(&mut subs).await;
943
944 txns.expect_commit_at(6, d0, &["6"], &log).await;
946 subs.push(txns.subscribe_task(&client, d0, 5).await);
947 step(&mut subs).await;
948
949 txns.expect_commit_at(7, d1, &["nope"], &log).await;
951 subs.push(txns.subscribe_task(&client, d0, 5).await);
952 step(&mut subs).await;
953
954 for mut sub in subs {
957 let progress = sub.step_past(7).await;
958 assert_eq!(progress, 8);
959 log.assert_eq(d0, 5, 8, sub.finish().await);
960 }
961 }
962
963 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
964 #[cfg_attr(miri, ignore)] async fn subscribe_shard_finalize() {
966 let client = PersistClient::new_for_tests().await;
967 let mut txns = TxnsHandle::expect_open(client.clone()).await;
968 let log = txns.new_log();
969 let d0 = txns.expect_register(1).await;
970
971 let mut sub = txns.read_cache().expect_subscribe(&client, d0, 1);
973 sub.step_past(1).await;
974
975 txns.expect_commit_at(2, d0, &["foo"], &log).await;
977 sub.step_past(2).await;
978
979 txns.forget(3, [d0]).await.unwrap();
981 sub.step_past(3).await;
982
983 txns.begin().commit_at(&mut txns, 7).await.unwrap();
986
987 let mut d0_write = writer(&client, d0).await;
990 let key = "bar".to_owned();
991 crate::small_caa(|| "test", &mut d0_write, &[((&key, &()), &5, 1)], 4, 6)
992 .await
993 .unwrap();
994 log.record((d0, key, 5, 1));
995 sub.step_past(4).await;
996
997 let () = d0_write
999 .compare_and_append_batch(&mut [], Antichain::from_elem(6), Antichain::new())
1000 .await
1001 .unwrap()
1002 .unwrap();
1003 while sub.txns.less_than(&u64::MAX) {
1004 sub.step();
1005 tokio::task::yield_now().await;
1006 }
1007
1008 log.assert_eq(d0, 1, u64::MAX, sub.output().clone());
1010
1011 log.assert_subscribe(d0, 4, u64::MAX).await;
1015 log.assert_subscribe(d0, 6, u64::MAX).await;
1016 }
1017
1018 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
1019 #[cfg_attr(miri, ignore)] async fn subscribe_shard_register_forget() {
1021 let client = PersistClient::new_for_tests().await;
1022 let mut txns = TxnsHandle::expect_open(client.clone()).await;
1023 let d0 = ShardId::new();
1024
1025 let mut sub = txns.read_cache().expect_subscribe(&client, d0, 0);
1027 assert_eq!(sub.progress(), 0);
1028
1029 txns.register(10, [writer(&client, d0).await])
1031 .await
1032 .unwrap();
1033 sub.step_past(10).await;
1034 assert!(
1035 sub.progress() > 10,
1036 "operator should advance past 10 when shard is registered"
1037 );
1038
1039 txns.forget(20, [d0]).await.unwrap();
1041 sub.step_past(20).await;
1042 assert!(
1043 sub.progress() > 20,
1044 "operator should advance past 20 when shard is forgotten"
1045 );
1046 }
1047
1048 #[mz_ore::test(tokio::test)]
1049 #[cfg_attr(miri, ignore)] async fn as_of_until() {
1051 let client = PersistClient::new_for_tests().await;
1052 let mut txns = TxnsHandle::expect_open(client.clone()).await;
1053 let log = txns.new_log();
1054
1055 let d0 = txns.expect_register(1).await;
1056 txns.expect_commit_at(2, d0, &["2"], &log).await;
1057 txns.expect_commit_at(3, d0, &["3"], &log).await;
1058 txns.expect_commit_at(4, d0, &["4"], &log).await;
1059 txns.expect_commit_at(5, d0, &["5"], &log).await;
1060 txns.expect_commit_at(6, d0, &["6"], &log).await;
1061 txns.expect_commit_at(7, d0, &["7"], &log).await;
1062
1063 let until = 5;
1064 let mut sub = DataSubscribe::new(
1065 "as_of_until",
1066 client,
1067 txns.txns_id(),
1068 d0,
1069 3,
1070 Antichain::from_elem(until),
1071 true,
1072 );
1073 while sub.txns.less_equal(&5) {
1077 sub.worker.step();
1078 tokio::task::yield_now().await;
1079 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1080 }
1081 let (actual_progresses, actual_events): (Vec<_>, Vec<_>) =
1082 sub.capture.into_iter().partition_map(|event| match event {
1083 Event::Progress(progress) => Either::Left(progress),
1084 Event::Messages(ts, data) => Either::Right((ts, data)),
1085 });
1086 let expected = vec![
1087 (3, vec![("2".to_owned(), 3, 1), ("3".to_owned(), 3, 1)]),
1088 (3, vec![("4".to_owned(), 4, 1)]),
1089 ];
1090 assert_eq!(actual_events, expected);
1091
1092 if let Some(max_progress_ts) = actual_progresses
1096 .into_iter()
1097 .flatten()
1098 .map(|(ts, _diff)| ts)
1099 .max()
1100 {
1101 assert!(max_progress_ts < until, "{max_progress_ts} < {until}");
1102 }
1103 }
1104}