mz_transform/redundant_join.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31use mz_ore::{assert_none, soft_panic_or_log};
32
33use crate::{TransformCtx, all};
34
35/// Remove redundant collections of distinct elements from joins.
36#[derive(Debug)]
37pub struct RedundantJoin {
38 recursion_guard: RecursionGuard,
39}
40
41impl Default for RedundantJoin {
42 fn default() -> RedundantJoin {
43 RedundantJoin {
44 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
45 }
46 }
47}
48
49impl CheckedRecursion for RedundantJoin {
50 fn recursion_guard(&self) -> &RecursionGuard {
51 &self.recursion_guard
52 }
53}
54
55impl crate::Transform for RedundantJoin {
56 fn name(&self) -> &'static str {
57 "RedundantJoin"
58 }
59
60 #[mz_ore::instrument(
61 target = "optimizer",
62 level = "debug",
63 fields(path.segment = "redundant_join")
64 )]
65 fn actually_perform_transform(
66 &self,
67 relation: &mut MirRelationExpr,
68 _: &mut TransformCtx,
69 ) -> Result<(), crate::TransformError> {
70 let mut ctx = ProvInfoCtx::default();
71 ctx.extend_uses(relation);
72 let result = self.action(relation, &mut ctx);
73 mz_repr::explain::trace_plan(&*relation);
74 result.map(|_| ())
75 }
76}
77
78impl RedundantJoin {
79 /// Remove redundant collections of distinct elements from joins.
80 ///
81 /// This method tracks "provenance" information for each collections,
82 /// those being column-wise relationships to identified collections
83 /// (either imported collections, or let-bound collections). These
84 /// relationships state that when projected on to these columns, the
85 /// records of the one collection are contained in the records of the
86 /// identified collection.
87 ///
88 /// This provenance information is then used for the `MirRelationExpr::Join`
89 /// variant to remove "redundant" joins, those that can be determined to
90 /// neither restrict nor augment one of the input relations. Consult the
91 /// `find_redundancy` method and its documentation for more detail.
92 pub fn action(
93 &self,
94 relation: &mut MirRelationExpr,
95 ctx: &mut ProvInfoCtx,
96 ) -> Result<Vec<ProvInfo>, crate::TransformError> {
97 let mut result = self.checked_recur(|_| {
98 match relation {
99 MirRelationExpr::Let { id, value, body } => {
100 // Recursively determine provenance of the value.
101 let value_prov = self.action(value, ctx)?;
102 // Clear uses from the just visited binding definition.
103 ctx.remove_uses(value);
104
105 // Extend the lets context with an entry for this binding.
106 let prov_old = ctx.insert(*id, value_prov);
107 assert_none!(prov_old, "No shadowing");
108
109 // Determine provenance of the body.
110 let result = self.action(body, ctx)?;
111 ctx.remove_uses(body);
112
113 // Remove the lets entry for this binding from the context.
114 ctx.remove(id);
115
116 Ok(result)
117 }
118
119 MirRelationExpr::LetRec {
120 ids,
121 values,
122 limits: _,
123 body,
124 } => {
125 // As a first approximation, we naively extend the `lets`
126 // context with the empty vec![] for each id.
127 for id in ids.iter() {
128 let prov_old = ctx.insert(*id, vec![]);
129 assert_none!(prov_old, "No shadowing");
130 }
131
132 // In other words, we don't attempt to derive additional
133 // provenance information for a binding from its `value`.
134 //
135 // We descend into the values and the body with the naively
136 // extended context.
137 for value in values.iter_mut() {
138 self.action(value, ctx)?;
139 }
140 // Clear uses from the just visited recursive binding
141 // definitions.
142 for value in values.iter_mut() {
143 ctx.remove_uses(value);
144 }
145 let result = self.action(body, ctx)?;
146 ctx.remove_uses(body);
147
148 // Remove the lets entries for all ids.
149 for id in ids.iter() {
150 ctx.remove(id);
151 }
152
153 Ok(result)
154 }
155
156 MirRelationExpr::Get { id, typ, .. } => {
157 if let Id::Local(id) = id {
158 // Extract the value provenance (this should always exist).
159 let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
160 soft_panic_or_log!("no ctx entry for LocalId {id}");
161 vec![]
162 });
163 // Add information about being exactly this let binding too.
164 val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
165 Ok(val_info)
166 } else {
167 // Add information about being exactly this GlobalId reference.
168 Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
169 }
170 }
171
172 MirRelationExpr::Join {
173 inputs,
174 equivalences,
175 implementation,
176 } => {
177 // This logic first applies what it has learned about its input provenance,
178 // and if it finds a redundant join input it removes it. In that case, it
179 // also fails to produce exciting provenance information, partly out of
180 // laziness and the challenge of ensuring it is correct. Instead, if it is
181 // unable to find a redundant join it produces meaningful provenance information.
182
183 // Recursively apply transformation, and determine the provenance of inputs.
184 let mut input_prov = Vec::new();
185 for i in inputs.iter_mut() {
186 input_prov.push(self.action(i, ctx)?);
187 }
188
189 // Determine useful information about the structure of the inputs.
190 let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
191 let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
192
193 // If we find an input that can be removed, we should do so!
194 // We only do this once per invocation to keep our sanity, but we could
195 // rewrite it to iterate. We can avoid looking for any relation that
196 // does not have keys, as it cannot be redundant in that case.
197 if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
198 .rev()
199 .filter(|i| !input_types[*i].keys.is_empty())
200 .flat_map(|i| {
201 find_redundancy(
202 i,
203 &input_types[i].keys,
204 &old_input_mapper,
205 equivalences,
206 &input_prov[..],
207 )
208 .map(|b| (i, b))
209 })
210 .next()
211 {
212 // Clear uses from the removed input.
213 ctx.remove_uses(&inputs[remove_input_idx]);
214
215 inputs.remove(remove_input_idx);
216 input_types.remove(remove_input_idx);
217
218 // Update the column offsets in the binding expressions to catch
219 // up with the removal of `remove_input_idx`.
220 for expr in bindings.iter_mut() {
221 expr.visit_pre_mut(|e| {
222 if let MirScalarExpr::Column(c, _) = e {
223 let (_local_col, input_relation) =
224 old_input_mapper.map_column_to_local(*c);
225 if input_relation > remove_input_idx {
226 *c -= old_input_mapper.input_arity(remove_input_idx);
227 }
228 }
229 });
230 }
231
232 // Replace column references from `remove_input_idx` with the corresponding
233 // binding expression. Update the offsets of the column references
234 // from inputs after `remove_input_idx`.
235 for equivalence in equivalences.iter_mut() {
236 for expr in equivalence.iter_mut() {
237 expr.visit_mut_post(&mut |e| {
238 if let MirScalarExpr::Column(c, _) = e {
239 let (local_col, input_relation) =
240 old_input_mapper.map_column_to_local(*c);
241 if input_relation == remove_input_idx {
242 *e = bindings[local_col].clone();
243 } else if input_relation > remove_input_idx {
244 *c -= old_input_mapper.input_arity(remove_input_idx);
245 }
246 }
247 })?;
248 }
249 }
250
251 mz_expr::canonicalize::canonicalize_equivalences(
252 equivalences,
253 input_types.iter().map(|t| &t.column_types),
254 );
255
256 // Build a projection that leaves the binding expressions in the same
257 // position as the columns of the removed join input they are replacing.
258 let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
259 let mut projection = Vec::new();
260 let new_join_arity = new_input_mapper.total_columns();
261 for i in 0..old_input_mapper.total_inputs() {
262 if i != remove_input_idx {
263 projection.extend(
264 new_input_mapper.global_columns(if i < remove_input_idx {
265 i
266 } else {
267 i - 1
268 }),
269 );
270 } else {
271 projection.extend(new_join_arity..new_join_arity + bindings.len());
272 }
273 }
274
275 // Unset implementation, as irrevocably hosed by this transformation.
276 *implementation = mz_expr::JoinImplementation::Unimplemented;
277
278 *relation = relation.take_dangerous().map(bindings).project(projection);
279 // The projection will gum up provenance reasoning anyhow, so don't work hard.
280 // We will return to this expression again with the same analysis.
281 Ok(Vec::new())
282 } else {
283 // Provenance information should be the union of input provenance information,
284 // with columns updated. Because rows may be dropped in the join, all `exact`
285 // bits should be un-set.
286 let mut results = Vec::new();
287 for (input, input_prov) in input_prov.into_iter().enumerate() {
288 for mut prov in input_prov {
289 prov.exact = false;
290 let mut projection = vec![None; old_input_mapper.total_columns()];
291 for (local_col, global_col) in
292 old_input_mapper.global_columns(input).enumerate()
293 {
294 projection[global_col]
295 .clone_from(&prov.dereferenced_projection[local_col]);
296 }
297 prov.dereferenced_projection = projection;
298 results.push(prov);
299 }
300 }
301 Ok(results)
302 }
303 }
304
305 MirRelationExpr::Filter { input, .. } => {
306 // Filter may drop records, and so we unset `exact`.
307 let mut result = self.action(input, ctx)?;
308 for prov in result.iter_mut() {
309 prov.exact = false;
310 }
311 Ok(result)
312 }
313
314 MirRelationExpr::Map { input, scalars } => {
315 let mut result = self.action(input, ctx)?;
316 for prov in result.iter_mut() {
317 for scalar in scalars.iter() {
318 let dereferenced_scalar = prov.strict_dereference(scalar);
319 prov.dereferenced_projection.push(dereferenced_scalar);
320 }
321 }
322 Ok(result)
323 }
324
325 MirRelationExpr::Union { base, inputs } => {
326 let mut prov = self.action(base, ctx)?;
327 for input in inputs {
328 let input_prov = self.action(input, ctx)?;
329 // To merge a new list of provenances, we look at the cross
330 // produce of things we might know about each source.
331 // TODO(mcsherry): this can be optimized to use datastructures
332 // keyed by the source identifier.
333 let mut new_prov = Vec::new();
334 for l in prov {
335 new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
336 }
337 prov = new_prov;
338 }
339 Ok(prov)
340 }
341
342 MirRelationExpr::Constant { .. } => Ok(Vec::new()),
343
344 MirRelationExpr::Reduce {
345 input,
346 group_key,
347 aggregates,
348 ..
349 } => {
350 // Reduce yields its first few columns as a key, and produces
351 // all key tuples that were present in its input.
352 let mut result = self.action(input, ctx)?;
353 for prov in result.iter_mut() {
354 let mut projection = group_key
355 .iter()
356 .map(|key| prov.strict_dereference(key))
357 .collect_vec();
358 projection.extend((0..aggregates.len()).map(|_| None));
359 prov.dereferenced_projection = projection;
360 }
361 // TODO: For min, max aggregates, we could preserve provenance
362 // if the expression references a column. We would need to un-set
363 // the `exact` bit in that case, and so we would want to keep both
364 // sets of provenance information.
365 Ok(result)
366 }
367
368 MirRelationExpr::Threshold { input } => {
369 // Threshold may drop records, and so we unset `exact`.
370 let mut result = self.action(input, ctx)?;
371 for prov in result.iter_mut() {
372 prov.exact = false;
373 }
374 Ok(result)
375 }
376
377 MirRelationExpr::TopK { input, .. } => {
378 // TopK may drop records, and so we unset `exact`.
379 let mut result = self.action(input, ctx)?;
380 for prov in result.iter_mut() {
381 prov.exact = false;
382 }
383 Ok(result)
384 }
385
386 MirRelationExpr::Project { input, outputs } => {
387 // Projections re-order, drop, and duplicate columns,
388 // but they neither drop rows nor invent values.
389 let mut result = self.action(input, ctx)?;
390 for prov in result.iter_mut() {
391 let projection = outputs
392 .iter()
393 .map(|c| prov.dereference(&MirScalarExpr::column(*c)))
394 .collect_vec();
395 prov.dereferenced_projection = projection;
396 }
397 Ok(result)
398 }
399
400 MirRelationExpr::FlatMap {
401 input,
402 func,
403 exprs: _,
404 } => {
405 // FlatMap may drop records, and so we unset `exact`.
406 let mut result = self.action(input, ctx)?;
407 for prov in result.iter_mut() {
408 prov.exact = false;
409 prov.dereferenced_projection
410 .extend((0..func.output_arity()).map(|_| None));
411 }
412 Ok(result)
413 }
414
415 MirRelationExpr::Negate { input } => {
416 // Negate does not guarantee that the multiplicity of
417 // each source record it at least one. This could have
418 // been a problem in `Union`, where we might report
419 // that the union of positive and negative records is
420 // "exact": cancellations would make this false.
421 let mut result = self.action(input, ctx)?;
422 for prov in result.iter_mut() {
423 prov.exact = false;
424 }
425 Ok(result)
426 }
427
428 MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
429 }
430 })?;
431 result.retain(|info| !info.is_trivial());
432
433 // Uncomment the following lines to trace the individual steps:
434 // println!("{}", relation.pretty());
435 // println!("result = {result:?}");
436 // println!("lets: {lets:?}");
437 // println!("---------------------");
438
439 Ok(result)
440 }
441}
442
443/// A relationship between a collections columns and some source columns.
444///
445/// An instance of this type indicates that some of the bearer's columns
446/// derive from `id`. In particular, the non-`None` elements in
447/// `dereferenced_projection` correspond to columns that can be derived
448/// from `id`'s projection.
449///
450/// The guarantee is that projected on to these columns, the distinct values
451/// of the bearer are contained in the set of distinct values of projected
452/// columns of `id`. In the case that `exact` is set, the two sets are equal.
453#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
454pub struct ProvInfo {
455 /// The Id (local or global) of the source.
456 id: Id,
457 /// The projection of the bearer written in terms of the columns projected
458 /// by the underlying Get operator. Set to `None` for columns that cannot
459 /// be expressed as scalar expression referencing only columns of the
460 /// underlying Get operator.
461 dereferenced_projection: Vec<Option<MirScalarExpr>>,
462 /// If true, all distinct projected source rows are present in the rows of
463 /// the projection of the current collection. This constraint is lost as soon
464 /// as a transformation may drop records.
465 exact: bool,
466}
467
468impl ProvInfo {
469 fn make_leaf(id: Id, arity: usize) -> Self {
470 Self {
471 id,
472 dereferenced_projection: (0..arity)
473 .map(|c| Some(MirScalarExpr::column(c)))
474 .collect::<Vec<_>>(),
475 exact: true,
476 }
477 }
478
479 /// Rewrite `expr` so it refers to the columns of the original source instead
480 /// of the columns of the projected source.
481 fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
482 match expr {
483 MirScalarExpr::Column(c, _) => {
484 if let Some(expr) = &self.dereferenced_projection[*c] {
485 Some(expr.clone())
486 } else {
487 None
488 }
489 }
490 MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
491 Some(MirScalarExpr::CallUnary {
492 func: func.clone(),
493 expr: Box::new(expr),
494 })
495 }),
496 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
497 self.dereference(expr1).and_then(|expr1| {
498 self.dereference(expr2).and_then(|expr2| {
499 Some(MirScalarExpr::CallBinary {
500 func: func.clone(),
501 expr1: Box::new(expr1),
502 expr2: Box::new(expr2),
503 })
504 })
505 })
506 }
507 MirScalarExpr::CallVariadic { func, exprs } => {
508 let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
509 if new_exprs.len() == exprs.len() {
510 Some(MirScalarExpr::CallVariadic {
511 func: func.clone(),
512 exprs: new_exprs,
513 })
514 } else {
515 None
516 }
517 }
518 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
519 Some(expr.clone())
520 }
521 MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
522 self.dereference(then).and_then(|then| {
523 self.dereference(els).and_then(|els| {
524 Some(MirScalarExpr::If {
525 cond: Box::new(cond),
526 then: Box::new(then),
527 els: Box::new(els),
528 })
529 })
530 })
531 }),
532 }
533 }
534
535 /// Like `dereference` but only returns expressions that actually depend on
536 /// the original source.
537 fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
538 let derefed = self.dereference(expr);
539 match derefed {
540 Some(ref expr) if !expr.support().is_empty() => derefed,
541 _ => None,
542 }
543 }
544
545 /// Merge two constraints to find a constraint that satisfies both inputs.
546 ///
547 /// This method returns nothing if no columns are in common (either because
548 /// difference sources are identified, or just no columns in common) and it
549 /// intersects bindings and the `exact` bit.
550 fn meet(&self, other: &Self) -> Option<Self> {
551 if self.id == other.id {
552 let resulting_projection = self
553 .dereferenced_projection
554 .iter()
555 .zip(other.dereferenced_projection.iter())
556 .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
557 .collect_vec();
558 if resulting_projection.iter().any(|e| e.is_some()) {
559 Some(ProvInfo {
560 id: self.id,
561 dereferenced_projection: resulting_projection,
562 exact: self.exact && other.exact,
563 })
564 } else {
565 None
566 }
567 } else {
568 None
569 }
570 }
571
572 /// Check if all entries of the dereferenced projection are missing.
573 ///
574 /// If this is the case keeping the `ProvInfo` entry around is meaningless.
575 fn is_trivial(&self) -> bool {
576 all![
577 !self.dereferenced_projection.is_empty(),
578 self.dereferenced_projection.iter().all(|x| x.is_none()),
579 ]
580 }
581}
582
583/// Attempts to find column bindings that make `input` redundant.
584///
585/// This method attempts to determine that `input` may be redundant by searching
586/// the join structure for another relation `other` with provenance that contains some
587/// provenance of `input`, and keys for `input` that are equated by the join to the
588/// corresponding columns of `other` under their provenance. The `input` provenance
589/// must also have its `exact` bit set.
590///
591/// In these circumstances, the claim is that because the key columns are equated and
592/// determine non-key columns, any matches between `input` and
593/// `other` will neither introduce new information to `other`, nor restrict the rows
594/// of `other`, nor alter their multplicity.
595fn find_redundancy(
596 input: usize,
597 keys: &[Vec<usize>],
598 input_mapper: &JoinInputMapper,
599 equivalences: &[Vec<MirScalarExpr>],
600 input_provs: &[Vec<ProvInfo>],
601) -> Option<Vec<MirScalarExpr>> {
602 // Whether the `equivalence` contains an expression that only references
603 // `input` that leads to the same as `root_expr` once dereferenced.
604 let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
605 root_expr: &MirScalarExpr,
606 input: usize,
607 provenance: &ProvInfo|
608 -> bool {
609 equivalence.iter().any(|expr| {
610 Some(input) == input_mapper.single_input(expr)
611 && provenance
612 .dereference(&input_mapper.map_expr_to_local(expr.clone()))
613 .as_ref()
614 == Some(root_expr)
615 })
616 };
617 for input_prov in input_provs[input].iter() {
618 // We can only elide if the input contains all records, and binds all columns.
619 if input_prov.exact
620 && input_prov
621 .dereferenced_projection
622 .iter()
623 .all(|e| e.is_some())
624 {
625 // examine all *other* inputs that have not been removed...
626 for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
627 for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
628 let all_key_columns_equated = |key: &Vec<usize>| {
629 key.iter().all(|key_col| {
630 // The root expression behind the key column, ie.
631 // the expression re-written in terms of elements in
632 // the projection of the Get operator.
633 let root_expr =
634 input_prov.dereference(&MirScalarExpr::column(*key_col));
635 // Check if there is a join equivalence that joins
636 // 'input' and 'other' on expressions that lead to
637 // the same root expression as the key column.
638 root_expr.as_ref().map_or(false, |root_expr| {
639 equivalences.iter().any(|equivalence| {
640 all![
641 contains_equivalent_expr_from_input(
642 equivalence,
643 root_expr,
644 input,
645 input_prov,
646 ),
647 contains_equivalent_expr_from_input(
648 equivalence,
649 root_expr,
650 other,
651 other_prov,
652 ),
653 ]
654 })
655 })
656 })
657 };
658
659 // Find an unique key for input that has all columns equated to other.
660 if keys.iter().any(all_key_columns_equated) {
661 // Find out whether we can produce input's projection strictly with
662 // elements in other's projection.
663 let expressions = input_prov
664 .dereferenced_projection
665 .iter()
666 .enumerate()
667 .flat_map(|(c, _)| {
668 // Check if the expression under input's 'c' column can be built
669 // with elements in other's projection.
670 input_prov.dereferenced_projection[c].as_ref().map_or(
671 None,
672 |root_expr| {
673 try_build_expression_using_other(
674 root_expr,
675 other,
676 other_prov,
677 input_mapper,
678 )
679 },
680 )
681 })
682 .collect_vec();
683 if expressions.len() == input_prov.dereferenced_projection.len() {
684 return Some(expressions);
685 }
686 }
687 }
688 }
689 }
690 }
691
692 None
693}
694
695/// Tries to build `root_expr` using elements from other's projection.
696fn try_build_expression_using_other(
697 root_expr: &MirScalarExpr,
698 other: usize,
699 other_prov: &ProvInfo,
700 input_mapper: &JoinInputMapper,
701) -> Option<MirScalarExpr> {
702 if root_expr.is_literal() {
703 return Some(root_expr.clone());
704 }
705
706 // Check if 'other' projects a column that lead to `root_expr`.
707 for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
708 if let Some(derefed) = derefed {
709 if derefed == root_expr {
710 return Some(MirScalarExpr::column(
711 input_mapper.map_column_to_global(other_col, other),
712 ));
713 }
714 }
715 }
716
717 // Otherwise, try to build root_expr's sub-expressions recursively
718 // other's projection.
719 match root_expr {
720 MirScalarExpr::Column(_, _) => None,
721 MirScalarExpr::CallUnary { func, expr } => {
722 try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
723 |expr| {
724 Some(MirScalarExpr::CallUnary {
725 func: func.clone(),
726 expr: Box::new(expr),
727 })
728 },
729 )
730 }
731 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
732 try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
733 |expr1| {
734 try_build_expression_using_other(expr2, other, other_prov, input_mapper)
735 .and_then(|expr2| {
736 Some(MirScalarExpr::CallBinary {
737 func: func.clone(),
738 expr1: Box::new(expr1),
739 expr2: Box::new(expr2),
740 })
741 })
742 },
743 )
744 }
745 MirScalarExpr::CallVariadic { func, exprs } => {
746 let new_exprs = exprs
747 .iter()
748 .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
749 .collect_vec();
750 if new_exprs.len() == exprs.len() {
751 Some(MirScalarExpr::CallVariadic {
752 func: func.clone(),
753 exprs: new_exprs,
754 })
755 } else {
756 None
757 }
758 }
759 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
760 Some(root_expr.clone())
761 }
762 MirScalarExpr::If { cond, then, els } => {
763 try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
764 |cond| {
765 try_build_expression_using_other(then, other, other_prov, input_mapper)
766 .and_then(|then| {
767 try_build_expression_using_other(els, other, other_prov, input_mapper)
768 .and_then(|els| {
769 Some(MirScalarExpr::If {
770 cond: Box::new(cond),
771 then: Box::new(then),
772 els: Box::new(els),
773 })
774 })
775 })
776 },
777 )
778 }
779 }
780}
781
782/// A context of `ProvInfo` vectors associated with bindings that might still be
783/// referenced.
784#[derive(Debug, Default)]
785pub struct ProvInfoCtx {
786 /// [`LocalId`] references in the remaining subtree.
787 ///
788 /// Entries from the `lets` map that are no longer used can be pruned.
789 uses: BTreeMap<LocalId, usize>,
790 /// [`ProvInfo`] vectors associated with let binding in scope.
791 lets: BTreeMap<LocalId, Vec<ProvInfo>>,
792}
793
794impl ProvInfoCtx {
795 /// Extend the `uses` map by the `LocalId`s used in `expr`.
796 pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
797 expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
798 MirRelationExpr::Get {
799 id: Id::Local(id), ..
800 } => {
801 let count = self.uses.entry(id.clone()).or_insert(0_usize);
802 *count += 1;
803 }
804 _ => (),
805 });
806 }
807
808 /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
809 /// `lets` entries for `uses` that reset to zero.
810 pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
811 let mut worklist = vec![expr];
812 while let Some(expr) = worklist.pop() {
813 if let MirRelationExpr::Get {
814 id: Id::Local(id), ..
815 } = expr
816 {
817 if let Some(count) = self.uses.get_mut(id) {
818 if *count > 0 {
819 *count -= 1;
820 }
821 if *count == 0 {
822 if self.lets.remove(id).is_none() {
823 soft_panic_or_log!("ctx.lets[{id}] should exist");
824 }
825 }
826 } else {
827 soft_panic_or_log!("ctx.uses[{id}] should exist");
828 }
829 }
830 match expr {
831 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
832 // When traversing the tree, don't descend into
833 // `Let`/`LetRec` sub-terms in order to avoid double
834 // counting (those are handled by remove_uses calls of
835 // RedundantJoin::action on subterms that were already
836 // visited because the action works bottom-up).
837 }
838 _ => {
839 worklist.extend(expr.children().rev());
840 }
841 }
842 }
843 }
844
845 /// Get the `ProvInfo` vector for `id` from the context.
846 pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
847 self.lets.get(id)
848 }
849
850 /// Extend the context with the `id: prov_infos` entry.
851 pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
852 self.lets.insert(id, prov_infos)
853 }
854
855 /// Remove the entry identified by `id` from the context.
856 pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
857 self.lets.remove(id)
858 }
859}