1use std::cell::RefCell;
13use std::collections::BTreeMap;
14use std::collections::hash_map::DefaultHasher;
15use std::convert::Infallible;
16use std::fmt::{Debug, Formatter};
17use std::future::Future;
18use std::hash::{Hash, Hasher};
19use std::pin::pin;
20use std::rc::Rc;
21use std::sync::Arc;
22use std::time::Instant;
23
24use anyhow::anyhow;
25use arrow::array::ArrayRef;
26use differential_dataflow::Hashable;
27use differential_dataflow::difference::Monoid;
28use differential_dataflow::lattice::Lattice;
29use futures_util::StreamExt;
30use mz_ore::cast::CastFrom;
31use mz_ore::collections::CollectionExt;
32use mz_persist_types::stats::PartStats;
33use mz_persist_types::{Codec, Codec64};
34use mz_timely_util::builder_async::{
35 Event, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
36};
37use timely::PartialOrder;
38use timely::container::CapacityContainerBuilder;
39use timely::dataflow::channels::pact::{Exchange, Pipeline};
40use timely::dataflow::operators::{CapabilitySet, ConnectLoop, Enter, Feedback, Leave};
41use timely::dataflow::scopes::Child;
42use timely::dataflow::{Scope, Stream};
43use timely::order::TotalOrder;
44use timely::progress::frontier::AntichainRef;
45use timely::progress::{Antichain, Timestamp, timestamp::Refines};
46use tracing::{debug, trace};
47
48use crate::batch::BLOB_TARGET_SIZE;
49use crate::cfg::{RetryParameters, USE_CRITICAL_SINCE_SOURCE};
50use crate::fetch::{ExchangeableBatchPart, FetchedBlob, Lease};
51use crate::internal::state::BatchPart;
52use crate::stats::{STATS_AUDIT_PERCENT, STATS_FILTER_ENABLED};
53use crate::{Diagnostics, PersistClient, ShardId};
54
55#[derive(Debug, Clone, PartialEq, Default)]
57pub enum FilterResult {
58 #[default]
60 Keep,
61 Discard,
63 ReplaceWith {
66 key: ArrayRef,
68 val: ArrayRef,
70 },
71}
72
73impl FilterResult {
74 pub fn keep_all<T>(_stats: &PartStats, _frontier: AntichainRef<T>) -> FilterResult {
76 Self::Keep
77 }
78}
79
80#[derive(Clone)]
91pub enum ErrorHandler {
92 Halt(&'static str),
94 Signal(Rc<dyn Fn(anyhow::Error) + 'static>),
96}
97
98impl Debug for ErrorHandler {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 match self {
101 ErrorHandler::Halt(name) => f.debug_tuple("ErrorHandler::Halt").field(name).finish(),
102 ErrorHandler::Signal(_) => f.write_str("ErrorHandler::Signal"),
103 }
104 }
105}
106
107impl ErrorHandler {
108 pub fn signal(signal_fn: impl Fn(anyhow::Error) + 'static) -> Self {
110 Self::Signal(Rc::new(signal_fn))
111 }
112
113 pub async fn report_and_stop(&self, error: anyhow::Error) -> ! {
117 match self {
118 ErrorHandler::Halt(name) => {
119 mz_ore::halt!("unhandled error in {name}: {error:#}")
120 }
121 ErrorHandler::Signal(callback) => {
122 let () = callback(error);
123 std::future::pending().await
124 }
125 }
126 }
127}
128
129pub fn shard_source<'g, K, V, T, D, DT, G, C>(
143 scope: &mut Child<'g, G, T>,
144 name: &str,
145 client: impl Fn() -> C,
146 shard_id: ShardId,
147 as_of: Option<Antichain<G::Timestamp>>,
148 snapshot_mode: SnapshotMode,
149 until: Antichain<G::Timestamp>,
150 desc_transformer: Option<DT>,
151 key_schema: Arc<K::Schema>,
152 val_schema: Arc<V::Schema>,
153 filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
154 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
156 start_signal: impl Future<Output = ()> + 'static,
157 error_handler: ErrorHandler,
158) -> (
159 Stream<Child<'g, G, T>, FetchedBlob<K, V, G::Timestamp, D>>,
160 Vec<PressOnDropButton>,
161)
162where
163 K: Debug + Codec,
164 V: Debug + Codec,
165 D: Monoid + Codec64 + Send + Sync,
166 G: Scope,
167 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
169 T: Refines<G::Timestamp>,
170 DT: FnOnce(
171 Child<'g, G, T>,
172 &Stream<Child<'g, G, T>, (usize, ExchangeableBatchPart<G::Timestamp>)>,
173 usize,
174 ) -> (
175 Stream<Child<'g, G, T>, (usize, ExchangeableBatchPart<G::Timestamp>)>,
176 Vec<PressOnDropButton>,
177 ),
178 C: Future<Output = PersistClient> + Send + 'static,
179{
180 let chosen_worker = usize::cast_from(name.hashed()) % scope.peers();
196
197 let mut tokens = vec![];
198
199 let (completed_fetches_feedback_handle, completed_fetches_feedback_stream) =
202 scope.feedback(T::Summary::default());
203
204 let is_transient = !until.is_empty();
208
209 let (descs, descs_token) = shard_source_descs::<K, V, D, G>(
210 &scope.parent,
211 name,
212 client(),
213 shard_id.clone(),
214 as_of,
215 snapshot_mode,
216 until,
217 completed_fetches_feedback_stream.leave(),
218 chosen_worker,
219 Arc::clone(&key_schema),
220 Arc::clone(&val_schema),
221 filter_fn,
222 listen_sleep,
223 start_signal,
224 error_handler.clone(),
225 );
226 tokens.push(descs_token);
227
228 let descs = descs.enter(scope);
229 let descs = match desc_transformer {
230 Some(desc_transformer) => {
231 let (descs, extra_tokens) = desc_transformer(scope.clone(), &descs, chosen_worker);
232 tokens.extend(extra_tokens);
233 descs
234 }
235 None => descs,
236 };
237
238 let (parts, completed_fetches_stream, fetch_token) = shard_source_fetch(
239 &descs,
240 name,
241 client(),
242 shard_id,
243 key_schema,
244 val_schema,
245 is_transient,
246 error_handler,
247 );
248 completed_fetches_stream.connect_loop(completed_fetches_feedback_handle);
249 tokens.push(fetch_token);
250
251 (parts, tokens)
252}
253
254#[derive(Debug, Clone, Copy)]
256pub enum SnapshotMode {
257 Include,
259 Exclude,
261}
262
263#[derive(Debug)]
264struct LeaseManager<T> {
265 leases: BTreeMap<T, Vec<Lease>>,
266}
267
268impl<T: Timestamp + Codec64> LeaseManager<T> {
269 fn new() -> Self {
270 Self {
271 leases: BTreeMap::new(),
272 }
273 }
274
275 fn push_at(&mut self, time: T, lease: Lease) {
277 self.leases.entry(time).or_default().push(lease);
278 }
279
280 fn advance_to(&mut self, frontier: AntichainRef<T>)
282 where
283 T: TotalOrder,
285 {
286 while let Some(first) = self.leases.first_entry() {
287 if frontier.less_equal(first.key()) {
288 break; }
290 drop(first.remove());
291 }
292 }
293}
294
295pub(crate) fn shard_source_descs<K, V, D, G>(
296 scope: &G,
297 name: &str,
298 client: impl Future<Output = PersistClient> + Send + 'static,
299 shard_id: ShardId,
300 as_of: Option<Antichain<G::Timestamp>>,
301 snapshot_mode: SnapshotMode,
302 until: Antichain<G::Timestamp>,
303 completed_fetches_stream: Stream<G, Infallible>,
304 chosen_worker: usize,
305 key_schema: Arc<K::Schema>,
306 val_schema: Arc<V::Schema>,
307 mut filter_fn: impl FnMut(&PartStats, AntichainRef<G::Timestamp>) -> FilterResult + 'static,
308 listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
310 start_signal: impl Future<Output = ()> + 'static,
311 error_handler: ErrorHandler,
312) -> (
313 Stream<G, (usize, ExchangeableBatchPart<G::Timestamp>)>,
314 PressOnDropButton,
315)
316where
317 K: Debug + Codec,
318 V: Debug + Codec,
319 D: Monoid + Codec64 + Send + Sync,
320 G: Scope,
321 G::Timestamp: Timestamp + Lattice + Codec64 + TotalOrder + Sync,
323{
324 let worker_index = scope.index();
325 let num_workers = scope.peers();
326
327 let name_owned = name.to_owned();
330
331 let listen_handle = Rc::new(RefCell::new(None));
333 let return_listen_handle = Rc::clone(&listen_handle);
334
335 let (tx, rx) = tokio::sync::oneshot::channel::<Rc<RefCell<LeaseManager<G::Timestamp>>>>();
337 let mut builder = AsyncOperatorBuilder::new(
338 format!("shard_source_descs_return({})", name),
339 scope.clone(),
340 );
341 let mut completed_fetches = builder.new_disconnected_input(&completed_fetches_stream, Pipeline);
342 builder.build(move |_caps| async move {
345 let Ok(leases) = rx.await else {
346 return;
349 };
350 while let Some(event) = completed_fetches.next().await {
351 let Event::Progress(frontier) = event else {
352 continue;
353 };
354 leases.borrow_mut().advance_to(frontier.borrow());
355 }
356 drop(return_listen_handle);
358 });
359
360 let mut builder =
361 AsyncOperatorBuilder::new(format!("shard_source_descs({})", name), scope.clone());
362 let (descs_output, descs_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
363
364 #[allow(clippy::await_holding_refcell_ref)]
365 let shutdown_button = builder.build(move |caps| async move {
366 let mut cap_set = CapabilitySet::from_elem(caps.into_element());
367
368 if worker_index != chosen_worker {
370 trace!(
371 "We are not the chosen worker ({}), exiting...",
372 chosen_worker
373 );
374 return;
375 }
376
377 let mut read = mz_ore::task::spawn(|| format!("shard_source_reader({})", name_owned), {
387 let diagnostics = Diagnostics {
388 handle_purpose: format!("shard_source({})", name_owned),
389 shard_name: name_owned.clone(),
390 };
391 async move {
392 let client = client.await;
393 client
394 .open_leased_reader::<K, V, G::Timestamp, D>(
395 shard_id,
396 key_schema,
397 val_schema,
398 diagnostics,
399 USE_CRITICAL_SINCE_SOURCE.get(client.dyncfgs()),
400 )
401 .await
402 }
403 })
404 .await
405 .expect("could not open persist shard");
406
407 let () = start_signal.await;
411
412 let cfg = read.cfg.clone();
413 let metrics = Arc::clone(&read.metrics);
414
415 let as_of = as_of.unwrap_or_else(|| read.since().clone());
416
417 cap_set.downgrade(as_of.clone());
433
434 let mut snapshot_parts =
435 match snapshot_mode {
436 SnapshotMode::Include => match read.snapshot(as_of.clone()).await {
437 Ok(parts) => parts,
438 Err(e) => error_handler
439 .report_and_stop(anyhow!(
440 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
441 ))
442 .await,
443 },
444 SnapshotMode::Exclude => vec![],
445 };
446
447 let leases = Rc::new(RefCell::new(LeaseManager::new()));
451 tx.send(Rc::clone(&leases))
452 .expect("lease returner exited before desc producer");
453
454 let mut listen = listen_handle.borrow_mut();
457 let listen = match read.listen(as_of.clone()).await {
458 Ok(handle) => listen.insert(handle),
459 Err(e) => {
460 error_handler
461 .report_and_stop(anyhow!(
462 "{name_owned}: {shard_id} cannot serve requested as_of {as_of:?}: {e:?}"
463 ))
464 .await
465 }
466 };
467
468 let listen_retry = listen_sleep.as_ref().map(|retry| retry());
469
470 let listen_head = if !snapshot_parts.is_empty() {
472 let (mut parts, progress) = listen.next(listen_retry).await;
473 snapshot_parts.append(&mut parts);
474 futures::stream::iter(Some((snapshot_parts, progress)))
475 } else {
476 futures::stream::iter(None)
477 };
478
479 let listen_tail = futures::stream::unfold(listen, |listen| async move {
481 Some((listen.next(listen_retry).await, listen))
482 });
483
484 let mut shard_stream = pin!(listen_head.chain(listen_tail));
485
486 let mut audit_budget_bytes = u64::cast_from(BLOB_TARGET_SIZE.get(&cfg).saturating_mul(2));
490
491 let mut current_frontier = as_of.clone();
493
494 while !PartialOrder::less_equal(&until, ¤t_frontier) {
497 let (parts, progress) = shard_stream.next().await.expect("infinite stream");
498
499 let current_ts = current_frontier
502 .as_option()
503 .expect("until should always be <= the empty frontier");
504 let session_cap = cap_set.delayed(current_ts);
505
506 for mut part_desc in parts {
507 if STATS_FILTER_ENABLED.get(&cfg) {
510 let filter_result = match &part_desc.part {
511 BatchPart::Hollow(x) => {
512 let should_fetch =
513 x.stats.as_ref().map_or(FilterResult::Keep, |stats| {
514 filter_fn(&stats.decode(), current_frontier.borrow())
515 });
516 should_fetch
517 }
518 BatchPart::Inline { .. } => FilterResult::Keep,
519 };
520 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
522 match filter_result {
523 FilterResult::Keep => {
524 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
525 }
526 FilterResult::Discard => {
527 metrics.pushdown.parts_filtered_count.inc();
528 metrics.pushdown.parts_filtered_bytes.inc_by(bytes);
529 let should_audit = match &part_desc.part {
530 BatchPart::Hollow(x) => {
531 let mut h = DefaultHasher::new();
532 x.key.hash(&mut h);
533 usize::cast_from(h.finish()) % 100
534 < STATS_AUDIT_PERCENT.get(&cfg)
535 }
536 BatchPart::Inline { .. } => false,
537 };
538 if should_audit && bytes < audit_budget_bytes {
539 audit_budget_bytes -= bytes;
540 metrics.pushdown.parts_audited_count.inc();
541 metrics.pushdown.parts_audited_bytes.inc_by(bytes);
542 part_desc.request_filter_pushdown_audit();
543 } else {
544 debug!(
545 "skipping part because of stats filter {:?}",
546 part_desc.part.stats()
547 );
548 continue;
549 }
550 }
551 FilterResult::ReplaceWith { key, val } => {
552 part_desc.maybe_optimize(&cfg, key, val);
553 audit_budget_bytes = audit_budget_bytes.saturating_add(bytes);
554 }
555 }
556 let bytes = u64::cast_from(part_desc.encoded_size_bytes());
557 if part_desc.part.is_inline() {
558 metrics.pushdown.parts_inline_count.inc();
559 metrics.pushdown.parts_inline_bytes.inc_by(bytes);
560 } else {
561 metrics.pushdown.parts_fetched_count.inc();
562 metrics.pushdown.parts_fetched_bytes.inc_by(bytes);
563 }
564 }
565
566 let worker_idx = usize::cast_from(Instant::now().hashed()) % num_workers;
573 let (part, lease) = part_desc.into_exchangeable_part();
574 leases.borrow_mut().push_at(current_ts.clone(), lease);
575 descs_output.give(&session_cap, (worker_idx, part));
576 }
577
578 current_frontier.join_assign(&progress);
579 cap_set.downgrade(progress.iter());
580 }
581 });
582
583 (descs_stream, shutdown_button.press_on_drop())
584}
585
586pub(crate) fn shard_source_fetch<K, V, T, D, G>(
587 descs: &Stream<G, (usize, ExchangeableBatchPart<T>)>,
588 name: &str,
589 client: impl Future<Output = PersistClient> + Send + 'static,
590 shard_id: ShardId,
591 key_schema: Arc<K::Schema>,
592 val_schema: Arc<V::Schema>,
593 is_transient: bool,
594 error_handler: ErrorHandler,
595) -> (
596 Stream<G, FetchedBlob<K, V, T, D>>,
597 Stream<G, Infallible>,
598 PressOnDropButton,
599)
600where
601 K: Debug + Codec,
602 V: Debug + Codec,
603 T: Timestamp + Lattice + Codec64 + Sync,
604 D: Monoid + Codec64 + Send + Sync,
605 G: Scope,
606 G::Timestamp: Refines<T>,
607{
608 let mut builder =
609 AsyncOperatorBuilder::new(format!("shard_source_fetch({})", name), descs.scope());
610 let (fetched_output, fetched_stream) = builder.new_output::<CapacityContainerBuilder<_>>();
611 let (completed_fetches_output, completed_fetches_stream) =
612 builder.new_output::<CapacityContainerBuilder<Vec<Infallible>>>();
613 let mut descs_input = builder.new_input_for_many(
614 descs,
615 Exchange::new(|&(i, _): &(usize, _)| u64::cast_from(i)),
616 [&fetched_output, &completed_fetches_output],
617 );
618 let name_owned = name.to_owned();
619
620 let shutdown_button = builder.build(move |_capabilities| async move {
621 let mut fetcher = mz_ore::task::spawn(|| format!("shard_source_fetch({})", name_owned), {
622 let diagnostics = Diagnostics {
623 shard_name: name_owned.clone(),
624 handle_purpose: format!("shard_source_fetch batch fetcher {}", name_owned),
625 };
626 async move {
627 client
628 .await
629 .create_batch_fetcher::<K, V, T, D>(
630 shard_id,
631 key_schema,
632 val_schema,
633 is_transient,
634 diagnostics,
635 )
636 .await
637 }
638 })
639 .await
640 .expect("shard codecs should not change");
641
642 while let Some(event) = descs_input.next().await {
643 if let Event::Data([fetched_cap, _completed_fetches_cap], data) = event {
644 for (_idx, part) in data {
647 let fetched = fetcher
648 .fetch_leased_part(part)
649 .await
650 .expect("shard_id should match across all workers");
651 let fetched = match fetched {
652 Ok(fetched) => fetched,
653 Err(blob_key) => {
654 error_handler
663 .report_and_stop(anyhow!(
664 "batch fetcher could not fetch batch part {}; lost lease?",
665 blob_key
666 ))
667 .await
668 }
669 };
670 {
671 fetched_output.give(&fetched_cap, fetched);
677 }
678 }
679 }
680 }
681 });
682
683 (
684 fetched_stream,
685 completed_fetches_stream,
686 shutdown_button.press_on_drop(),
687 )
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693 use std::sync::Arc;
694
695 use mz_persist::location::SeqNo;
696 use timely::dataflow::Scope;
697 use timely::dataflow::operators::Leave;
698 use timely::dataflow::operators::Probe;
699 use timely::progress::Antichain;
700
701 use crate::operators::shard_source::shard_source;
702 use crate::{Diagnostics, ShardId};
703
704 #[mz_ore::test]
705 fn test_lease_manager() {
706 let lease = Lease::new(SeqNo::minimum());
707 let mut manager = LeaseManager::new();
708 for t in 0u64..10 {
709 manager.push_at(t, lease.clone());
710 }
711 assert_eq!(lease.count(), 11);
712 manager.advance_to(AntichainRef::new(&[5]));
713 assert_eq!(lease.count(), 6);
714 manager.advance_to(AntichainRef::new(&[3]));
715 assert_eq!(lease.count(), 6);
716 manager.advance_to(AntichainRef::new(&[9]));
717 assert_eq!(lease.count(), 2);
718 manager.advance_to(AntichainRef::new(&[10]));
719 assert_eq!(lease.count(), 1);
720 }
721
722 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
730 #[cfg_attr(miri, ignore)] async fn test_shard_source_implicit_initial_as_of() {
732 let persist_client = PersistClient::new_for_tests().await;
733
734 let expected_frontier = 42;
735 let shard_id = ShardId::new();
736
737 initialize_shard(
738 &persist_client,
739 shard_id,
740 Antichain::from_elem(expected_frontier),
741 )
742 .await;
743
744 let res = timely::execute::execute_directly(move |worker| {
745 let until = Antichain::new();
746
747 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
748 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
749 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
750 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
751 scope,
752 "test_source",
753 move || std::future::ready(persist_client.clone()),
754 shard_id,
755 None, SnapshotMode::Include,
757 until,
758 Some(transformer),
759 Arc::new(
760 <std::string::String as mz_persist_types::Codec>::Schema::default(),
761 ),
762 Arc::new(
763 <std::string::String as mz_persist_types::Codec>::Schema::default(),
764 ),
765 FilterResult::keep_all,
766 false.then_some(|| unreachable!()),
767 async {},
768 ErrorHandler::Halt("test"),
769 );
770 (stream.leave(), tokens)
771 });
772
773 let probe = stream.probe();
774
775 (probe, token)
776 });
777
778 while probe.less_than(&expected_frontier) {
779 worker.step();
780 }
781
782 let mut probe_frontier = Antichain::new();
783 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
784
785 probe_frontier
786 });
787
788 assert_eq!(res, Antichain::from_elem(expected_frontier));
789 }
790
791 #[mz_ore::test(tokio::test(flavor = "multi_thread"))]
798 #[cfg_attr(miri, ignore)] async fn test_shard_source_explicit_initial_as_of() {
800 let persist_client = PersistClient::new_for_tests().await;
801
802 let expected_frontier = 42;
803 let shard_id = ShardId::new();
804
805 initialize_shard(
806 &persist_client,
807 shard_id,
808 Antichain::from_elem(expected_frontier),
809 )
810 .await;
811
812 let res = timely::execute::execute_directly(move |worker| {
813 let as_of = Antichain::from_elem(expected_frontier);
814 let until = Antichain::new();
815
816 let (probe, _token) = worker.dataflow::<u64, _, _>(|scope| {
817 let (stream, token) = scope.scoped::<u64, _, _>("hybrid", |scope| {
818 let transformer = move |_, descs: &Stream<_, _>, _| (descs.clone(), vec![]);
819 let (stream, tokens) = shard_source::<String, String, u64, u64, _, _, _>(
820 scope,
821 "test_source",
822 move || std::future::ready(persist_client.clone()),
823 shard_id,
824 Some(as_of), SnapshotMode::Include,
826 until,
827 Some(transformer),
828 Arc::new(
829 <std::string::String as mz_persist_types::Codec>::Schema::default(),
830 ),
831 Arc::new(
832 <std::string::String as mz_persist_types::Codec>::Schema::default(),
833 ),
834 FilterResult::keep_all,
835 false.then_some(|| unreachable!()),
836 async {},
837 ErrorHandler::Halt("test"),
838 );
839 (stream.leave(), tokens)
840 });
841
842 let probe = stream.probe();
843
844 (probe, token)
845 });
846
847 while probe.less_than(&expected_frontier) {
848 worker.step();
849 }
850
851 let mut probe_frontier = Antichain::new();
852 probe.with_frontier(|f| probe_frontier.extend(f.iter().cloned()));
853
854 probe_frontier
855 });
856
857 assert_eq!(res, Antichain::from_elem(expected_frontier));
858 }
859
860 async fn initialize_shard(
861 persist_client: &PersistClient,
862 shard_id: ShardId,
863 since: Antichain<u64>,
864 ) {
865 let mut read_handle = persist_client
866 .open_leased_reader::<String, String, u64, u64>(
867 shard_id,
868 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
869 Arc::new(<std::string::String as mz_persist_types::Codec>::Schema::default()),
870 Diagnostics::for_tests(),
871 true,
872 )
873 .await
874 .expect("invalid usage");
875
876 read_handle.downgrade_since(&since).await;
877 }
878}