mz_transform/redundant_join.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31use mz_ore::{assert_none, soft_panic_or_log};
32
33use crate::{TransformCtx, all};
34
35/// Remove redundant collections of distinct elements from joins.
36#[derive(Debug)]
37pub struct RedundantJoin {
38 recursion_guard: RecursionGuard,
39}
40
41impl Default for RedundantJoin {
42 fn default() -> RedundantJoin {
43 RedundantJoin {
44 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
45 }
46 }
47}
48
49impl CheckedRecursion for RedundantJoin {
50 fn recursion_guard(&self) -> &RecursionGuard {
51 &self.recursion_guard
52 }
53}
54
55impl crate::Transform for RedundantJoin {
56 fn name(&self) -> &'static str {
57 "RedundantJoin"
58 }
59
60 #[mz_ore::instrument(
61 target = "optimizer",
62 level = "debug",
63 fields(path.segment = "redundant_join")
64 )]
65 fn actually_perform_transform(
66 &self,
67 relation: &mut MirRelationExpr,
68 _: &mut TransformCtx,
69 ) -> Result<(), crate::TransformError> {
70 let mut ctx = ProvInfoCtx::default();
71 ctx.extend_uses(relation);
72 let result = self.action(relation, &mut ctx);
73 mz_repr::explain::trace_plan(&*relation);
74 result.map(|_| ())
75 }
76}
77
78impl RedundantJoin {
79 /// Remove redundant collections of distinct elements from joins.
80 ///
81 /// This method tracks "provenance" information for each collections,
82 /// those being column-wise relationships to identified collections
83 /// (either imported collections, or let-bound collections). These
84 /// relationships state that when projected on to these columns, the
85 /// records of the one collection are contained in the records of the
86 /// identified collection.
87 ///
88 /// This provenance information is then used for the `MirRelationExpr::Join`
89 /// variant to remove "redundant" joins, those that can be determined to
90 /// neither restrict nor augment one of the input relations. Consult the
91 /// `find_redundancy` method and its documentation for more detail.
92 pub fn action(
93 &self,
94 relation: &mut MirRelationExpr,
95 ctx: &mut ProvInfoCtx,
96 ) -> Result<Vec<ProvInfo>, crate::TransformError> {
97 let mut result = self.checked_recur(|_| {
98 match relation {
99 MirRelationExpr::Let { id, value, body } => {
100 // Recursively determine provenance of the value.
101 let value_prov = self.action(value, ctx)?;
102 // Clear uses from the just visited binding definition.
103 ctx.remove_uses(value);
104
105 // Extend the lets context with an entry for this binding.
106 let prov_old = ctx.insert(*id, value_prov);
107 assert_none!(prov_old, "No shadowing");
108
109 // Determine provenance of the body.
110 let result = self.action(body, ctx)?;
111 ctx.remove_uses(body);
112
113 // Remove the lets entry for this binding from the context.
114 ctx.remove(id);
115
116 Ok(result)
117 }
118
119 MirRelationExpr::LetRec {
120 ids,
121 values,
122 limits: _,
123 body,
124 } => {
125 // As a first approximation, we naively extend the `lets`
126 // context with the empty vec![] for each id.
127 for id in ids.iter() {
128 let prov_old = ctx.insert(*id, vec![]);
129 assert_none!(prov_old, "No shadowing");
130 }
131
132 // In other words, we don't attempt to derive additional
133 // provenance information for a binding from its `value`.
134 //
135 // We descend into the values and the body with the naively
136 // extended context.
137 for value in values.iter_mut() {
138 self.action(value, ctx)?;
139 }
140 // Clear uses from the just visited recursive binding
141 // definitions.
142 for value in values.iter_mut() {
143 ctx.remove_uses(value);
144 }
145 let result = self.action(body, ctx)?;
146 ctx.remove_uses(body);
147
148 // Remove the lets entries for all ids.
149 for id in ids.iter() {
150 ctx.remove(id);
151 }
152
153 Ok(result)
154 }
155
156 MirRelationExpr::Get { id, typ, .. } => {
157 if let Id::Local(id) = id {
158 // Extract the value provenance (this should always exist).
159 let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
160 soft_panic_or_log!("no ctx entry for LocalId {id}");
161 vec![]
162 });
163 // Add information about being exactly this let binding too.
164 val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
165 Ok(val_info)
166 } else {
167 // Add information about being exactly this GlobalId reference.
168 Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
169 }
170 }
171
172 MirRelationExpr::Join {
173 inputs,
174 equivalences,
175 implementation,
176 } => {
177 // This logic first applies what it has learned about its input provenance,
178 // and if it finds a redundant join input it removes it. In that case, it
179 // also fails to produce exciting provenance information, partly out of
180 // laziness and the challenge of ensuring it is correct. Instead, if it is
181 // unable to find a redundant join it produces meaningful provenance information.
182
183 // Recursively apply transformation, and determine the provenance of inputs.
184 let mut input_prov = Vec::new();
185 for i in inputs.iter_mut() {
186 input_prov.push(self.action(i, ctx)?);
187 }
188
189 // Determine useful information about the structure of the inputs.
190 let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
191 let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
192
193 // If we find an input that can be removed, we should do so!
194 // We only do this once per invocation to keep our sanity, but we could
195 // rewrite it to iterate. We can avoid looking for any relation that
196 // does not have keys, as it cannot be redundant in that case.
197 if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
198 .rev()
199 .filter(|i| !input_types[*i].keys.is_empty())
200 .flat_map(|i| {
201 find_redundancy(
202 i,
203 &input_types[i].keys,
204 &old_input_mapper,
205 equivalences,
206 &input_prov[..],
207 )
208 .map(|b| (i, b))
209 })
210 .next()
211 {
212 // Clear uses from the removed input.
213 ctx.remove_uses(&inputs[remove_input_idx]);
214
215 inputs.remove(remove_input_idx);
216 input_types.remove(remove_input_idx);
217
218 // Update the column offsets in the binding expressions to catch
219 // up with the removal of `remove_input_idx`.
220 for expr in bindings.iter_mut() {
221 expr.visit_pre_mut(|e| {
222 if let MirScalarExpr::Column(c) = e {
223 let (_local_col, input_relation) =
224 old_input_mapper.map_column_to_local(*c);
225 if input_relation > remove_input_idx {
226 *c -= old_input_mapper.input_arity(remove_input_idx);
227 }
228 }
229 });
230 }
231
232 // Replace column references from `remove_input_idx` with the corresponding
233 // binding expression. Update the offsets of the column references
234 // from inputs after `remove_input_idx`.
235 for equivalence in equivalences.iter_mut() {
236 for expr in equivalence.iter_mut() {
237 expr.visit_mut_post(&mut |e| {
238 if let MirScalarExpr::Column(c) = e {
239 let (local_col, input_relation) =
240 old_input_mapper.map_column_to_local(*c);
241 if input_relation == remove_input_idx {
242 *e = bindings[local_col].clone();
243 } else if input_relation > remove_input_idx {
244 *c -= old_input_mapper.input_arity(remove_input_idx);
245 }
246 }
247 })?;
248 }
249 }
250
251 mz_expr::canonicalize::canonicalize_equivalences(
252 equivalences,
253 input_types.iter().map(|t| &t.column_types),
254 );
255
256 // Build a projection that leaves the binding expressions in the same
257 // position as the columns of the removed join input they are replacing.
258 let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
259 let mut projection = Vec::new();
260 let new_join_arity = new_input_mapper.total_columns();
261 for i in 0..old_input_mapper.total_inputs() {
262 if i != remove_input_idx {
263 projection.extend(
264 new_input_mapper.global_columns(if i < remove_input_idx {
265 i
266 } else {
267 i - 1
268 }),
269 );
270 } else {
271 projection.extend(new_join_arity..new_join_arity + bindings.len());
272 }
273 }
274
275 // Unset implementation, as irrevocably hosed by this transformation.
276 *implementation = mz_expr::JoinImplementation::Unimplemented;
277
278 *relation = relation.take_dangerous().map(bindings).project(projection);
279 // The projection will gum up provenance reasoning anyhow, so don't work hard.
280 // We will return to this expression again with the same analysis.
281 Ok(Vec::new())
282 } else {
283 // Provenance information should be the union of input provenance information,
284 // with columns updated. Because rows may be dropped in the join, all `exact`
285 // bits should be un-set.
286 let mut results = Vec::new();
287 for (input, input_prov) in input_prov.into_iter().enumerate() {
288 for mut prov in input_prov {
289 prov.exact = false;
290 let mut projection = vec![None; old_input_mapper.total_columns()];
291 for (local_col, global_col) in
292 old_input_mapper.global_columns(input).enumerate()
293 {
294 projection[global_col]
295 .clone_from(&prov.dereferenced_projection[local_col]);
296 }
297 prov.dereferenced_projection = projection;
298 results.push(prov);
299 }
300 }
301 Ok(results)
302 }
303 }
304
305 MirRelationExpr::Filter { input, .. } => {
306 // Filter may drop records, and so we unset `exact`.
307 let mut result = self.action(input, ctx)?;
308 for prov in result.iter_mut() {
309 prov.exact = false;
310 }
311 Ok(result)
312 }
313
314 MirRelationExpr::Map { input, scalars } => {
315 let mut result = self.action(input, ctx)?;
316 for prov in result.iter_mut() {
317 for scalar in scalars.iter() {
318 let dereferenced_scalar = prov.strict_dereference(scalar);
319 prov.dereferenced_projection.push(dereferenced_scalar);
320 }
321 }
322 Ok(result)
323 }
324
325 MirRelationExpr::Union { base, inputs } => {
326 let mut prov = self.action(base, ctx)?;
327 for input in inputs {
328 let input_prov = self.action(input, ctx)?;
329 // To merge a new list of provenances, we look at the cross
330 // produce of things we might know about each source.
331 // TODO(mcsherry): this can be optimized to use datastructures
332 // keyed by the source identifier.
333 let mut new_prov = Vec::new();
334 for l in prov {
335 new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
336 }
337 prov = new_prov;
338 }
339 Ok(prov)
340 }
341
342 MirRelationExpr::Constant { .. } => Ok(Vec::new()),
343
344 MirRelationExpr::Reduce {
345 input,
346 group_key,
347 aggregates,
348 ..
349 } => {
350 // Reduce yields its first few columns as a key, and produces
351 // all key tuples that were present in its input.
352 let mut result = self.action(input, ctx)?;
353 for prov in result.iter_mut() {
354 let mut projection = group_key
355 .iter()
356 .map(|key| prov.strict_dereference(key))
357 .collect_vec();
358 projection.extend((0..aggregates.len()).map(|_| None));
359 prov.dereferenced_projection = projection;
360 }
361 // TODO: For min, max aggregates, we could preserve provenance
362 // if the expression references a column. We would need to un-set
363 // the `exact` bit in that case, and so we would want to keep both
364 // sets of provenance information.
365 Ok(result)
366 }
367
368 MirRelationExpr::Threshold { input } => {
369 // Threshold may drop records, and so we unset `exact`.
370 let mut result = self.action(input, ctx)?;
371 for prov in result.iter_mut() {
372 prov.exact = false;
373 }
374 Ok(result)
375 }
376
377 MirRelationExpr::TopK { input, .. } => {
378 // TopK may drop records, and so we unset `exact`.
379 let mut result = self.action(input, ctx)?;
380 for prov in result.iter_mut() {
381 prov.exact = false;
382 }
383 Ok(result)
384 }
385
386 MirRelationExpr::Project { input, outputs } => {
387 // Projections re-order, drop, and duplicate columns,
388 // but they neither drop rows nor invent values.
389 let mut result = self.action(input, ctx)?;
390 for prov in result.iter_mut() {
391 let projection = outputs
392 .iter()
393 .map(|c| prov.dereference(&MirScalarExpr::Column(*c)))
394 .collect_vec();
395 prov.dereferenced_projection = projection;
396 }
397 Ok(result)
398 }
399
400 MirRelationExpr::FlatMap { input, func, .. } => {
401 // FlatMap may drop records, and so we unset `exact`.
402 let mut result = self.action(input, ctx)?;
403 for prov in result.iter_mut() {
404 prov.exact = false;
405 prov.dereferenced_projection
406 .extend((0..func.output_type().column_types.len()).map(|_| None));
407 }
408 Ok(result)
409 }
410
411 MirRelationExpr::Negate { input } => {
412 // Negate does not guarantee that the multiplicity of
413 // each source record it at least one. This could have
414 // been a problem in `Union`, where we might report
415 // that the union of positive and negative records is
416 // "exact": cancellations would make this false.
417 let mut result = self.action(input, ctx)?;
418 for prov in result.iter_mut() {
419 prov.exact = false;
420 }
421 Ok(result)
422 }
423
424 MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
425 }
426 })?;
427 result.retain(|info| !info.is_trivial());
428
429 // Uncomment the following lines to trace the individual steps:
430 // println!("{}", relation.pretty());
431 // println!("result = {result:?}");
432 // println!("lets: {lets:?}");
433 // println!("---------------------");
434
435 Ok(result)
436 }
437}
438
439/// A relationship between a collections columns and some source columns.
440///
441/// An instance of this type indicates that some of the bearer's columns
442/// derive from `id`. In particular, the non-`None` elements in
443/// `dereferenced_projection` correspond to columns that can be derived
444/// from `id`'s projection.
445///
446/// The guarantee is that projected on to these columns, the distinct values
447/// of the bearer are contained in the set of distinct values of projected
448/// columns of `id`. In the case that `exact` is set, the two sets are equal.
449#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
450pub struct ProvInfo {
451 /// The Id (local or global) of the source.
452 id: Id,
453 /// The projection of the bearer written in terms of the columns projected
454 /// by the underlying Get operator. Set to `None` for columns that cannot
455 /// be expressed as scalar expression referencing only columns of the
456 /// underlying Get operator.
457 dereferenced_projection: Vec<Option<MirScalarExpr>>,
458 /// If true, all distinct projected source rows are present in the rows of
459 /// the projection of the current collection. This constraint is lost as soon
460 /// as a transformation may drop records.
461 exact: bool,
462}
463
464impl ProvInfo {
465 fn make_leaf(id: Id, arity: usize) -> Self {
466 Self {
467 id,
468 dereferenced_projection: (0..arity)
469 .map(|c| Some(MirScalarExpr::column(c)))
470 .collect::<Vec<_>>(),
471 exact: true,
472 }
473 }
474
475 /// Rewrite `expr` so it refers to the columns of the original source instead
476 /// of the columns of the projected source.
477 fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
478 match expr {
479 MirScalarExpr::Column(c) => {
480 if let Some(expr) = &self.dereferenced_projection[*c] {
481 Some(expr.clone())
482 } else {
483 None
484 }
485 }
486 MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
487 Some(MirScalarExpr::CallUnary {
488 func: func.clone(),
489 expr: Box::new(expr),
490 })
491 }),
492 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
493 self.dereference(expr1).and_then(|expr1| {
494 self.dereference(expr2).and_then(|expr2| {
495 Some(MirScalarExpr::CallBinary {
496 func: func.clone(),
497 expr1: Box::new(expr1),
498 expr2: Box::new(expr2),
499 })
500 })
501 })
502 }
503 MirScalarExpr::CallVariadic { func, exprs } => {
504 let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
505 if new_exprs.len() == exprs.len() {
506 Some(MirScalarExpr::CallVariadic {
507 func: func.clone(),
508 exprs: new_exprs,
509 })
510 } else {
511 None
512 }
513 }
514 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
515 Some(expr.clone())
516 }
517 MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
518 self.dereference(then).and_then(|then| {
519 self.dereference(els).and_then(|els| {
520 Some(MirScalarExpr::If {
521 cond: Box::new(cond),
522 then: Box::new(then),
523 els: Box::new(els),
524 })
525 })
526 })
527 }),
528 }
529 }
530
531 /// Like `dereference` but only returns expressions that actually depend on
532 /// the original source.
533 fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
534 let derefed = self.dereference(expr);
535 match derefed {
536 Some(ref expr) if !expr.support().is_empty() => derefed,
537 _ => None,
538 }
539 }
540
541 /// Merge two constraints to find a constraint that satisfies both inputs.
542 ///
543 /// This method returns nothing if no columns are in common (either because
544 /// difference sources are identified, or just no columns in common) and it
545 /// intersects bindings and the `exact` bit.
546 fn meet(&self, other: &Self) -> Option<Self> {
547 if self.id == other.id {
548 let resulting_projection = self
549 .dereferenced_projection
550 .iter()
551 .zip(other.dereferenced_projection.iter())
552 .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
553 .collect_vec();
554 if resulting_projection.iter().any(|e| e.is_some()) {
555 Some(ProvInfo {
556 id: self.id,
557 dereferenced_projection: resulting_projection,
558 exact: self.exact && other.exact,
559 })
560 } else {
561 None
562 }
563 } else {
564 None
565 }
566 }
567
568 /// Check if all entries of the dereferenced projection are missing.
569 ///
570 /// If this is the case keeping the `ProvInfo` entry around is meaningless.
571 fn is_trivial(&self) -> bool {
572 all![
573 !self.dereferenced_projection.is_empty(),
574 self.dereferenced_projection.iter().all(|x| x.is_none()),
575 ]
576 }
577}
578
579/// Attempts to find column bindings that make `input` redundant.
580///
581/// This method attempts to determine that `input` may be redundant by searching
582/// the join structure for another relation `other` with provenance that contains some
583/// provenance of `input`, and keys for `input` that are equated by the join to the
584/// corresponding columns of `other` under their provenance. The `input` provenance
585/// must also have its `exact` bit set.
586///
587/// In these circumstances, the claim is that because the key columns are equated and
588/// determine non-key columns, any matches between `input` and
589/// `other` will neither introduce new information to `other`, nor restrict the rows
590/// of `other`, nor alter their multplicity.
591fn find_redundancy(
592 input: usize,
593 keys: &[Vec<usize>],
594 input_mapper: &JoinInputMapper,
595 equivalences: &[Vec<MirScalarExpr>],
596 input_provs: &[Vec<ProvInfo>],
597) -> Option<Vec<MirScalarExpr>> {
598 // Whether the `equivalence` contains an expression that only references
599 // `input` that leads to the same as `root_expr` once dereferenced.
600 let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
601 root_expr: &MirScalarExpr,
602 input: usize,
603 provenance: &ProvInfo|
604 -> bool {
605 equivalence.iter().any(|expr| {
606 Some(input) == input_mapper.single_input(expr)
607 && provenance
608 .dereference(&input_mapper.map_expr_to_local(expr.clone()))
609 .as_ref()
610 == Some(root_expr)
611 })
612 };
613 for input_prov in input_provs[input].iter() {
614 // We can only elide if the input contains all records, and binds all columns.
615 if input_prov.exact
616 && input_prov
617 .dereferenced_projection
618 .iter()
619 .all(|e| e.is_some())
620 {
621 // examine all *other* inputs that have not been removed...
622 for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
623 for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
624 let all_key_columns_equated = |key: &Vec<usize>| {
625 key.iter().all(|key_col| {
626 // The root expression behind the key column, ie.
627 // the expression re-written in terms of elements in
628 // the projection of the Get operator.
629 let root_expr =
630 input_prov.dereference(&MirScalarExpr::column(*key_col));
631 // Check if there is a join equivalence that joins
632 // 'input' and 'other' on expressions that lead to
633 // the same root expression as the key column.
634 root_expr.as_ref().map_or(false, |root_expr| {
635 equivalences.iter().any(|equivalence| {
636 all![
637 contains_equivalent_expr_from_input(
638 equivalence,
639 root_expr,
640 input,
641 input_prov,
642 ),
643 contains_equivalent_expr_from_input(
644 equivalence,
645 root_expr,
646 other,
647 other_prov,
648 ),
649 ]
650 })
651 })
652 })
653 };
654
655 // Find an unique key for input that has all columns equated to other.
656 if keys.iter().any(all_key_columns_equated) {
657 // Find out whether we can produce input's projection strictly with
658 // elements in other's projection.
659 let expressions = input_prov
660 .dereferenced_projection
661 .iter()
662 .enumerate()
663 .flat_map(|(c, _)| {
664 // Check if the expression under input's 'c' column can be built
665 // with elements in other's projection.
666 input_prov.dereferenced_projection[c].as_ref().map_or(
667 None,
668 |root_expr| {
669 try_build_expression_using_other(
670 root_expr,
671 other,
672 other_prov,
673 input_mapper,
674 )
675 },
676 )
677 })
678 .collect_vec();
679 if expressions.len() == input_prov.dereferenced_projection.len() {
680 return Some(expressions);
681 }
682 }
683 }
684 }
685 }
686 }
687
688 None
689}
690
691/// Tries to build `root_expr` using elements from other's projection.
692fn try_build_expression_using_other(
693 root_expr: &MirScalarExpr,
694 other: usize,
695 other_prov: &ProvInfo,
696 input_mapper: &JoinInputMapper,
697) -> Option<MirScalarExpr> {
698 if root_expr.is_literal() {
699 return Some(root_expr.clone());
700 }
701
702 // Check if 'other' projects a column that lead to `root_expr`.
703 for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
704 if let Some(derefed) = derefed {
705 if derefed == root_expr {
706 return Some(MirScalarExpr::Column(
707 input_mapper.map_column_to_global(other_col, other),
708 ));
709 }
710 }
711 }
712
713 // Otherwise, try to build root_expr's sub-expressions recursively
714 // other's projection.
715 match root_expr {
716 MirScalarExpr::Column(_) => None,
717 MirScalarExpr::CallUnary { func, expr } => {
718 try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
719 |expr| {
720 Some(MirScalarExpr::CallUnary {
721 func: func.clone(),
722 expr: Box::new(expr),
723 })
724 },
725 )
726 }
727 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
728 try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
729 |expr1| {
730 try_build_expression_using_other(expr2, other, other_prov, input_mapper)
731 .and_then(|expr2| {
732 Some(MirScalarExpr::CallBinary {
733 func: func.clone(),
734 expr1: Box::new(expr1),
735 expr2: Box::new(expr2),
736 })
737 })
738 },
739 )
740 }
741 MirScalarExpr::CallVariadic { func, exprs } => {
742 let new_exprs = exprs
743 .iter()
744 .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
745 .collect_vec();
746 if new_exprs.len() == exprs.len() {
747 Some(MirScalarExpr::CallVariadic {
748 func: func.clone(),
749 exprs: new_exprs,
750 })
751 } else {
752 None
753 }
754 }
755 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
756 Some(root_expr.clone())
757 }
758 MirScalarExpr::If { cond, then, els } => {
759 try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
760 |cond| {
761 try_build_expression_using_other(then, other, other_prov, input_mapper)
762 .and_then(|then| {
763 try_build_expression_using_other(els, other, other_prov, input_mapper)
764 .and_then(|els| {
765 Some(MirScalarExpr::If {
766 cond: Box::new(cond),
767 then: Box::new(then),
768 els: Box::new(els),
769 })
770 })
771 })
772 },
773 )
774 }
775 }
776}
777
778/// A context of `ProvInfo` vectors associated with bindings that might still be
779/// referenced.
780#[derive(Debug, Default)]
781pub struct ProvInfoCtx {
782 /// [`LocalId`] references in the remaining subtree.
783 ///
784 /// Entries from the `lets` map that are no longer used can be pruned.
785 uses: BTreeMap<LocalId, usize>,
786 /// [`ProvInfo`] vectors associated with let binding in scope.
787 lets: BTreeMap<LocalId, Vec<ProvInfo>>,
788}
789
790impl ProvInfoCtx {
791 /// Extend the `uses` map by the `LocalId`s used in `expr`.
792 pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
793 expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
794 MirRelationExpr::Get {
795 id: Id::Local(id), ..
796 } => {
797 let count = self.uses.entry(id.clone()).or_insert(0_usize);
798 *count += 1;
799 }
800 _ => (),
801 });
802 }
803
804 /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
805 /// `lets` entries for `uses` that reset to zero.
806 pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
807 let mut worklist = vec![expr];
808 while let Some(expr) = worklist.pop() {
809 if let MirRelationExpr::Get {
810 id: Id::Local(id), ..
811 } = expr
812 {
813 if let Some(count) = self.uses.get_mut(id) {
814 if *count > 0 {
815 *count -= 1;
816 }
817 if *count == 0 {
818 if self.lets.remove(id).is_none() {
819 soft_panic_or_log!("ctx.lets[{id}] should exist");
820 }
821 }
822 } else {
823 soft_panic_or_log!("ctx.uses[{id}] should exist");
824 }
825 }
826 match expr {
827 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
828 // When traversing the tree, don't descend into
829 // `Let`/`LetRec` sub-terms in order to avoid double
830 // counting (those are handled by remove_uses calls of
831 // RedundantJoin::action on subterms that were already
832 // visited because the action works bottom-up).
833 }
834 _ => {
835 worklist.extend(expr.children().rev());
836 }
837 }
838 }
839 }
840
841 /// Get the `ProvInfo` vector for `id` from the context.
842 pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
843 self.lets.get(id)
844 }
845
846 /// Extend the context with the `id: prov_infos` entry.
847 pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
848 self.lets.insert(id, prov_infos)
849 }
850
851 /// Remove the entry identified by `id` from the context.
852 pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
853 self.lets.remove(id)
854 }
855}