1use differential_dataflow::consolidation::ConsolidatingContainerBuilder;
13use mz_dyncfg::ConfigSet;
14use std::convert::Infallible;
15use std::fmt::Debug;
16use std::future::Future;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::time::Instant;
20
21use differential_dataflow::lattice::Lattice;
22use futures::{StreamExt, future::Either};
23use mz_expr::{ColumnSpecs, Interpreter, MfpPlan, ResultSpec, UnmaterializableFunc};
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_persist_client::cache::PersistClientCache;
27use mz_persist_client::cfg::{PersistConfig, RetryParameters};
28use mz_persist_client::fetch::{ExchangeableBatchPart, ShardSourcePart};
29use mz_persist_client::fetch::{FetchedBlob, FetchedPart};
30use mz_persist_client::operators::shard_source::{
31    ErrorHandler, FilterResult, SnapshotMode, shard_source,
32};
33use mz_persist_client::stats::STATS_AUDIT_PANIC;
34use mz_persist_types::Codec64;
35use mz_persist_types::codec_impls::UnitSchema;
36use mz_persist_types::columnar::{ColumnEncoder, Schema};
37use mz_repr::{Datum, DatumVec, Diff, GlobalId, RelationDesc, Row, RowArena, Timestamp};
38use mz_storage_types::StorageDiff;
39use mz_storage_types::controller::{CollectionMetadata, TxnsCodecRow};
40use mz_storage_types::errors::DataflowError;
41use mz_storage_types::sources::SourceData;
42use mz_storage_types::stats::RelationPartStats;
43use mz_timely_util::builder_async::{
44    Event, OperatorBuilder as AsyncOperatorBuilder, PressOnDropButton,
45};
46use mz_timely_util::probe::ProbeNotify;
47use mz_txn_wal::operator::{TxnsContext, txns_progress};
48use serde::{Deserialize, Serialize};
49use timely::PartialOrder;
50use timely::communication::Push;
51use timely::dataflow::ScopeParent;
52use timely::dataflow::channels::Message;
53use timely::dataflow::channels::pact::Pipeline;
54use timely::dataflow::operators::generic::OutputHandleCore;
55use timely::dataflow::operators::generic::builder_rc::OperatorBuilder;
56use timely::dataflow::operators::{Capability, Leave, OkErr};
57use timely::dataflow::operators::{CapabilitySet, ConnectLoop, Feedback};
58use timely::dataflow::scopes::Child;
59use timely::dataflow::{Scope, Stream};
60use timely::order::TotalOrder;
61use timely::progress::Antichain;
62use timely::progress::Timestamp as TimelyTimestamp;
63use timely::progress::timestamp::PathSummary;
64use timely::scheduling::Activator;
65use tokio::sync::mpsc::UnboundedSender;
66use tracing::{error, trace};
67
68use crate::metrics::BackpressureMetrics;
69
70#[derive(
78    Copy, Clone, PartialEq, Default, Eq, PartialOrd, Ord, Debug, Serialize, Deserialize, Hash,
79)]
80pub struct Subtime(u64);
81
82impl PartialOrder for Subtime {
83    fn less_equal(&self, other: &Self) -> bool {
84        self.0.less_equal(&other.0)
85    }
86}
87
88impl TotalOrder for Subtime {}
89
90impl PathSummary<Subtime> for Subtime {
91    fn results_in(&self, src: &Subtime) -> Option<Subtime> {
92        self.0.results_in(&src.0).map(Subtime)
93    }
94
95    fn followed_by(&self, other: &Self) -> Option<Self> {
96        self.0.followed_by(&other.0).map(Subtime)
97    }
98}
99
100impl TimelyTimestamp for Subtime {
101    type Summary = Subtime;
102
103    fn minimum() -> Self {
104        Subtime(0)
105    }
106}
107
108impl Subtime {
109    pub const fn least_summary() -> Self {
111        Subtime(1)
112    }
113}
114
115pub fn persist_source<G>(
138    scope: &mut G,
139    source_id: GlobalId,
140    persist_clients: Arc<PersistClientCache>,
141    txns_ctx: &TxnsContext,
142    worker_dyncfgs: &ConfigSet,
145    metadata: CollectionMetadata,
146    read_schema: Option<RelationDesc>,
147    as_of: Option<Antichain<Timestamp>>,
148    snapshot_mode: SnapshotMode,
149    until: Antichain<Timestamp>,
150    map_filter_project: Option<&mut MfpPlan>,
151    max_inflight_bytes: Option<usize>,
152    start_signal: impl Future<Output = ()> + 'static,
153    error_handler: ErrorHandler,
154) -> (
155    Stream<G, (Row, Timestamp, Diff)>,
156    Stream<G, (DataflowError, Timestamp, Diff)>,
157    Vec<PressOnDropButton>,
158)
159where
160    G: Scope<Timestamp = mz_repr::Timestamp>,
161{
162    let shard_metrics = persist_clients.shard_metrics(&metadata.data_shard, &source_id.to_string());
163
164    let mut tokens = vec![];
165
166    let stream = scope.scoped(&format!("granular_backpressure({})", source_id), |scope| {
167        let (flow_control, flow_control_probe) = match max_inflight_bytes {
168            Some(max_inflight_bytes) => {
169                let backpressure_metrics = BackpressureMetrics {
170                    emitted_bytes: Arc::clone(&shard_metrics.backpressure_emitted_bytes),
171                    last_backpressured_bytes: Arc::clone(
172                        &shard_metrics.backpressure_last_backpressured_bytes,
173                    ),
174                    retired_bytes: Arc::clone(&shard_metrics.backpressure_retired_bytes),
175                };
176
177                let probe = mz_timely_util::probe::Handle::default();
178                let progress_stream = mz_timely_util::probe::source(
179                    scope.clone(),
180                    format!("decode_backpressure_probe({source_id})"),
181                    probe.clone(),
182                );
183                let flow_control = FlowControl {
184                    progress_stream,
185                    max_inflight_bytes,
186                    summary: (Default::default(), Subtime::least_summary()),
187                    metrics: Some(backpressure_metrics),
188                };
189                (Some(flow_control), Some(probe))
190            }
191            None => (None, None),
192        };
193
194        let cfg = Arc::clone(&persist_clients.cfg().configs);
200        let subscribe_sleep = match metadata.txns_shard {
201            Some(_) => Some(move || mz_txn_wal::operator::txns_data_shard_retry_params(&cfg)),
202            None => None,
203        };
204
205        let (stream, source_tokens) = persist_source_core(
206            scope,
207            source_id,
208            Arc::clone(&persist_clients),
209            metadata.clone(),
210            read_schema,
211            as_of.clone(),
212            snapshot_mode,
213            until.clone(),
214            map_filter_project,
215            flow_control,
216            subscribe_sleep,
217            start_signal,
218            error_handler,
219        );
220        tokens.extend(source_tokens);
221
222        let stream = match flow_control_probe {
223            Some(probe) => stream.probe_notify_with(vec![probe]),
224            None => stream,
225        };
226
227        stream.leave()
228    });
229
230    let (stream, txns_tokens) = match metadata.txns_shard {
235        Some(txns_shard) => txns_progress::<SourceData, (), Timestamp, i64, _, TxnsCodecRow, _, _>(
236            stream,
237            &source_id.to_string(),
238            txns_ctx,
239            worker_dyncfgs,
240            move || {
241                let (c, l) = (
242                    Arc::clone(&persist_clients),
243                    metadata.persist_location.clone(),
244                );
245                async move { c.open(l).await.expect("location is valid") }
246            },
247            txns_shard,
248            metadata.data_shard,
249            as_of
250                .expect("as_of is provided for table sources")
251                .into_option()
252                .expect("shard is not closed"),
253            until,
254            Arc::new(metadata.relation_desc),
255            Arc::new(UnitSchema),
256        ),
257        None => (stream, vec![]),
258    };
259    tokens.extend(txns_tokens);
260    let (ok_stream, err_stream) = stream.ok_err(|(d, t, r)| match d {
261        Ok(row) => Ok((row, t.0, r)),
262        Err(err) => Err((err, t.0, r)),
263    });
264    (ok_stream, err_stream, tokens)
265}
266
267type RefinedScope<'g, G> = Child<'g, G, (<G as ScopeParent>::Timestamp, Subtime)>;
268
269#[allow(clippy::needless_borrow)]
276pub fn persist_source_core<'g, G>(
277    scope: &RefinedScope<'g, G>,
278    source_id: GlobalId,
279    persist_clients: Arc<PersistClientCache>,
280    metadata: CollectionMetadata,
281    read_schema: Option<RelationDesc>,
282    as_of: Option<Antichain<Timestamp>>,
283    snapshot_mode: SnapshotMode,
284    until: Antichain<Timestamp>,
285    map_filter_project: Option<&mut MfpPlan>,
286    flow_control: Option<FlowControl<RefinedScope<'g, G>>>,
287    listen_sleep: Option<impl Fn() -> RetryParameters + 'static>,
289    start_signal: impl Future<Output = ()> + 'static,
290    error_handler: ErrorHandler,
291) -> (
292    Stream<
293        RefinedScope<'g, G>,
294        (
295            Result<Row, DataflowError>,
296            (mz_repr::Timestamp, Subtime),
297            Diff,
298        ),
299    >,
300    Vec<PressOnDropButton>,
301)
302where
303    G: Scope<Timestamp = mz_repr::Timestamp>,
304{
305    let cfg = persist_clients.cfg().clone();
306    let name = source_id.to_string();
307    let filter_plan = map_filter_project.as_ref().map(|p| (*p).clone());
308
309    let read_desc = match read_schema {
311        Some(desc) => desc,
312        None => metadata.relation_desc,
313    };
314
315    let desc_transformer = match flow_control {
316        Some(flow_control) => Some(move |mut scope: _, descs: &Stream<_, _>, chosen_worker| {
317            let (stream, token) = backpressure(
318                &mut scope,
319                &format!("backpressure({source_id})"),
320                descs,
321                flow_control,
322                chosen_worker,
323                None,
324            );
325            (stream, vec![token])
326        }),
327        None => None,
328    };
329
330    let metrics = Arc::clone(persist_clients.metrics());
331    let filter_name = name.clone();
332    let upper = until.as_option().cloned().unwrap_or(Timestamp::MAX);
336    let (fetched, token) = shard_source(
337        &mut scope.clone(),
338        &name,
339        move || {
340            let (c, l) = (
341                Arc::clone(&persist_clients),
342                metadata.persist_location.clone(),
343            );
344            async move { c.open(l).await.unwrap() }
345        },
346        metadata.data_shard,
347        as_of,
348        snapshot_mode,
349        until.clone(),
350        desc_transformer,
351        Arc::new(read_desc.clone()),
352        Arc::new(UnitSchema),
353        move |stats, frontier| {
354            let Some(lower) = frontier.as_option().copied() else {
355                return FilterResult::Discard;
358            };
359
360            if lower > upper {
361                return FilterResult::Discard;
364            }
365
366            let time_range =
367                ResultSpec::value_between(Datum::MzTimestamp(lower), Datum::MzTimestamp(upper));
368            if let Some(plan) = &filter_plan {
369                let metrics = &metrics.pushdown.part_stats;
370                let stats = RelationPartStats::new(&filter_name, metrics, &read_desc, stats);
371                filter_result(&read_desc, time_range, stats, plan)
372            } else {
373                FilterResult::Keep
374            }
375        },
376        listen_sleep,
377        start_signal,
378        error_handler,
379    );
380    let rows = decode_and_mfp(cfg, &fetched, &name, until, map_filter_project);
381    (rows, token)
382}
383
384fn filter_result(
385    relation_desc: &RelationDesc,
386    time_range: ResultSpec,
387    stats: RelationPartStats,
388    plan: &MfpPlan,
389) -> FilterResult {
390    let arena = RowArena::new();
391    let mut ranges = ColumnSpecs::new(relation_desc.typ(), &arena);
392    ranges.push_unmaterializable(UnmaterializableFunc::MzNow, time_range);
393
394    let may_error = stats.err_count().map_or(true, |count| count > 0);
395
396    for (pos, (idx, _, _)) in relation_desc.iter_all().enumerate() {
399        let result_spec = stats.col_stats(idx, &arena);
400        ranges.push_column(pos, result_spec);
401    }
402    let result = ranges.mfp_plan_filter(plan).range;
403    let may_error = may_error || result.may_fail();
404    let may_keep = result.may_contain(Datum::True);
405    let may_skip = result.may_contain(Datum::False) || result.may_contain(Datum::Null);
406    if relation_desc.len() == 0 && !may_error && !may_skip {
407        let Ok(mut key) = <RelationDesc as Schema<SourceData>>::encoder(relation_desc) else {
408            return FilterResult::Keep;
409        };
410        key.append(&SourceData(Ok(Row::default())));
411        let key = key.finish();
412        let Ok(mut val) = <UnitSchema as Schema<()>>::encoder(&UnitSchema) else {
413            return FilterResult::Keep;
414        };
415        val.append(&());
416        let val = val.finish();
417
418        FilterResult::ReplaceWith {
419            key: Arc::new(key),
420            val: Arc::new(val),
421        }
422    } else if may_error || may_keep {
423        FilterResult::Keep
424    } else {
425        FilterResult::Discard
426    }
427}
428
429pub fn decode_and_mfp<G>(
430    cfg: PersistConfig,
431    fetched: &Stream<G, FetchedBlob<SourceData, (), Timestamp, StorageDiff>>,
432    name: &str,
433    until: Antichain<Timestamp>,
434    mut map_filter_project: Option<&mut MfpPlan>,
435) -> Stream<G, (Result<Row, DataflowError>, G::Timestamp, Diff)>
436where
437    G: Scope<Timestamp = (mz_repr::Timestamp, Subtime)>,
438{
439    let scope = fetched.scope();
440    let mut builder = OperatorBuilder::new(
441        format!("persist_source::decode_and_mfp({})", name),
442        scope.clone(),
443    );
444    let operator_info = builder.operator_info();
445
446    let mut fetched_input = builder.new_input(fetched, Pipeline);
447    let (mut updates_output, updates_stream) =
448        builder.new_output::<ConsolidatingContainerBuilder<_>>();
449
450    let mut datum_vec = mz_repr::DatumVec::new();
452    let mut row_builder = Row::default();
453
454    let map_filter_project = map_filter_project.as_mut().map(|mfp| mfp.take());
456
457    builder.build(move |_caps| {
458        let name = name.to_owned();
459        let activations = scope.activations();
461        let activator = Activator::new(operator_info.address, activations);
462        let mut pending_work = std::collections::VecDeque::new();
464        let panic_on_audit_failure = STATS_AUDIT_PANIC.handle(&cfg);
465
466        move |_frontier| {
467            fetched_input.for_each(|time, data| {
468                let capability = time.retain();
469                for fetched_blob in data.drain(..) {
470                    pending_work.push_back(PendingWork {
471                        panic_on_audit_failure: panic_on_audit_failure.get(),
472                        capability: capability.clone(),
473                        part: PendingPart::Unparsed(fetched_blob),
474                    })
475                }
476            });
477
478            let yield_fuel = cfg.storage_source_decode_fuel();
481            let yield_fn = |_, work| work >= yield_fuel;
482
483            let mut work = 0;
484            let start_time = Instant::now();
485            let mut output = updates_output.activate();
486            while !pending_work.is_empty() && !yield_fn(start_time, work) {
487                let done = pending_work.front_mut().unwrap().do_work(
488                    &mut work,
489                    &name,
490                    start_time,
491                    yield_fn,
492                    &until,
493                    map_filter_project.as_ref(),
494                    &mut datum_vec,
495                    &mut row_builder,
496                    &mut output,
497                );
498                if done {
499                    pending_work.pop_front();
500                }
501            }
502            if !pending_work.is_empty() {
503                activator.activate();
504            }
505        }
506    });
507
508    updates_stream
509}
510
511struct PendingWork {
513    panic_on_audit_failure: bool,
515    capability: Capability<(mz_repr::Timestamp, Subtime)>,
517    part: PendingPart,
519}
520
521enum PendingPart {
522    Unparsed(FetchedBlob<SourceData, (), Timestamp, StorageDiff>),
523    Parsed {
524        part: ShardSourcePart<SourceData, (), Timestamp, StorageDiff>,
525    },
526}
527
528impl PendingPart {
529    fn part_mut(&mut self) -> &mut FetchedPart<SourceData, (), Timestamp, StorageDiff> {
536        match self {
537            PendingPart::Unparsed(x) => {
538                *self = PendingPart::Parsed { part: x.parse() };
539                self.part_mut()
541            }
542            PendingPart::Parsed { part } => &mut part.part,
543        }
544    }
545}
546
547impl PendingWork {
548    fn do_work<P, YFn>(
551        &mut self,
552        work: &mut usize,
553        name: &str,
554        start_time: Instant,
555        yield_fn: YFn,
556        until: &Antichain<Timestamp>,
557        map_filter_project: Option<&MfpPlan>,
558        datum_vec: &mut DatumVec,
559        row_builder: &mut Row,
560        output: &mut OutputHandleCore<
561            '_,
562            (mz_repr::Timestamp, Subtime),
563            ConsolidatingContainerBuilder<
564                Vec<(
565                    Result<Row, DataflowError>,
566                    (mz_repr::Timestamp, Subtime),
567                    Diff,
568                )>,
569            >,
570            P,
571        >,
572    ) -> bool
573    where
574        P: Push<
575            Message<
576                (mz_repr::Timestamp, Subtime),
577                Vec<(
578                    Result<Row, DataflowError>,
579                    (mz_repr::Timestamp, Subtime),
580                    Diff,
581                )>,
582            >,
583        >,
584        YFn: Fn(Instant, usize) -> bool,
585    {
586        let mut session = output.session_with_builder(&self.capability);
587        let fetched_part = self.part.part_mut();
588        let is_filter_pushdown_audit = fetched_part.is_filter_pushdown_audit();
589        let mut row_buf = None;
590        while let Some(((key, val), time, diff)) =
591            fetched_part.next_with_storage(&mut row_buf, &mut None)
592        {
593            if until.less_equal(&time) {
594                continue;
595            }
596            match (key, val) {
597                (Ok(SourceData(Ok(row))), Ok(())) => {
598                    if let Some(mfp) = map_filter_project {
599                        *work += 1;
606                        let arena = mz_repr::RowArena::new();
607                        let mut datums_local = datum_vec.borrow_with(&row);
608                        for result in mfp.evaluate(
609                            &mut datums_local,
610                            &arena,
611                            time,
612                            diff.into(),
613                            |time| !until.less_equal(time),
614                            row_builder,
615                        ) {
616                            if let Some(stats) = &is_filter_pushdown_audit {
620                                sentry::with_scope(
624                                    |scope| {
625                                        scope
626                                            .set_tag("alert_id", "persist_pushdown_audit_violation")
627                                    },
628                                    || {
629                                        error!(
630                                            ?stats,
631                                            name,
632                                            ?mfp,
633                                            ?result,
634                                            "persist filter pushdown correctness violation!"
635                                        );
636                                        if self.panic_on_audit_failure {
637                                            panic!(
638                                                "persist filter pushdown correctness violation! {}",
639                                                name
640                                            );
641                                        }
642                                    },
643                                );
644                            }
645                            match result {
646                                Ok((row, time, diff)) => {
647                                    if !until.less_equal(&time) {
649                                        let mut emit_time = *self.capability.time();
650                                        emit_time.0 = time;
651                                        session.give((Ok(row), emit_time, diff));
652                                        *work += 1;
653                                    }
654                                }
655                                Err((err, time, diff)) => {
656                                    if !until.less_equal(&time) {
658                                        let mut emit_time = *self.capability.time();
659                                        emit_time.0 = time;
660                                        session.give((Err(err), emit_time, diff));
661                                        *work += 1;
662                                    }
663                                }
664                            }
665                        }
666                        drop(datums_local);
670                        row_buf.replace(SourceData(Ok(row)));
671                    } else {
672                        let mut emit_time = *self.capability.time();
673                        emit_time.0 = time;
674                        session.give((Ok(row.clone()), emit_time, diff.into()));
676                        row_buf.replace(SourceData(Ok(row)));
677                        *work += 1;
678                    }
679                }
680                (Ok(SourceData(Err(err))), Ok(())) => {
681                    let mut emit_time = *self.capability.time();
682                    emit_time.0 = time;
683                    session.give((Err(err), emit_time, diff.into()));
684                    *work += 1;
685                }
686                (Err(_), Ok(_)) | (Ok(_), Err(_)) | (Err(_), Err(_)) => {
688                    panic!("decoding failed")
689                }
690            }
691            if yield_fn(start_time, *work) {
692                return false;
693            }
694        }
695        true
696    }
697}
698
699pub trait Backpressureable: Clone + 'static {
701    fn byte_size(&self) -> usize;
703}
704
705impl<T: Clone + 'static> Backpressureable for (usize, ExchangeableBatchPart<T>) {
706    fn byte_size(&self) -> usize {
707        self.1.encoded_size_bytes()
708    }
709}
710
711#[derive(Debug)]
713pub struct FlowControl<G: Scope> {
714    pub progress_stream: Stream<G, Infallible>,
720    pub max_inflight_bytes: usize,
722    pub summary: <G::Timestamp as TimelyTimestamp>::Summary,
725
726    pub metrics: Option<BackpressureMetrics>,
728}
729
730pub fn backpressure<T, G, O>(
743    scope: &mut G,
744    name: &str,
745    data: &Stream<G, O>,
746    flow_control: FlowControl<G>,
747    chosen_worker: usize,
748    probe: Option<UnboundedSender<(Antichain<(T, Subtime)>, usize, usize)>>,
750) -> (Stream<G, O>, PressOnDropButton)
751where
752    T: TimelyTimestamp + Lattice + Codec64 + TotalOrder,
753    G: Scope<Timestamp = (T, Subtime)>,
754    O: Backpressureable + std::fmt::Debug,
755{
756    let worker_index = scope.index();
757
758    let (flow_control_stream, flow_control_max_bytes, metrics) = (
759        flow_control.progress_stream,
760        flow_control.max_inflight_bytes,
761        flow_control.metrics,
762    );
763
764    let (handle, summaried_flow) = scope.feedback(flow_control.summary.clone());
769    flow_control_stream.connect_loop(handle);
770
771    let mut builder = AsyncOperatorBuilder::new(
772        format!("persist_source_backpressure({})", name),
773        scope.clone(),
774    );
775    let (data_output, data_stream) = builder.new_output();
776
777    let mut data_input = builder.new_disconnected_input(data, Pipeline);
778    let mut flow_control_input = builder.new_disconnected_input(&summaried_flow, Pipeline);
779
780    fn synthesize_frontiers<T: PartialOrder + Clone>(
782        mut frontier: Antichain<(T, Subtime)>,
783        mut time: (T, Subtime),
784        part_number: &mut u64,
785    ) -> (
786        (T, Subtime),
787        Antichain<(T, Subtime)>,
788        Antichain<(T, Subtime)>,
789    ) {
790        let mut next_frontier = frontier.clone();
791        time.1 = Subtime(*part_number);
792        frontier.insert(time.clone());
793        *part_number += 1;
794        let mut next_time = time.clone();
795        next_time.1 = Subtime(*part_number);
796        next_frontier.insert(next_time);
797        (time, frontier, next_frontier)
798    }
799
800    let data_input = async_stream::stream!({
803        let mut part_number = 0;
804        let mut parts: Vec<((T, Subtime), O)> = Vec::new();
805        loop {
806            match data_input.next().await {
807                None => {
808                    let empty = Antichain::new();
809                    parts.sort_by_key(|val| val.0.clone());
810                    for (part_time, d) in parts.drain(..) {
811                        let (part_time, frontier, next_frontier) = synthesize_frontiers(
812                            empty.clone(),
813                            part_time.clone(),
814                            &mut part_number,
815                        );
816                        yield Either::Right((part_time, d, frontier, next_frontier))
817                    }
818                    break;
819                }
820                Some(Event::Data(time, data)) => {
821                    for d in data {
822                        parts.push((time.clone(), d));
823                    }
824                }
825                Some(Event::Progress(prog)) => {
826                    parts.sort_by_key(|val| val.0.clone());
827                    for (part_time, d) in parts.extract_if(.., |p| !prog.less_equal(&p.0)) {
828                        let (part_time, frontier, next_frontier) =
829                            synthesize_frontiers(prog.clone(), part_time.clone(), &mut part_number);
830                        yield Either::Right((part_time, d, frontier, next_frontier))
831                    }
832                    yield Either::Left(prog)
833                }
834            }
835        }
836    });
837    let shutdown_button = builder.build(move |caps| async move {
838        let mut cap_set = CapabilitySet::from_elem(caps.into_element());
840
841        let mut output_frontier = Antichain::from_elem(TimelyTimestamp::minimum());
843        let mut flow_control_frontier = Antichain::from_elem(TimelyTimestamp::minimum());
845
846        let mut inflight_parts = Vec::new();
848        let mut pending_parts = std::collections::VecDeque::new();
850
851        if worker_index != chosen_worker {
853            trace!(
854                "We are not the chosen worker ({}), exiting...",
855                chosen_worker
856            );
857            return;
858        }
859        tokio::pin!(data_input);
860        'emitting_parts: loop {
861            let inflight_bytes: usize = inflight_parts.iter().map(|(_, size)| size).sum();
864
865            if inflight_bytes < flow_control_max_bytes
873                || !PartialOrder::less_equal(&flow_control_frontier, &output_frontier)
874            {
875                let (time, part, next_frontier) =
876                    if let Some((time, part, next_frontier)) = pending_parts.pop_front() {
877                        (time, part, next_frontier)
878                    } else {
879                        match data_input.next().await {
880                            Some(Either::Right((time, part, frontier, next_frontier))) => {
881                                output_frontier = frontier;
886                                cap_set.downgrade(output_frontier.iter());
887
888                                if inflight_bytes >= flow_control_max_bytes
893                                    && !PartialOrder::less_than(
894                                        &output_frontier,
895                                        &flow_control_frontier,
896                                    )
897                                {
898                                    pending_parts.push_back((time, part, next_frontier));
899                                    continue 'emitting_parts;
900                                }
901                                (time, part, next_frontier)
902                            }
903                            Some(Either::Left(prog)) => {
904                                output_frontier = prog;
905                                cap_set.downgrade(output_frontier.iter());
906                                continue 'emitting_parts;
907                            }
908                            None => {
909                                if pending_parts.is_empty() {
910                                    break 'emitting_parts;
911                                } else {
912                                    continue 'emitting_parts;
913                                }
914                            }
915                        }
916                    };
917
918                let byte_size = part.byte_size();
919                if let Some(emission_ts) = flow_control.summary.results_in(&time) {
929                    inflight_parts.push((emission_ts, byte_size));
930                }
931
932                data_output.give(&cap_set.delayed(&time), part);
935
936                if let Some(metrics) = &metrics {
937                    metrics.emitted_bytes.inc_by(u64::cast_from(byte_size))
938                }
939
940                output_frontier = next_frontier;
941                cap_set.downgrade(output_frontier.iter())
942            } else {
943                if let Some(metrics) = &metrics {
944                    metrics
945                        .last_backpressured_bytes
946                        .set(u64::cast_from(inflight_bytes))
947                }
948                let parts_count = inflight_parts.len();
949                let new_flow_control_frontier = match flow_control_input.next().await {
954                    Some(Event::Progress(frontier)) => frontier,
955                    Some(Event::Data(_, _)) => {
956                        unreachable!("flow_control_input should not contain data")
957                    }
958                    None => Antichain::new(),
959                };
960
961                flow_control_frontier.clone_from(&new_flow_control_frontier);
963
964                let retired_parts = inflight_parts
966                    .extract_if(.., |(ts, _size)| !flow_control_frontier.less_equal(ts));
967                let (retired_size, retired_count): (usize, usize) = retired_parts
968                    .fold((0, 0), |(accum_size, accum_count), (_ts, size)| {
969                        (accum_size + size, accum_count + 1)
970                    });
971                trace!(
972                    "returning {} parts with {} bytes, frontier: {:?}",
973                    retired_count, retired_size, flow_control_frontier,
974                );
975
976                if let Some(metrics) = &metrics {
977                    metrics.retired_bytes.inc_by(u64::cast_from(retired_size))
978                }
979
980                if let Some(probe) = probe.as_ref() {
982                    let _ = probe.send((new_flow_control_frontier, parts_count, retired_count));
983                }
984            }
985        }
986    });
987    (data_stream, shutdown_button.press_on_drop())
988}
989
990#[cfg(test)]
991mod tests {
992    use timely::container::CapacityContainerBuilder;
993    use timely::dataflow::operators::{Enter, Probe};
994    use tokio::sync::mpsc::unbounded_channel;
995    use tokio::sync::oneshot;
996
997    use super::*;
998
999    #[mz_ore::test]
1000    fn test_backpressure_non_granular() {
1001        use Step::*;
1002        backpressure_runner(
1003            vec![(50, Part(101)), (50, Part(102)), (100, Part(1))],
1004            100,
1005            (1, Subtime(0)),
1006            vec![
1007                AssertOutputFrontier((50, Subtime(2))),
1010                AssertBackpressured {
1011                    frontier: (1, Subtime(0)),
1012                    inflight_parts: 1,
1013                    retired_parts: 0,
1014                },
1015                AssertBackpressured {
1016                    frontier: (51, Subtime(0)),
1017                    inflight_parts: 1,
1018                    retired_parts: 0,
1019                },
1020                ProcessXParts(2),
1021                AssertBackpressured {
1022                    frontier: (101, Subtime(0)),
1023                    inflight_parts: 2,
1024                    retired_parts: 2,
1025                },
1026                AssertOutputFrontier((100, Subtime(3))),
1029            ],
1030            true,
1031        );
1032
1033        backpressure_runner(
1034            vec![
1035                (50, Part(10)),
1036                (50, Part(10)),
1037                (51, Part(100)),
1038                (52, Part(1000)),
1039            ],
1040            50,
1041            (1, Subtime(0)),
1042            vec![
1043                AssertOutputFrontier((51, Subtime(3))),
1045                AssertBackpressured {
1046                    frontier: (1, Subtime(0)),
1047                    inflight_parts: 3,
1048                    retired_parts: 0,
1049                },
1050                ProcessXParts(3),
1051                AssertBackpressured {
1052                    frontier: (52, Subtime(0)),
1053                    inflight_parts: 3,
1054                    retired_parts: 2,
1055                },
1056                AssertBackpressured {
1057                    frontier: (53, Subtime(0)),
1058                    inflight_parts: 1,
1059                    retired_parts: 1,
1060                },
1061                AssertOutputFrontier((52, Subtime(4))),
1064            ],
1065            true,
1066        );
1067
1068        backpressure_runner(
1069            vec![
1070                (50, Part(98)),
1071                (50, Part(1)),
1072                (51, Part(10)),
1073                (52, Part(100)),
1074                (52, Part(10)),
1076                (52, Part(10)),
1077                (52, Part(10)),
1078                (52, Part(100)),
1079                (100, Part(100)),
1081            ],
1082            100,
1083            (1, Subtime(0)),
1084            vec![
1085                AssertOutputFrontier((51, Subtime(3))),
1086                AssertBackpressured {
1090                    frontier: (1, Subtime(0)),
1091                    inflight_parts: 3,
1092                    retired_parts: 0,
1093                },
1094                AssertBackpressured {
1095                    frontier: (51, Subtime(0)),
1096                    inflight_parts: 3,
1097                    retired_parts: 0,
1098                },
1099                ProcessXParts(1),
1100                AssertOutputFrontier((51, Subtime(3))),
1103                ProcessXParts(1),
1107                AssertOutputFrontier((52, Subtime(4))),
1108                AssertBackpressured {
1109                    frontier: (52, Subtime(0)),
1110                    inflight_parts: 3,
1111                    retired_parts: 2,
1112                },
1113                ProcessXParts(1),
1117                AssertBackpressured {
1121                    frontier: (53, Subtime(0)),
1122                    inflight_parts: 2,
1123                    retired_parts: 1,
1124                },
1125                ProcessXParts(5),
1127                AssertBackpressured {
1128                    frontier: (101, Subtime(0)),
1129                    inflight_parts: 5,
1130                    retired_parts: 5,
1131                },
1132                AssertOutputFrontier((100, Subtime(9))),
1133            ],
1134            true,
1135        );
1136    }
1137
1138    #[mz_ore::test]
1139    fn test_backpressure_granular() {
1140        use Step::*;
1141        backpressure_runner(
1142            vec![(50, Part(101)), (50, Part(101))],
1143            100,
1144            (0, Subtime(1)),
1145            vec![
1146                AssertOutputFrontier((50, Subtime(1))),
1148                AssertBackpressured {
1151                    frontier: (0, Subtime(1)),
1152                    inflight_parts: 1,
1153                    retired_parts: 0,
1154                },
1155                AssertBackpressured {
1156                    frontier: (50, Subtime(1)),
1157                    inflight_parts: 1,
1158                    retired_parts: 0,
1159                },
1160                ProcessXParts(1),
1162                AssertBackpressured {
1164                    frontier: (50, Subtime(2)),
1165                    inflight_parts: 1,
1166                    retired_parts: 1,
1167                },
1168                AssertOutputFrontier((50, Subtime(2))),
1170            ],
1171            false,
1172        );
1173
1174        backpressure_runner(
1175            vec![
1176                (50, Part(10)),
1177                (50, Part(10)),
1178                (51, Part(35)),
1179                (52, Part(100)),
1180            ],
1181            50,
1182            (0, Subtime(1)),
1183            vec![
1184                AssertOutputFrontier((51, Subtime(3))),
1186                AssertBackpressured {
1187                    frontier: (0, Subtime(1)),
1188                    inflight_parts: 3,
1189                    retired_parts: 0,
1190                },
1191                AssertBackpressured {
1192                    frontier: (50, Subtime(1)),
1193                    inflight_parts: 3,
1194                    retired_parts: 0,
1195                },
1196                ProcessXParts(1),
1198                AssertBackpressured {
1199                    frontier: (50, Subtime(2)),
1200                    inflight_parts: 3,
1201                    retired_parts: 1,
1202                },
1203                AssertOutputFrontier((52, Subtime(4))),
1206                ProcessXParts(2),
1207                AssertBackpressured {
1208                    frontier: (52, Subtime(4)),
1209                    inflight_parts: 3,
1210                    retired_parts: 2,
1211                },
1212            ],
1213            false,
1214        );
1215    }
1216
1217    type Time = (u64, Subtime);
1218    #[derive(Clone, Debug)]
1219    struct Part(usize);
1220    impl Backpressureable for Part {
1221        fn byte_size(&self) -> usize {
1222            self.0
1223        }
1224    }
1225
1226    enum Step {
1228        AssertOutputFrontier(Time),
1231        AssertBackpressured {
1235            frontier: Time,
1236            inflight_parts: usize,
1237            retired_parts: usize,
1238        },
1239        ProcessXParts(usize),
1241    }
1242
1243    fn backpressure_runner(
1245        input: Vec<(u64, Part)>,
1247        max_inflight_bytes: usize,
1249        summary: Time,
1251        steps: Vec<Step>,
1253        non_granular_consumer: bool,
1256    ) {
1257        timely::execute::execute_directly(move |worker| {
1258            let (backpressure_probe, consumer_tx, mut backpressure_status_rx, finalizer_tx, _token) =
1259                worker.dataflow::<u64, _, _>(|scope| {
1261                    let (non_granular_feedback_handle, non_granular_feedback) =
1262                        if non_granular_consumer {
1263                            let (h, f) = scope.feedback(Default::default());
1264                            (Some(h), Some(f))
1265                        } else {
1266                            (None, None)
1267                        };
1268                    let (
1269                        backpressure_probe,
1270                        consumer_tx,
1271                        backpressure_status_rx,
1272                        token,
1273                        backpressured,
1274                        finalizer_tx,
1275                    ) = scope.scoped::<(u64, Subtime), _, _>("hybrid", |scope| {
1276                        let (input, finalizer_tx) =
1277                            iterator_operator(scope.clone(), input.into_iter());
1278
1279                        let (flow_control, granular_feedback_handle) = if non_granular_consumer {
1280                            (
1281                                FlowControl {
1282                                    progress_stream: non_granular_feedback.unwrap().enter(scope),
1283                                    max_inflight_bytes,
1284                                    summary,
1285                                    metrics: None
1286                                },
1287                                None,
1288                            )
1289                        } else {
1290                            let (granular_feedback_handle, granular_feedback) =
1291                                scope.feedback(Default::default());
1292                            (
1293                                FlowControl {
1294                                    progress_stream: granular_feedback,
1295                                    max_inflight_bytes,
1296                                    summary,
1297                                    metrics: None,
1298                                },
1299                                Some(granular_feedback_handle),
1300                            )
1301                        };
1302
1303                        let (backpressure_status_tx, backpressure_status_rx) = unbounded_channel();
1304
1305                        let (backpressured, token) = backpressure(
1306                            scope,
1307                            "test",
1308                            &input,
1309                            flow_control,
1310                            0,
1311                            Some(backpressure_status_tx),
1312                        );
1313
1314                        let tx = if !non_granular_consumer {
1316                            Some(consumer_operator(
1317                                scope.clone(),
1318                                &backpressured,
1319                                granular_feedback_handle.unwrap(),
1320                            ))
1321                        } else {
1322                            None
1323                        };
1324
1325                        (
1326                            backpressured.probe(),
1327                            tx,
1328                            backpressure_status_rx,
1329                            token,
1330                            backpressured.leave(),
1331                            finalizer_tx,
1332                        )
1333                    });
1334
1335                    let consumer_tx = if non_granular_consumer {
1337                        consumer_operator(
1338                            scope.clone(),
1339                            &backpressured,
1340                            non_granular_feedback_handle.unwrap(),
1341                        )
1342                    } else {
1343                        consumer_tx.unwrap()
1344                    };
1345
1346                    (
1347                        backpressure_probe,
1348                        consumer_tx,
1349                        backpressure_status_rx,
1350                        finalizer_tx,
1351                        token,
1352                    )
1353                });
1354
1355            use Step::*;
1356            for step in steps {
1357                match step {
1358                    AssertOutputFrontier(time) => {
1359                        eprintln!("checking advance to {time:?}");
1360                        backpressure_probe.with_frontier(|front| {
1361                            eprintln!("current backpressure output frontier: {front:?}");
1362                        });
1363                        while backpressure_probe.less_than(&time) {
1364                            worker.step();
1365                            backpressure_probe.with_frontier(|front| {
1366                                eprintln!("current backpressure output frontier: {front:?}");
1367                            });
1368                            std::thread::sleep(std::time::Duration::from_millis(25));
1369                        }
1370                    }
1371                    ProcessXParts(parts) => {
1372                        eprintln!("processing {parts:?} parts");
1373                        for _ in 0..parts {
1374                            consumer_tx.send(()).unwrap();
1375                        }
1376                    }
1377                    AssertBackpressured {
1378                        frontier,
1379                        inflight_parts,
1380                        retired_parts,
1381                    } => {
1382                        let frontier = Antichain::from_elem(frontier);
1383                        eprintln!(
1384                            "asserting backpressured at {frontier:?}, with {inflight_parts:?} inflight parts \
1385                            and {retired_parts:?} retired"
1386                        );
1387                        let (new_frontier, new_count, new_retired_count) = loop {
1388                            if let Ok(val) = backpressure_status_rx.try_recv() {
1389                                break val;
1390                            }
1391                            worker.step();
1392                            std::thread::sleep(std::time::Duration::from_millis(25));
1393                        };
1394                        assert_eq!(
1395                            (frontier, inflight_parts, retired_parts),
1396                            (new_frontier, new_count, new_retired_count)
1397                        );
1398                    }
1399                }
1400            }
1401            let _ = finalizer_tx.send(());
1403        });
1404    }
1405
1406    fn iterator_operator<
1409        G: Scope<Timestamp = (u64, Subtime)>,
1410        I: Iterator<Item = (u64, Part)> + 'static,
1411    >(
1412        scope: G,
1413        mut input: I,
1414    ) -> (Stream<G, Part>, oneshot::Sender<()>) {
1415        let (finalizer_tx, finalizer_rx) = oneshot::channel();
1416        let mut iterator = AsyncOperatorBuilder::new("iterator".to_string(), scope);
1417        let (output_handle, output) = iterator.new_output::<CapacityContainerBuilder<Vec<Part>>>();
1418
1419        iterator.build(|mut caps| async move {
1420            let mut capability = Some(caps.pop().unwrap());
1421            let mut last = None;
1422            while let Some(element) = input.next() {
1423                let time = element.0.clone();
1424                let part = element.1;
1425                last = Some((time, Subtime(0)));
1426                output_handle.give(&capability.as_ref().unwrap().delayed(&last.unwrap()), part);
1427            }
1428            if let Some(last) = last {
1429                capability
1430                    .as_mut()
1431                    .unwrap()
1432                    .downgrade(&(last.0 + 1, last.1));
1433            }
1434
1435            let _ = finalizer_rx.await;
1436            capability.take();
1437        });
1438
1439        (output, finalizer_tx)
1440    }
1441
1442    fn consumer_operator<G: Scope, O: Backpressureable + std::fmt::Debug>(
1446        scope: G,
1447        input: &Stream<G, O>,
1448        feedback: timely::dataflow::operators::feedback::Handle<G, Vec<std::convert::Infallible>>,
1449    ) -> UnboundedSender<()> {
1450        let (tx, mut rx) = unbounded_channel::<()>();
1451        let mut consumer = AsyncOperatorBuilder::new("consumer".to_string(), scope);
1452        let (output_handle, output) =
1453            consumer.new_output::<CapacityContainerBuilder<Vec<std::convert::Infallible>>>();
1454        let mut input = consumer.new_input_for(input, Pipeline, &output_handle);
1455
1456        consumer.build(|_caps| async move {
1457            while let Some(()) = rx.recv().await {
1458                while let Some(Event::Progress(_)) = input.next().await {}
1460            }
1461        });
1462        output.connect_loop(feedback);
1463
1464        tx
1465    }
1466}