1use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::collections::hash_map::DefaultHasher;
15use std::convert::Infallible;
16use std::fmt::{Debug, Formatter};
17use std::future::Future;
18use std::hash::{Hash, Hasher};
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Instant;
23
24use anyhow::anyhow;
25use arrow::array::ArrayRef;
26use differential_dataflow::Hashable;
27use differential_dataflow::difference::Semigroup;
28use differential_dataflow::lattice::Lattice;
29use futures_util::StreamExt;
30use mz_ore::cast::CastFrom;
31use mz_ore::collections::CollectionExt;
32use mz_persist_types::stats::PartStats;
33use mz_persist_types::{Codec, Codec64};
34use mz_timely_util::builder_async::{
35 Event, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
36};
37use timely::PartialOrder;
38use timely::container::CapacityContainerBuilder;
39use timely::dataflow::channels::pact::{Exchange, Pipeline};
40use timely::dataflow::operators::{CapabilitySet, ConnectLoop, Enter, Feedback, Leave};
41use timely::dataflow::scopes::Child;
42use timely::dataflow::{Scope, Stream};
43use timely::order::TotalOrder;
44use timely::progress::frontier::AntichainRef;
45use timely::progress::{Antichain, Timestamp, timestamp::Refines};
46use tracing::{debug, trace};
47
48use crate::batch::BLOB_TARGET_SIZE;
49use crate::cfg::{RetryParameters, USE_CRITICAL_SINCE_SOURCE};
50use crate::fetch::{ExchangeableBatchPart, FetchedBlob, Lease};
51use crate::internal::state::BatchPart;
52use crate::stats::{STATS_AUDIT_PERCENT, STATS_FILTER_ENABLED};
53use crate::{Diagnostics, PersistClient, ShardId};
54
55#[derive(Debug, Clone, PartialEq, Default)]
57pub enum FilterResult {
58 #[default]
60 Keep,
61 Discard,
63 ReplaceWith {
66 key: ArrayRef,
68 val: ArrayRef,
70 },
71}
72
73impl FilterResult {
74 pub fn keep_all<T>(_stats: &PartStats, _frontier: AntichainRef<T>) -> FilterResult {
76 Self::Keep
77 }
78}
79
80#[derive(Clone)]
91pub enum ErrorHandler {
92 Halt(&'static str),
94 Signal(Rc<dyn Fn(anyhow::Error) + 'static>),
96}
97
98impl Debug for ErrorHandler {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 match self {
101 ErrorHandler::Halt(name) => f.debug_tuple("ErrorHandler::Halt").field(name).finish(),
102 ErrorHandler::Signal(_) => f.write_str("ErrorHandler::Signal"),
103 }
104 }
105}
106
107impl ErrorHandler {
108 pub fn signal(signal_fn: impl Fn(anyhow::Error) + 'static) -> Self {
110 Self::Signal(Rc::new(signal_fn))
111 }
112
113 pub async fn report_and_stop(&self, error: anyhow::Error) -> ! {
117 match self {
118 ErrorHandler::Halt(name) => {
119 mz_ore::halt!("unhandled error in {name}: {error:#}")
120 }
121 ErrorHandler::Signal(callback) => {
122 let () = callback(error);
123 std::future::pending().await
124 }
125 }
126 }
127}
128
129pub fn shard_source<'g, K, V, T, D, DT, G, C>(
143 scope: &mut Child<'g, G, T>,
144 name: &str,
145 client: impl Fn() -> C,
146 shard_id: ShardId,
147 as_of: Option<Antichain<G::Timestamp>>,
148 snapshot_mode: SnapshotMode,
149 until: Antichain<G::Timestamp>,
150 desc_transformer: Option<DT>,
151 key_schema: Arc<K::Schema>,
152 val_schema: Arc<V::Schema>,
153 filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
154 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
156 start_signal: impl Future<Output = ()> + 'static,
157 error_handler: ErrorHandler,
158) -> (
159 Stream<Child<'g, G, T>, FetchedBlob<K, V, G::Timestamp, D>>,
160 Vec<PressOnDropButton>,
161)
162where
163 K: Debug + Codec,
164 V: Debug + Codec,
165 D: Semigroup + Codec64 + Send + Sync,
166 G: Scope,
167 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
169 T: Refines<G::Timestamp>,
170 DT: FnOnce(
171 Child<'g, G, T>,
172 &Stream<Child<'g, G, T>, (usize, ExchangeableBatchPart<G::Timestamp>)>,
173 usize,
174 ) -> (
175 Stream<Child<'g, G, T>, (usize, ExchangeableBatchPart<G::Timestamp>)>,
176 Vec<PressOnDropButton>,
177 ),
178 C: Future<Output = PersistClient> + Send + 'static,
179{
180 let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
196
197 let mut tokens = vec![];
198
199 let (completed_fetches_feedback_handle, completed_fetches_feedback_stream) =
202 scope.feedback(T::Summary::default());
203
204 let is_transient = !until.is_empty();
208
209 let (descs, descs_token) = shard_source_descs::<K, V, D, G>(
210 &scope.parent,
211 name,
212 client(),
213 shard_id.clone(),
214 as_of,
215 snapshot_mode,
216 until,
217 completed_fetches_feedback_stream.leave(),
218 chosen_worker,
219 Arc::clone(&key_schema),
220 Arc::clone(&val_schema),
221 filter_fn,
222 listen_sleep,
223 start_signal,
224 error_handler.clone(),
225 );
226 tokens.push(descs_token);
227
228 let descs = descs.enter(scope);
229 let descs = match desc_transformer {
230 Some(desc_transformer) => {
231 let (descs, extra_tokens) = desc_transformer(scope.clone(), &descs, chosen_worker);
232 tokens.extend(extra_tokens);
233 descs
234 }
235 None => descs,
236 };
237
238 let (parts, completed_fetches_stream, fetch_token) = shard_source_fetch(
239 &descs,
240 name,
241 client(),
242 shard_id,
243 key_schema,
244 val_schema,
245 is_transient,
246 error_handler,
247 );
248 completed_fetches_stream.connect_loop(completed_fetches_feedback_handle);
249 tokens.push(fetch_token);
250
251 (parts, tokens)
252}
253
254#[derive(Debug, Clone, Copy)]
256pub enum SnapshotMode {
257 Include,
259 Exclude,
261}
262
263#[derive(Debug)]
264struct LeaseManager<T> {
265 leases: BTreeMap<T, Vec<Lease>>,
266}
267
268impl<T: Timestamp + Codec64> LeaseManager<T> {
269 fn new() -> Self {
270 Self {
271 leases: BTreeMap::new(),
272 }
273 }
274
275 fn push_at(&mut self, time: T, lease: Lease) {
277 self.leases.entry(time).or_default().push(lease);
278 }
279
280 fn advance_to(&mut self, frontier: AntichainRef<T>)
282 where
283 T: TotalOrder,
285 {
286 while let Some(first) = self.leases.first_entry() {
287 if frontier.less_equal(first.key()) {
288 break; }
290 drop(first.remove());
291 }
292 }
293}
294
295pub(crate) fn shard_source_descs<K, V, D, G>(
296 scope: &G,
297 name: &str,
298 client: impl Future<Output = PersistClient> + Send + 'static,
299 shard_id: ShardId,
300 as_of: Option<Antichain<G::Timestamp>>,
301 snapshot_mode: SnapshotMode,
302 until: Antichain<G::Timestamp>,
303 completed_fetches_stream: Stream<G, Infallible>,
304 chosen_worker: usize,
305 key_schema: Arc<K::Schema>,
306 val_schema: Arc<V::Schema>,
307 mut filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
308 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
310 start_signal: impl Future<Output = ()> + 'static,
311 error_handler: ErrorHandler,
312) -> (
313 Stream<G, (usize, ExchangeableBatchPart<G::Timestamp>)>,
314 PressOnDropButton,
315)
316where
317 K: Debug + Codec,
318 V: Debug + Codec,
319 D: Semigroup + Codec64 + Send + Sync,
320 G: Scope,
321 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
323{
324 let worker_index = scope.index();
325 let num_workers = scope.peers();
326
327 let name_owned = name.to_owned();
330
331 let listen_handle = Rc::new(RefCell::new(None));
333 let return_listen_handle = Rc::clone(&listen_handle);
334
335 let (tx, rx) = tokio::sync::oneshot::channel::<Rc<RefCell<LeaseManager<G::Timestamp>>>>();
337 let mut builder = AsyncOperatorBuilder::new(
338 format!("shard_source_descs_return({})", name),
339 scope.clone(),
340 );
341 let mut completed_fetches = builder.new_disconnected_input(&completed_fetches_stream, Pipeline);
342 builder.build(move |_caps| async move {
345 let Ok(leases) = rx.await else {
346 return;
349 };
350 while let Some(event) = completed_fetches.next().await {
351 let Event::Progress(frontier) = event else {
352 continue;
353 };
354 leases.borrow_mut().advance_to(frontier.borrow());
355 }
356 drop(return_listen_handle);
358 });
359
360 let mut builder =
361 AsyncOperatorBuilder::new(format!("shard_source_descs({})", name), scope.clone());
362 let (descs_output, descs_stream) = builder.new_output();
363
364 #[allow(clippy::await_holding_refcell_ref)]
365 let shutdown_button = builder.build(move |caps| async move {
366 let mut cap_set = CapabilitySet::from_elem(caps.into_element());
367
368 if worker_index != chosen_worker {
370 trace!(
371 "We are not the chosen worker ({}), exiting...",
372 chosen_worker
373 );
374 return;
375 }
376
377 let mut read = mz_ore::task::spawn(|| format!("shard_source_reader({})", name_owned), {
387 let diagnostics = Diagnostics {
388 handle_purpose: format!("shard_source({})", name_owned),
389 shard_name: name_owned.clone(),
390 };
391 async move {
392 let client = client.await;
393 client
394 .open_leased_reader::<K, V, G::Timestamp, D>(
395 shard_id,
396 key_schema,
397 val_schema,
398 diagnostics,
399 USE_CRITICAL_SINCE_SOURCE.get(client.dyncfgs()),
400 )
401 .await
402 }
403 })
404 .await
405 .expect("reader creation shouldn't panic")
406 .expect("could not open persist shard");
407
408 let () = start_signal.await;
412
413 let cfg = read.cfg.clone();
414 let metrics = Arc::clone(&read.metrics);
415
416 let as_of = as_of.unwrap_or_else(|| read.since().clone());
417
418 cap_set.downgrade(as_of.clone());
434
435 let mut snapshot_parts =
436 match snapshot_mode {
437 SnapshotMode::Include => match read.snapshot(as_of.clone()).await {
438 Ok(parts) => parts,
439 Err(e) => error_handler
440 .report_and_stop(anyhow!(
441 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
442 ))
443 .await,
444 },
445 SnapshotMode::Exclude => vec![],
446 };
447
448 let leases = Rc::new(RefCell::new(LeaseManager::new()));
452 tx.send(Rc::clone(&leases))
453 .expect("lease returner exited before desc producer");
454
455 let mut listen = listen_handle.borrow_mut();
458 let listen = match read.listen(as_of.clone()).await {
459 Ok(handle) => listen.insert(handle),
460 Err(e) => {
461 error_handler
462 .report_and_stop(anyhow!(
463 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
464 ))
465 .await
466 }
467 };
468
469 let listen_retry = listen_sleep.as_ref().map(|retry| retry());
470
471 let listen_head = if !snapshot_parts.is_empty() {
473 let (mut parts, progress) = listen.next(listen_retry).await;
474 snapshot_parts.append(&mut parts);
475 futures::stream::iter(Some((snapshot_parts, progress)))
476 } else {
477 futures::stream::iter(None)
478 };
479
480 let listen_tail = futures::stream::unfold(listen, |listen| async move {
482 Some((listen.next(listen_retry).await, listen))
483 });
484
485 let mut shard_stream = pin!(listen_head.chain(listen_tail));
486
487 let mut audit_budget_bytes = u64::cast_from(BLOB_TARGET_SIZE.get(&cfg).saturating_mul(2));
491
492 let mut current_frontier = as_of.clone();
494
495 while !PartialOrder::less_equal(&until, ¤t_frontier) {
498 let (parts, progress) = shard_stream.next().await.expect("infinite stream");
499
500 let current_ts = current_frontier
503 .as_option()
504 .expect("until should always be <= the empty frontier");
505 let session_cap = cap_set.delayed(current_ts);
506
507 for mut part_desc in parts {
508 if STATS_FILTER_ENABLED.get(&cfg) {
511 let filter_result = match &part_desc.part {
512 BatchPart::Hollow(x) => {
513 let should_fetch =
514 x.stats.as_ref().map_or(FilterResult::Keep, |stats| {
515 filter_fn(&stats.decode(), current_frontier.borrow())
516 });
517 should_fetch
518 }
519 BatchPart::Inline { .. } => FilterResult::Keep,
520 };
521 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
523 match filter_result {
524 FilterResult::Keep => {
525 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
526 }
527 FilterResult::Discard => {
528 metrics.pushdown.parts_filtered_count.inc();
529 metrics.pushdown.parts_filtered_bytes.inc_by(bytes);
530 let should_audit = match &part_desc.part {
531 BatchPart::Hollow(x) => {
532 let mut h = DefaultHasher::new();
533 x.key.hash(&mut h);
534 usize::cast_from(h.finish()) % 100
535 < STATS_AUDIT_PERCENT.get(&cfg)
536 }
537 BatchPart::Inline { .. } => false,
538 };
539 if should_audit && bytes < audit_budget_bytes {
540 audit_budget_bytes -= bytes;
541 metrics.pushdown.parts_audited_count.inc();
542 metrics.pushdown.parts_audited_bytes.inc_by(bytes);
543 part_desc.request_filter_pushdown_audit();
544 } else {
545 debug!(
546 "skipping part because of stats filter {:?}",
547 part_desc.part.stats()
548 );
549 continue;
550 }
551 }
552 FilterResult::ReplaceWith { key, val } => {
553 part_desc.maybe_optimize(&cfg, key, val);
554 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
555 }
556 }
557 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
558 if part_desc.part.is_inline() {
559 metrics.pushdown.parts_inline_count.inc();
560 metrics.pushdown.parts_inline_bytes.inc_by(bytes);
561 } else {
562 metrics.pushdown.parts_fetched_count.inc();
563 metrics.pushdown.parts_fetched_bytes.inc_by(bytes);
564 }
565 }
566
567 let worker_idx = usize::cast_from(Instant::now().hashed()) % num_workers;
574 let (part, lease) = part_desc.into_exchangeable_part();
575 leases.borrow_mut().push_at(current_ts.clone(), lease);
576 descs_output.give(&session_cap, (worker_idx, part));
577 }
578
579 current_frontier.join_assign(&progress);
580 cap_set.downgrade(progress.iter());
581 }
582 });
583
584 (descs_stream, shutdown_button.press_on_drop())
585}
586
587pub(crate) fn shard_source_fetch<K, V, T, D, G>(
588 descs: &Stream<G, (usize, ExchangeableBatchPart<T>)>,
589 name: &str,
590 client: impl Future<Output = PersistClient> + Send + 'static,
591 shard_id: ShardId,
592 key_schema: Arc<K::Schema>,
593 val_schema: Arc<V::Schema>,
594 is_transient: bool,
595 error_handler: ErrorHandler,
596) -> (
597 Stream<G, FetchedBlob<K, V, T, D>>,
598 Stream<G, Infallible>,
599 PressOnDropButton,
600)
601where
602 K: Debug + Codec,
603 V: Debug + Codec,
604 T: Timestamp + Lattice + Codec64 + Sync,
605 D: Semigroup + Codec64 + Send + Sync,
606 G: Scope,
607 G::Timestamp: Refines<T>,
608{
609 let mut builder =
610 AsyncOperatorBuilder::new(format!("shard_source_fetch({})", name), descs.scope());
611 let (fetched_output, fetched_stream) = builder.new_output();
612 let (completed_fetches_output, completed_fetches_stream) =
613 builder.new_output::<CapacityContainerBuilder<Vec<Infallible>>>();
614 let mut descs_input = builder.new_input_for_many(
615 descs,
616 Exchange::new(|&(i, _): &(usize, _)| u64::cast_from(i)),
617 [&fetched_output, &completed_fetches_output],
618 );
619 let name_owned = name.to_owned();
620
621 let shutdown_button = builder.build(move |_capabilities| async move {
622 let mut fetcher = mz_ore::task::spawn(|| format!("shard_source_fetch({})", name_owned), {
623 let diagnostics = Diagnostics {
624 shard_name: name_owned.clone(),
625 handle_purpose: format!("shard_source_fetch batch fetcher {}", name_owned),
626 };
627 async move {
628 client
629 .await
630 .create_batch_fetcher::<K, V, T, D>(
631 shard_id,
632 key_schema,
633 val_schema,
634 is_transient,
635 diagnostics,
636 )
637 .await
638 }
639 })
640 .await
641 .expect("fetcher creation shouldn't panic")
642 .expect("shard codecs should not change");
643
644 while let Some(event) = descs_input.next().await {
645 if let Event::Data([fetched_cap, _completed_fetches_cap], data) = event {
646 for (_idx, part) in data {
649 let fetched = fetcher
650 .fetch_leased_part(part)
651 .await
652 .expect("shard_id should match across all workers");
653 let fetched = match fetched {
654 Ok(fetched) => fetched,
655 Err(blob_key) => {
656 error_handler
665 .report_and_stop(anyhow!(
666 "batch fetcher could not fetch batch part {}; lost lease?",
667 blob_key
668 ))
669 .await
670 }
671 };
672 {
673 fetched_output.give(&fetched_cap, fetched);
679 }
680 }
681 }
682 }
683 });
684
685 (
686 fetched_stream,
687 completed_fetches_stream,
688 shutdown_button.press_on_drop(),
689 )
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use std::sync::Arc;
696
697 use mz_persist::location::SeqNo;
698 use timely::dataflow::Scope;
699 use timely::dataflow::operators::Leave;
700 use timely::dataflow::operators::Probe;
701 use timely::progress::Antichain;
702
703 use crate::operators::shard_source::shard_source;
704 use crate::{Diagnostics, ShardId};
705
706 #[mz_ore::test]
707 fn test_lease_manager() {
708 let lease = Lease::new(SeqNo::minimum());
709 let mut manager = LeaseManager::new();
710 for t in 0u64..10 {
711 manager.push_at(t, lease.clone());
712 }
713 assert_eq!(lease.count(), 11);
714 manager.advance_to(AntichainRef::new(&[5]));
715 assert_eq!(lease.count(), 6);
716 manager.advance_to(AntichainRef::new(&[3]));
717 assert_eq!(lease.count(), 6);
718 manager.advance_to(AntichainRef::new(&[9]));
719 assert_eq!(lease.count(), 2);
720 manager.advance_to(AntichainRef::new(&[10]));
721 assert_eq!(lease.count(), 1);
722 }
723
724 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
732 #[cfg_attr(miri, ignore)] async fn test_shard_source_implicit_initial_as_of() {
734 let persist_client = PersistClient::new_for_tests().await;
735
736 let expected_frontier = 42;
737 let shard_id = ShardId::new();
738
739 initialize_shard(
740 &persist_client,
741 shard_id,
742 Antichain::from_elem(expected_frontier),
743 )
744 .await;
745
746 let res = timely::execute::execute_directly(move |worker| {
747 let until = Antichain::new();
748
749 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
750 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
751 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
752 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
753 scope,
754 "test_source",
755 move || std::future::ready(persist_client.clone()),
756 shard_id,
757 None, SnapshotMode::Include,
759 until,
760 Some(transformer),
761 Arc::new(
762 <std::string::String as mz_persist_types::Codec>::Schema::default(),
763 ),
764 Arc::new(
765 <std::string::String as mz_persist_types::Codec>::Schema::default(),
766 ),
767 FilterResult::keep_all,
768 false.then_some(|| unreachable!()),
769 async {},
770 ErrorHandler::Halt("test"),
771 );
772 (stream.leave(), tokens)
773 });
774
775 let probe = stream.probe();
776
777 (probe, token)
778 });
779
780 while probe.less_than(&expected_frontier) {
781 worker.step();
782 }
783
784 let mut probe_frontier = Antichain::new();
785 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
786
787 probe_frontier
788 });
789
790 assert_eq!(res, Antichain::from_elem(expected_frontier));
791 }
792
793 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
800 #[cfg_attr(miri, ignore)] async fn test_shard_source_explicit_initial_as_of() {
802 let persist_client = PersistClient::new_for_tests().await;
803
804 let expected_frontier = 42;
805 let shard_id = ShardId::new();
806
807 initialize_shard(
808 &persist_client,
809 shard_id,
810 Antichain::from_elem(expected_frontier),
811 )
812 .await;
813
814 let res = timely::execute::execute_directly(move |worker| {
815 let as_of = Antichain::from_elem(expected_frontier);
816 let until = Antichain::new();
817
818 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
819 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
820 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
821 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
822 scope,
823 "test_source",
824 move || std::future::ready(persist_client.clone()),
825 shard_id,
826 Some(as_of), SnapshotMode::Include,
828 until,
829 Some(transformer),
830 Arc::new(
831 <std::string::String as mz_persist_types::Codec>::Schema::default(),
832 ),
833 Arc::new(
834 <std::string::String as mz_persist_types::Codec>::Schema::default(),
835 ),
836 FilterResult::keep_all,
837 false.then_some(|| unreachable!()),
838 async {},
839 ErrorHandler::Halt("test"),
840 );
841 (stream.leave(), tokens)
842 });
843
844 let probe = stream.probe();
845
846 (probe, token)
847 });
848
849 while probe.less_than(&expected_frontier) {
850 worker.step();
851 }
852
853 let mut probe_frontier = Antichain::new();
854 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
855
856 probe_frontier
857 });
858
859 assert_eq!(res, Antichain::from_elem(expected_frontier));
860 }
861
862 async fn initialize_shard(
863 persist_client: &PersistClient,
864 shard_id: ShardId,
865 since: Antichain<u64>,
866 ) {
867 let mut read_handle = persist_client
868 .open_leased_reader::<String, String, u64, u64>(
869 shard_id,
870 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
871 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
872 Diagnostics::for_tests(),
873 true,
874 )
875 .await
876 .expect("invalid usage");
877
878 read_handle.downgrade_since(&since).await;
879 }
880}