1use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::collections::hash_map::DefaultHasher;
15use std::convert::Infallible;
16use std::fmt::Debug;
17use std::future::{self, Future};
18use std::hash::{Hash, Hasher};
19use std::pin::{Pin, pin};
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Instant;
23
24use arrow::array::ArrayRef;
25use differential_dataflow::Hashable;
26use differential_dataflow::difference::Semigroup;
27use differential_dataflow::lattice::Lattice;
28use futures_util::StreamExt;
29use mz_ore::cast::CastFrom;
30use mz_ore::collections::CollectionExt;
31use mz_persist_types::stats::PartStats;
32use mz_persist_types::{Codec, Codec64};
33use mz_timely_util::builder_async::{
34 Event, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
35};
36use timely::PartialOrder;
37use timely::container::CapacityContainerBuilder;
38use timely::dataflow::channels::pact::{Exchange, Pipeline};
39use timely::dataflow::operators::{CapabilitySet, ConnectLoop, Enter, Feedback, Leave};
40use timely::dataflow::scopes::Child;
41use timely::dataflow::{Scope, Stream};
42use timely::order::TotalOrder;
43use timely::progress::frontier::AntichainRef;
44use timely::progress::{Antichain, Timestamp, timestamp::Refines};
45use tracing::{debug, trace};
46
47use crate::batch::BLOB_TARGET_SIZE;
48use crate::cfg::{RetryParameters, USE_CRITICAL_SINCE_SOURCE};
49use crate::fetch::{FetchedBlob, Lease, SerdeLeasedBatchPart};
50use crate::internal::state::BatchPart;
51use crate::stats::{STATS_AUDIT_PERCENT, STATS_FILTER_ENABLED};
52use crate::{Diagnostics, PersistClient, ShardId};
53
54#[derive(Debug, Clone, PartialEq, Default)]
56pub enum FilterResult {
57 #[default]
59 Keep,
60 Discard,
62 ReplaceWith {
65 key: ArrayRef,
67 val: ArrayRef,
69 },
70}
71
72impl FilterResult {
73 pub fn keep_all<T>(_stats: &PartStats, _frontier: AntichainRef<T>) -> FilterResult {
75 Self::Keep
76 }
77}
78
79pub fn shard_source<'g, K, V, T, D, DT, G, C>(
93 scope: &mut Child<'g, G, T>,
94 name: &str,
95 client: impl Fn() -> C,
96 shard_id: ShardId,
97 as_of: Option<Antichain<G::Timestamp>>,
98 snapshot_mode: SnapshotMode,
99 until: Antichain<G::Timestamp>,
100 desc_transformer: Option<DT>,
101 key_schema: Arc<K::Schema>,
102 val_schema: Arc<V::Schema>,
103 filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
104 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
106 start_signal: impl Future<Output = ()> + 'static,
107 error_handler: impl FnOnce(String) -> Pin<Box<dyn Future<Output = ()>>> + 'static,
108) -> (
109 Stream<Child<'g, G, T>, FetchedBlob<K, V, G::Timestamp, D>>,
110 Vec<PressOnDropButton>,
111)
112where
113 K: Debug + Codec,
114 V: Debug + Codec,
115 D: Semigroup + Codec64 + Send + Sync,
116 G: Scope,
117 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
119 T: Refines<G::Timestamp>,
120 DT: FnOnce(
121 Child<'g, G, T>,
122 &Stream<Child<'g, G, T>, (usize, SerdeLeasedBatchPart)>,
123 usize,
124 ) -> (
125 Stream<Child<'g, G, T>, (usize, SerdeLeasedBatchPart)>,
126 Vec<PressOnDropButton>,
127 ),
128 C: Future<Output = PersistClient> + Send + 'static,
129{
130 let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
146
147 let mut tokens = vec![];
148
149 let (completed_fetches_feedback_handle, completed_fetches_feedback_stream) =
152 scope.feedback(T::Summary::default());
153
154 let is_transient = !until.is_empty();
158
159 let (descs, descs_token) = shard_source_descs::<K, V, D, G>(
160 &scope.parent,
161 name,
162 client(),
163 shard_id.clone(),
164 as_of,
165 snapshot_mode,
166 until,
167 completed_fetches_feedback_stream.leave(),
168 chosen_worker,
169 Arc::clone(&key_schema),
170 Arc::clone(&val_schema),
171 filter_fn,
172 listen_sleep,
173 start_signal,
174 error_handler,
175 );
176 tokens.push(descs_token);
177
178 let descs = descs.enter(scope);
179 let descs = match desc_transformer {
180 Some(desc_transformer) => {
181 let (descs, extra_tokens) = desc_transformer(scope.clone(), &descs, chosen_worker);
182 tokens.extend(extra_tokens);
183 descs
184 }
185 None => descs,
186 };
187
188 let (parts, completed_fetches_stream, fetch_token) = shard_source_fetch(
189 &descs,
190 name,
191 client(),
192 shard_id,
193 key_schema,
194 val_schema,
195 is_transient,
196 );
197 completed_fetches_stream.connect_loop(completed_fetches_feedback_handle);
198 tokens.push(fetch_token);
199
200 (parts, tokens)
201}
202
203#[derive(Debug, Clone, Copy)]
205pub enum SnapshotMode {
206 Include,
208 Exclude,
210}
211
212#[derive(Debug)]
213struct LeaseManager<T> {
214 leases: BTreeMap<T, Vec<Lease>>,
215}
216
217impl<T: Timestamp + Codec64> LeaseManager<T> {
218 fn new() -> Self {
219 Self {
220 leases: BTreeMap::new(),
221 }
222 }
223
224 fn push_at(&mut self, time: T, lease: Lease) {
226 self.leases.entry(time).or_default().push(lease);
227 }
228
229 fn advance_to(&mut self, frontier: AntichainRef<T>)
231 where
232 T: TotalOrder,
234 {
235 while let Some(first) = self.leases.first_entry() {
236 if frontier.less_equal(first.key()) {
237 break; }
239 drop(first.remove());
240 }
241 }
242}
243
244pub(crate) fn shard_source_descs<K, V, D, G>(
245 scope: &G,
246 name: &str,
247 client: impl Future<Output = PersistClient> + Send + 'static,
248 shard_id: ShardId,
249 as_of: Option<Antichain<G::Timestamp>>,
250 snapshot_mode: SnapshotMode,
251 until: Antichain<G::Timestamp>,
252 completed_fetches_stream: Stream<G, Infallible>,
253 chosen_worker: usize,
254 key_schema: Arc<K::Schema>,
255 val_schema: Arc<V::Schema>,
256 mut filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
257 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
259 start_signal: impl Future<Output = ()> + 'static,
260 error_handler: impl FnOnce(String) -> Pin<Box<dyn Future<Output = ()>>> + 'static,
261) -> (Stream<G, (usize, SerdeLeasedBatchPart)>, PressOnDropButton)
262where
263 K: Debug + Codec,
264 V: Debug + Codec,
265 D: Semigroup + Codec64 + Send + Sync,
266 G: Scope,
267 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
269{
270 let worker_index = scope.index();
271 let num_workers = scope.peers();
272
273 let name_owned = name.to_owned();
276
277 let listen_handle = Rc::new(RefCell::new(None));
279 let return_listen_handle = Rc::clone(&listen_handle);
280
281 let (tx, rx) = tokio::sync::oneshot::channel::<Rc<RefCell<LeaseManager<G::Timestamp>>>>();
283 let mut builder = AsyncOperatorBuilder::new(
284 format!("shard_source_descs_return({})", name),
285 scope.clone(),
286 );
287 let mut completed_fetches = builder.new_disconnected_input(&completed_fetches_stream, Pipeline);
288 builder.build(move |_caps| async move {
291 let Ok(leases) = rx.await else {
292 return;
295 };
296 while let Some(event) = completed_fetches.next().await {
297 let Event::Progress(frontier) = event else {
298 continue;
299 };
300 leases.borrow_mut().advance_to(frontier.borrow());
301 }
302 drop(return_listen_handle);
304 });
305
306 struct ErrorHandler<H: FnOnce(String) -> Pin<Box<dyn Future<Output = ()>>> + 'static> {
309 inner: H,
310 }
311 impl<H: FnOnce(String) -> Pin<Box<dyn Future<Output = ()>>> + 'static> ErrorHandler<H> {
312 async fn report_and_stop(self, error: String) -> ! {
314 (self.inner)(error).await;
315
316 future::pending().await
320 }
321 }
322 let error_handler = ErrorHandler {
323 inner: error_handler,
324 };
325
326 let mut builder =
327 AsyncOperatorBuilder::new(format!("shard_source_descs({})", name), scope.clone());
328 let (descs_output, descs_stream) = builder.new_output();
329
330 #[allow(clippy::await_holding_refcell_ref)]
331 let shutdown_button = builder.build(move |caps| async move {
332 let mut cap_set = CapabilitySet::from_elem(caps.into_element());
333
334 if worker_index != chosen_worker {
336 trace!(
337 "We are not the chosen worker ({}), exiting...",
338 chosen_worker
339 );
340 return;
341 }
342
343 let mut read = mz_ore::task::spawn(|| format!("shard_source_reader({})", name_owned), {
353 let diagnostics = Diagnostics {
354 handle_purpose: format!("shard_source({})", name_owned),
355 shard_name: name_owned.clone(),
356 };
357 async move {
358 let client = client.await;
359 client
360 .open_leased_reader::<K, V, G::Timestamp, D>(
361 shard_id,
362 key_schema,
363 val_schema,
364 diagnostics,
365 USE_CRITICAL_SINCE_SOURCE.get(client.dyncfgs()),
366 )
367 .await
368 }
369 })
370 .await
371 .expect("reader creation shouldn't panic")
372 .expect("could not open persist shard");
373
374 let () = start_signal.await;
378
379 let cfg = read.cfg.clone();
380 let metrics = Arc::clone(&read.metrics);
381
382 let as_of = as_of.unwrap_or_else(|| read.since().clone());
383
384 cap_set.downgrade(as_of.clone());
400
401 let mut snapshot_parts =
402 match snapshot_mode {
403 SnapshotMode::Include => match read.snapshot(as_of.clone()).await {
404 Ok(parts) => parts,
405 Err(e) => error_handler
406 .report_and_stop(format!(
407 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
408 ))
409 .await,
410 },
411 SnapshotMode::Exclude => vec![],
412 };
413
414 let leases = Rc::new(RefCell::new(LeaseManager::new()));
418 tx.send(Rc::clone(&leases))
419 .expect("lease returner exited before desc producer");
420
421 let mut listen = listen_handle.borrow_mut();
424 let listen = match read.listen(as_of.clone()).await {
425 Ok(handle) => listen.insert(handle),
426 Err(e) => {
427 error_handler
428 .report_and_stop(format!(
429 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
430 ))
431 .await
432 }
433 };
434
435 let listen_retry = listen_sleep.as_ref().map(|retry| retry());
436
437 let listen_head = if !snapshot_parts.is_empty() {
439 let (mut parts, progress) = listen.next(listen_retry).await;
440 snapshot_parts.append(&mut parts);
441 futures::stream::iter(Some((snapshot_parts, progress)))
442 } else {
443 futures::stream::iter(None)
444 };
445
446 let listen_tail = futures::stream::unfold(listen, |listen| async move {
448 Some((listen.next(listen_retry).await, listen))
449 });
450
451 let mut shard_stream = pin!(listen_head.chain(listen_tail));
452
453 let mut audit_budget_bytes = u64::cast_from(BLOB_TARGET_SIZE.get(&cfg).saturating_mul(2));
457
458 let mut current_frontier = as_of.clone();
460
461 while !PartialOrder::less_equal(&until, ¤t_frontier) {
464 let (parts, progress) = shard_stream.next().await.expect("infinite stream");
465
466 let current_ts = current_frontier
469 .as_option()
470 .expect("until should always be <= the empty frontier");
471 let session_cap = cap_set.delayed(current_ts);
472
473 for mut part_desc in parts {
474 if STATS_FILTER_ENABLED.get(&cfg) {
477 let filter_result = match &part_desc.part {
478 BatchPart::Hollow(x) => {
479 let should_fetch =
480 x.stats.as_ref().map_or(FilterResult::Keep, |stats| {
481 filter_fn(&stats.decode(), current_frontier.borrow())
482 });
483 should_fetch
484 }
485 BatchPart::Inline { .. } => FilterResult::Keep,
486 };
487 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
489 match filter_result {
490 FilterResult::Keep => {
491 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
492 }
493 FilterResult::Discard => {
494 metrics.pushdown.parts_filtered_count.inc();
495 metrics.pushdown.parts_filtered_bytes.inc_by(bytes);
496 let should_audit = match &part_desc.part {
497 BatchPart::Hollow(x) => {
498 let mut h = DefaultHasher::new();
499 x.key.hash(&mut h);
500 usize::cast_from(h.finish()) % 100
501 < STATS_AUDIT_PERCENT.get(&cfg)
502 }
503 BatchPart::Inline { .. } => false,
504 };
505 if should_audit && bytes < audit_budget_bytes {
506 audit_budget_bytes -= bytes;
507 metrics.pushdown.parts_audited_count.inc();
508 metrics.pushdown.parts_audited_bytes.inc_by(bytes);
509 part_desc.request_filter_pushdown_audit();
510 } else {
511 debug!(
512 "skipping part because of stats filter {:?}",
513 part_desc.part.stats()
514 );
515 continue;
516 }
517 }
518 FilterResult::ReplaceWith { key, val } => {
519 part_desc.maybe_optimize(&cfg, key, val);
520 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
521 }
522 }
523 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
524 if part_desc.part.is_inline() {
525 metrics.pushdown.parts_inline_count.inc();
526 metrics.pushdown.parts_inline_bytes.inc_by(bytes);
527 } else {
528 metrics.pushdown.parts_fetched_count.inc();
529 metrics.pushdown.parts_fetched_bytes.inc_by(bytes);
530 }
531 }
532
533 let worker_idx = usize::cast_from(Instant::now().hashed()) % num_workers;
540 let (part, lease) = part_desc.into_exchangeable_part();
541 if let Some(lease) = lease {
542 leases.borrow_mut().push_at(current_ts.clone(), lease);
543 }
544 descs_output.give(&session_cap, (worker_idx, part));
545 }
546
547 current_frontier.join_assign(&progress);
548 cap_set.downgrade(progress.iter());
549 }
550 });
551
552 (descs_stream, shutdown_button.press_on_drop())
553}
554
555pub(crate) fn shard_source_fetch<K, V, T, D, G>(
556 descs: &Stream<G, (usize, SerdeLeasedBatchPart)>,
557 name: &str,
558 client: impl Future<Output = PersistClient> + Send + 'static,
559 shard_id: ShardId,
560 key_schema: Arc<K::Schema>,
561 val_schema: Arc<V::Schema>,
562 is_transient: bool,
563) -> (
564 Stream<G, FetchedBlob<K, V, T, D>>,
565 Stream<G, Infallible>,
566 PressOnDropButton,
567)
568where
569 K: Debug + Codec,
570 V: Debug + Codec,
571 T: Timestamp + Lattice + Codec64 + Sync,
572 D: Semigroup + Codec64 + Send + Sync,
573 G: Scope,
574 G::Timestamp: Refines<T>,
575{
576 let mut builder =
577 AsyncOperatorBuilder::new(format!("shard_source_fetch({})", name), descs.scope());
578 let (fetched_output, fetched_stream) = builder.new_output();
579 let (completed_fetches_output, completed_fetches_stream) =
580 builder.new_output::<CapacityContainerBuilder<Vec<Infallible>>>();
581 let mut descs_input = builder.new_input_for_many(
582 descs,
583 Exchange::new(|&(i, _): &(usize, _)| u64::cast_from(i)),
584 [&fetched_output, &completed_fetches_output],
585 );
586 let name_owned = name.to_owned();
587
588 let shutdown_button = builder.build(move |_capabilities| async move {
589 let mut fetcher = mz_ore::task::spawn(|| format!("shard_source_fetch({})", name_owned), {
590 let diagnostics = Diagnostics {
591 shard_name: name_owned.clone(),
592 handle_purpose: format!("shard_source_fetch batch fetcher {}", name_owned),
593 };
594 async move {
595 client
596 .await
597 .create_batch_fetcher::<K, V, T, D>(
598 shard_id,
599 key_schema,
600 val_schema,
601 is_transient,
602 diagnostics,
603 )
604 .await
605 }
606 })
607 .await
608 .expect("fetcher creation shouldn't panic")
609 .expect("shard codecs should not change");
610
611 while let Some(event) = descs_input.next().await {
612 if let Event::Data([fetched_cap, _completed_fetches_cap], data) = event {
613 for (_idx, part) in data {
616 let leased_part = fetcher.leased_part_from_exchangeable(part);
617 let fetched = fetcher
618 .fetch_leased_part(&leased_part)
619 .await
620 .expect("shard_id should match across all workers");
621 {
622 fetched_output.give(&fetched_cap, fetched);
628 }
629 }
630 }
631 }
632 });
633
634 (
635 fetched_stream,
636 completed_fetches_stream,
637 shutdown_button.press_on_drop(),
638 )
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644 use std::sync::Arc;
645
646 use timely::dataflow::Scope;
647 use timely::dataflow::operators::Leave;
648 use timely::dataflow::operators::Probe;
649 use timely::progress::Antichain;
650
651 use crate::operators::shard_source::shard_source;
652 use crate::{Diagnostics, ShardId};
653
654 #[mz_ore::test]
655 fn test_lease_manager() {
656 let lease = Lease::default();
657 let mut manager = LeaseManager::new();
658 for t in 0u64..10 {
659 manager.push_at(t, lease.clone());
660 }
661 assert_eq!(lease.count(), 11);
662 manager.advance_to(AntichainRef::new(&[5]));
663 assert_eq!(lease.count(), 6);
664 manager.advance_to(AntichainRef::new(&[3]));
665 assert_eq!(lease.count(), 6);
666 manager.advance_to(AntichainRef::new(&[9]));
667 assert_eq!(lease.count(), 2);
668 manager.advance_to(AntichainRef::new(&[10]));
669 assert_eq!(lease.count(), 1);
670 }
671
672 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
680 #[cfg_attr(miri, ignore)] async fn test_shard_source_implicit_initial_as_of() {
682 let persist_client = PersistClient::new_for_tests().await;
683
684 let expected_frontier = 42;
685 let shard_id = ShardId::new();
686
687 initialize_shard(
688 &persist_client,
689 shard_id,
690 Antichain::from_elem(expected_frontier),
691 )
692 .await;
693
694 let res = timely::execute::execute_directly(move |worker| {
695 let until = Antichain::new();
696
697 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
698 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
699 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
700 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
701 scope,
702 "test_source",
703 move || std::future::ready(persist_client.clone()),
704 shard_id,
705 None, SnapshotMode::Include,
707 until,
708 Some(transformer),
709 Arc::new(
710 <std::string::String as mz_persist_types::Codec>::Schema::default(),
711 ),
712 Arc::new(
713 <std::string::String as mz_persist_types::Codec>::Schema::default(),
714 ),
715 FilterResult::keep_all,
716 false.then_some(|| unreachable!()),
717 async {},
718 |error| panic!("test: {error}"),
719 );
720 (stream.leave(), tokens)
721 });
722
723 let probe = stream.probe();
724
725 (probe, token)
726 });
727
728 while probe.less_than(&expected_frontier) {
729 worker.step();
730 }
731
732 let mut probe_frontier = Antichain::new();
733 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
734
735 probe_frontier
736 });
737
738 assert_eq!(res, Antichain::from_elem(expected_frontier));
739 }
740
741 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
748 #[cfg_attr(miri, ignore)] async fn test_shard_source_explicit_initial_as_of() {
750 let persist_client = PersistClient::new_for_tests().await;
751
752 let expected_frontier = 42;
753 let shard_id = ShardId::new();
754
755 initialize_shard(
756 &persist_client,
757 shard_id,
758 Antichain::from_elem(expected_frontier),
759 )
760 .await;
761
762 let res = timely::execute::execute_directly(move |worker| {
763 let as_of = Antichain::from_elem(expected_frontier);
764 let until = Antichain::new();
765
766 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
767 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
768 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
769 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
770 scope,
771 "test_source",
772 move || std::future::ready(persist_client.clone()),
773 shard_id,
774 Some(as_of), SnapshotMode::Include,
776 until,
777 Some(transformer),
778 Arc::new(
779 <std::string::String as mz_persist_types::Codec>::Schema::default(),
780 ),
781 Arc::new(
782 <std::string::String as mz_persist_types::Codec>::Schema::default(),
783 ),
784 FilterResult::keep_all,
785 false.then_some(|| unreachable!()),
786 async {},
787 |error| panic!("test: {error}"),
788 );
789 (stream.leave(), tokens)
790 });
791
792 let probe = stream.probe();
793
794 (probe, token)
795 });
796
797 while probe.less_than(&expected_frontier) {
798 worker.step();
799 }
800
801 let mut probe_frontier = Antichain::new();
802 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
803
804 probe_frontier
805 });
806
807 assert_eq!(res, Antichain::from_elem(expected_frontier));
808 }
809
810 async fn initialize_shard(
811 persist_client: &PersistClient,
812 shard_id: ShardId,
813 since: Antichain<u64>,
814 ) {
815 let mut read_handle = persist_client
816 .open_leased_reader::<String, String, u64, u64>(
817 shard_id,
818 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
819 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
820 Diagnostics::for_tests(),
821 true,
822 )
823 .await
824 .expect("invalid usage");
825
826 read_handle.downgrade_since(&since).await;
827 }
828}