mz_transform/redundant_join.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{
30 Columns, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
31};
32use mz_ore::stack::{CheckedRecursion, RecursionGuard};
33use mz_ore::{assert_none, soft_panic_or_log};
34
35use crate::{TransformCtx, all};
36
37/// Remove redundant collections of distinct elements from joins.
38#[derive(Debug)]
39pub struct RedundantJoin {
40 recursion_guard: RecursionGuard,
41}
42
43impl Default for RedundantJoin {
44 fn default() -> RedundantJoin {
45 RedundantJoin {
46 recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
47 }
48 }
49}
50
51impl CheckedRecursion for RedundantJoin {
52 fn recursion_guard(&self) -> &RecursionGuard {
53 &self.recursion_guard
54 }
55}
56
57impl crate::Transform for RedundantJoin {
58 fn name(&self) -> &'static str {
59 "RedundantJoin"
60 }
61
62 #[mz_ore::instrument(
63 target = "optimizer",
64 level = "debug",
65 fields(path.segment = "redundant_join")
66 )]
67 fn actually_perform_transform(
68 &self,
69 relation: &mut MirRelationExpr,
70 _: &mut TransformCtx,
71 ) -> Result<(), crate::TransformError> {
72 let mut ctx = ProvInfoCtx::default();
73 ctx.extend_uses(relation);
74 let result = self.action(relation, &mut ctx);
75 mz_repr::explain::trace_plan(&*relation);
76 result.map(|_| ())
77 }
78}
79
80impl RedundantJoin {
81 /// Remove redundant collections of distinct elements from joins.
82 ///
83 /// This method tracks "provenance" information for each collections,
84 /// those being column-wise relationships to identified collections
85 /// (either imported collections, or let-bound collections). These
86 /// relationships state that when projected on to these columns, the
87 /// records of the one collection are contained in the records of the
88 /// identified collection.
89 ///
90 /// This provenance information is then used for the `MirRelationExpr::Join`
91 /// variant to remove "redundant" joins, those that can be determined to
92 /// neither restrict nor augment one of the input relations. Consult the
93 /// `find_redundancy` method and its documentation for more detail.
94 pub fn action(
95 &self,
96 relation: &mut MirRelationExpr,
97 ctx: &mut ProvInfoCtx,
98 ) -> Result<Vec<ProvInfo>, crate::TransformError> {
99 let mut result = self.checked_recur(|_| {
100 match relation {
101 MirRelationExpr::Let { id, value, body } => {
102 // Recursively determine provenance of the value.
103 let value_prov = self.action(value, ctx)?;
104 // Clear uses from the just visited binding definition.
105 ctx.remove_uses(value);
106
107 // Extend the lets context with an entry for this binding.
108 let prov_old = ctx.insert(*id, value_prov);
109 assert_none!(prov_old, "No shadowing");
110
111 // Determine provenance of the body.
112 let result = self.action(body, ctx)?;
113 ctx.remove_uses(body);
114
115 // Remove the lets entry for this binding from the context.
116 ctx.remove(id);
117
118 Ok(result)
119 }
120
121 MirRelationExpr::LetRec {
122 ids,
123 values,
124 limits: _,
125 body,
126 } => {
127 // As a first approximation, we naively extend the `lets`
128 // context with the empty vec![] for each id.
129 for id in ids.iter() {
130 let prov_old = ctx.insert(*id, vec![]);
131 assert_none!(prov_old, "No shadowing");
132 }
133
134 // In other words, we don't attempt to derive additional
135 // provenance information for a binding from its `value`.
136 //
137 // We descend into the values and the body with the naively
138 // extended context.
139 for value in values.iter_mut() {
140 self.action(value, ctx)?;
141 }
142 // Clear uses from the just visited recursive binding
143 // definitions.
144 for value in values.iter_mut() {
145 ctx.remove_uses(value);
146 }
147 let result = self.action(body, ctx)?;
148 ctx.remove_uses(body);
149
150 // Remove the lets entries for all ids.
151 for id in ids.iter() {
152 ctx.remove(id);
153 }
154
155 Ok(result)
156 }
157
158 MirRelationExpr::Get { id, typ, .. } => {
159 if let Id::Local(id) = id {
160 // Extract the value provenance (this should always exist).
161 let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
162 soft_panic_or_log!("no ctx entry for LocalId {id}");
163 vec![]
164 });
165 // Add information about being exactly this let binding too.
166 val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
167 Ok(val_info)
168 } else {
169 // Add information about being exactly this GlobalId reference.
170 Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
171 }
172 }
173
174 MirRelationExpr::Join {
175 inputs,
176 equivalences,
177 implementation,
178 } => {
179 // This logic first applies what it has learned about its input provenance,
180 // and if it finds a redundant join input it removes it. In that case, it
181 // also fails to produce exciting provenance information, partly out of
182 // laziness and the challenge of ensuring it is correct. Instead, if it is
183 // unable to find a redundant join it produces meaningful provenance information.
184
185 // Recursively apply transformation, and determine the provenance of inputs.
186 let mut input_prov = Vec::new();
187 for i in inputs.iter_mut() {
188 input_prov.push(self.action(i, ctx)?);
189 }
190
191 // Determine useful information about the structure of the inputs.
192 let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
193 let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
194
195 // If we find an input that can be removed, we should do so!
196 // We only do this once per invocation to keep our sanity, but we could
197 // rewrite it to iterate. We can avoid looking for any relation that
198 // does not have keys, as it cannot be redundant in that case.
199 if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
200 .rev()
201 .filter(|i| !input_types[*i].keys.is_empty())
202 .flat_map(|i| {
203 find_redundancy(
204 i,
205 &input_types[i].keys,
206 &old_input_mapper,
207 equivalences,
208 &input_prov[..],
209 )
210 .map(|b| (i, b))
211 })
212 .next()
213 {
214 // Clear uses from the removed input.
215 ctx.remove_uses(&inputs[remove_input_idx]);
216
217 inputs.remove(remove_input_idx);
218 input_types.remove(remove_input_idx);
219
220 // Update the column offsets in the binding expressions to catch
221 // up with the removal of `remove_input_idx`.
222 for expr in bindings.iter_mut() {
223 expr.visit_pre_mut(|e| {
224 if let MirScalarExpr::Column(c, _) = e {
225 let (_local_col, input_relation) =
226 old_input_mapper.map_column_to_local(*c);
227 if input_relation > remove_input_idx {
228 *c -= old_input_mapper.input_arity(remove_input_idx);
229 }
230 }
231 });
232 }
233
234 // Replace column references from `remove_input_idx` with the corresponding
235 // binding expression. Update the offsets of the column references
236 // from inputs after `remove_input_idx`.
237 for equivalence in equivalences.iter_mut() {
238 for expr in equivalence.iter_mut() {
239 expr.visit_mut_post(&mut |e| {
240 if let MirScalarExpr::Column(c, _) = e {
241 let (local_col, input_relation) =
242 old_input_mapper.map_column_to_local(*c);
243 if input_relation == remove_input_idx {
244 *e = bindings[local_col].clone();
245 } else if input_relation > remove_input_idx {
246 *c -= old_input_mapper.input_arity(remove_input_idx);
247 }
248 }
249 })?;
250 }
251 }
252
253 mz_expr::canonicalize::canonicalize_equivalences(
254 equivalences,
255 input_types.iter().map(|t| &t.column_types),
256 );
257
258 // Build a projection that leaves the binding expressions in the same
259 // position as the columns of the removed join input they are replacing.
260 let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
261 let mut projection = Vec::new();
262 let new_join_arity = new_input_mapper.total_columns();
263 for i in 0..old_input_mapper.total_inputs() {
264 if i != remove_input_idx {
265 projection.extend(
266 new_input_mapper.global_columns(if i < remove_input_idx {
267 i
268 } else {
269 i - 1
270 }),
271 );
272 } else {
273 projection.extend(new_join_arity..new_join_arity + bindings.len());
274 }
275 }
276
277 // Unset implementation, as irrevocably hosed by this transformation.
278 *implementation = mz_expr::JoinImplementation::Unimplemented;
279
280 *relation = relation.take_dangerous().map(bindings).project(projection);
281 // The projection will gum up provenance reasoning anyhow, so don't work hard.
282 // We will return to this expression again with the same analysis.
283 Ok(Vec::new())
284 } else {
285 // Provenance information should be the union of input provenance information,
286 // with columns updated. Because rows may be dropped in the join, all `exact`
287 // bits should be un-set.
288 let mut results = Vec::new();
289 for (input, input_prov) in input_prov.into_iter().enumerate() {
290 for mut prov in input_prov {
291 prov.exact = false;
292 let mut projection = vec![None; old_input_mapper.total_columns()];
293 for (local_col, global_col) in
294 old_input_mapper.global_columns(input).enumerate()
295 {
296 projection[global_col]
297 .clone_from(&prov.dereferenced_projection[local_col]);
298 }
299 prov.dereferenced_projection = projection;
300 results.push(prov);
301 }
302 }
303 Ok(results)
304 }
305 }
306
307 MirRelationExpr::Filter { input, .. } => {
308 // Filter may drop records, and so we unset `exact`.
309 let mut result = self.action(input, ctx)?;
310 for prov in result.iter_mut() {
311 prov.exact = false;
312 }
313 Ok(result)
314 }
315
316 MirRelationExpr::Map { input, scalars } => {
317 let mut result = self.action(input, ctx)?;
318 for prov in result.iter_mut() {
319 for scalar in scalars.iter() {
320 let dereferenced_scalar = prov.strict_dereference(scalar);
321 prov.dereferenced_projection.push(dereferenced_scalar);
322 }
323 }
324 Ok(result)
325 }
326
327 MirRelationExpr::Union { base, inputs } => {
328 let mut prov = self.action(base, ctx)?;
329 for input in inputs {
330 let input_prov = self.action(input, ctx)?;
331 // To merge a new list of provenances, we look at the cross
332 // produce of things we might know about each source.
333 // TODO(mcsherry): this can be optimized to use datastructures
334 // keyed by the source identifier.
335 let mut new_prov = Vec::new();
336 for l in prov {
337 new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
338 }
339 prov = new_prov;
340 }
341 Ok(prov)
342 }
343
344 MirRelationExpr::Constant { .. } => Ok(Vec::new()),
345
346 MirRelationExpr::Reduce {
347 input,
348 group_key,
349 aggregates,
350 ..
351 } => {
352 // Reduce yields its first few columns as a key, and produces
353 // all key tuples that were present in its input.
354 let mut result = self.action(input, ctx)?;
355 for prov in result.iter_mut() {
356 let mut projection = group_key
357 .iter()
358 .map(|key| prov.strict_dereference(key))
359 .collect_vec();
360 projection.extend((0..aggregates.len()).map(|_| None));
361 prov.dereferenced_projection = projection;
362 }
363 // TODO: For min, max aggregates, we could preserve provenance
364 // if the expression references a column. We would need to un-set
365 // the `exact` bit in that case, and so we would want to keep both
366 // sets of provenance information.
367 Ok(result)
368 }
369
370 MirRelationExpr::Threshold { input } => {
371 // Threshold may drop records, and so we unset `exact`.
372 let mut result = self.action(input, ctx)?;
373 for prov in result.iter_mut() {
374 prov.exact = false;
375 }
376 Ok(result)
377 }
378
379 MirRelationExpr::TopK { input, .. } => {
380 // TopK may drop records, and so we unset `exact`.
381 let mut result = self.action(input, ctx)?;
382 for prov in result.iter_mut() {
383 prov.exact = false;
384 }
385 Ok(result)
386 }
387
388 MirRelationExpr::Project { input, outputs } => {
389 // Projections re-order, drop, and duplicate columns,
390 // but they neither drop rows nor invent values.
391 let mut result = self.action(input, ctx)?;
392 for prov in result.iter_mut() {
393 let projection = outputs
394 .iter()
395 .map(|c| prov.dereference(&MirScalarExpr::column(*c)))
396 .collect_vec();
397 prov.dereferenced_projection = projection;
398 }
399 Ok(result)
400 }
401
402 MirRelationExpr::FlatMap {
403 input,
404 func,
405 exprs: _,
406 } => {
407 // FlatMap may drop records, and so we unset `exact`.
408 let mut result = self.action(input, ctx)?;
409 for prov in result.iter_mut() {
410 prov.exact = false;
411 prov.dereferenced_projection
412 .extend((0..func.output_arity()).map(|_| None));
413 }
414 Ok(result)
415 }
416
417 MirRelationExpr::Negate { input } => {
418 // Negate does not guarantee that the multiplicity of
419 // each source record it at least one. This could have
420 // been a problem in `Union`, where we might report
421 // that the union of positive and negative records is
422 // "exact": cancellations would make this false.
423 let mut result = self.action(input, ctx)?;
424 for prov in result.iter_mut() {
425 prov.exact = false;
426 }
427 Ok(result)
428 }
429
430 MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
431 }
432 })?;
433 result.retain(|info| !info.is_trivial());
434
435 // Uncomment the following lines to trace the individual steps:
436 // println!("{}", relation.pretty());
437 // println!("result = {result:?}");
438 // println!("lets: {lets:?}");
439 // println!("---------------------");
440
441 Ok(result)
442 }
443}
444
445/// A relationship between a collections columns and some source columns.
446///
447/// An instance of this type indicates that some of the bearer's columns
448/// derive from `id`. In particular, the non-`None` elements in
449/// `dereferenced_projection` correspond to columns that can be derived
450/// from `id`'s projection.
451///
452/// The guarantee is that projected on to these columns, the distinct values
453/// of the bearer are contained in the set of distinct values of projected
454/// columns of `id`. In the case that `exact` is set, the two sets are equal.
455#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
456pub struct ProvInfo {
457 /// The Id (local or global) of the source.
458 id: Id,
459 /// The projection of the bearer written in terms of the columns projected
460 /// by the underlying Get operator. Set to `None` for columns that cannot
461 /// be expressed as scalar expression referencing only columns of the
462 /// underlying Get operator.
463 dereferenced_projection: Vec<Option<MirScalarExpr>>,
464 /// If true, all distinct projected source rows are present in the rows of
465 /// the projection of the current collection. This constraint is lost as soon
466 /// as a transformation may drop records.
467 exact: bool,
468}
469
470impl ProvInfo {
471 fn make_leaf(id: Id, arity: usize) -> Self {
472 Self {
473 id,
474 dereferenced_projection: (0..arity)
475 .map(|c| Some(MirScalarExpr::column(c)))
476 .collect::<Vec<_>>(),
477 exact: true,
478 }
479 }
480
481 /// Rewrite `expr` so it refers to the columns of the original source instead
482 /// of the columns of the projected source.
483 fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
484 match expr {
485 MirScalarExpr::Column(c, _) => {
486 if let Some(expr) = &self.dereferenced_projection[*c] {
487 Some(expr.clone())
488 } else {
489 None
490 }
491 }
492 MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
493 Some(MirScalarExpr::CallUnary {
494 func: func.clone(),
495 expr: Box::new(expr),
496 })
497 }),
498 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
499 self.dereference(expr1).and_then(|expr1| {
500 self.dereference(expr2).and_then(|expr2| {
501 Some(MirScalarExpr::CallBinary {
502 func: func.clone(),
503 expr1: Box::new(expr1),
504 expr2: Box::new(expr2),
505 })
506 })
507 })
508 }
509 MirScalarExpr::CallVariadic { func, exprs } => {
510 let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
511 if new_exprs.len() == exprs.len() {
512 Some(MirScalarExpr::CallVariadic {
513 func: func.clone(),
514 exprs: new_exprs,
515 })
516 } else {
517 None
518 }
519 }
520 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
521 Some(expr.clone())
522 }
523 MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
524 self.dereference(then).and_then(|then| {
525 self.dereference(els).and_then(|els| {
526 Some(MirScalarExpr::If {
527 cond: Box::new(cond),
528 then: Box::new(then),
529 els: Box::new(els),
530 })
531 })
532 })
533 }),
534 }
535 }
536
537 /// Like `dereference` but only returns expressions that actually depend on
538 /// the original source.
539 fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
540 let derefed = self.dereference(expr);
541 match derefed {
542 Some(ref expr) if !expr.support().is_empty() => derefed,
543 _ => None,
544 }
545 }
546
547 /// Merge two constraints to find a constraint that satisfies both inputs.
548 ///
549 /// This method returns nothing if no columns are in common (either because
550 /// difference sources are identified, or just no columns in common) and it
551 /// intersects bindings and the `exact` bit.
552 fn meet(&self, other: &Self) -> Option<Self> {
553 if self.id == other.id {
554 let resulting_projection = self
555 .dereferenced_projection
556 .iter()
557 .zip_eq(other.dereferenced_projection.iter())
558 .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
559 .collect_vec();
560 if resulting_projection.iter().any(|e| e.is_some()) {
561 Some(ProvInfo {
562 id: self.id,
563 dereferenced_projection: resulting_projection,
564 exact: self.exact && other.exact,
565 })
566 } else {
567 None
568 }
569 } else {
570 None
571 }
572 }
573
574 /// Check if all entries of the dereferenced projection are missing.
575 ///
576 /// If this is the case keeping the `ProvInfo` entry around is meaningless.
577 fn is_trivial(&self) -> bool {
578 all![
579 !self.dereferenced_projection.is_empty(),
580 self.dereferenced_projection.iter().all(|x| x.is_none()),
581 ]
582 }
583}
584
585/// Attempts to find column bindings that make `input` redundant.
586///
587/// This method attempts to determine that `input` may be redundant by searching
588/// the join structure for another relation `other` with provenance that contains some
589/// provenance of `input`, and keys for `input` that are equated by the join to the
590/// corresponding columns of `other` under their provenance. The `input` provenance
591/// must also have its `exact` bit set.
592///
593/// In these circumstances, the claim is that because the key columns are equated and
594/// determine non-key columns, any matches between `input` and
595/// `other` will neither introduce new information to `other`, nor restrict the rows
596/// of `other`, nor alter their multplicity.
597fn find_redundancy(
598 input: usize,
599 keys: &[Vec<usize>],
600 input_mapper: &JoinInputMapper,
601 equivalences: &[Vec<MirScalarExpr>],
602 input_provs: &[Vec<ProvInfo>],
603) -> Option<Vec<MirScalarExpr>> {
604 // Whether the `equivalence` contains an expression that only references
605 // `input` that leads to the same as `root_expr` once dereferenced.
606 let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
607 root_expr: &MirScalarExpr,
608 input: usize,
609 provenance: &ProvInfo|
610 -> bool {
611 equivalence.iter().any(|expr| {
612 Some(input) == input_mapper.single_input(expr)
613 && provenance
614 .dereference(&input_mapper.map_expr_to_local(expr.clone()))
615 .as_ref()
616 == Some(root_expr)
617 })
618 };
619 for input_prov in input_provs[input].iter() {
620 // We can only elide if the input contains all records, and binds all columns.
621 if input_prov.exact
622 && input_prov
623 .dereferenced_projection
624 .iter()
625 .all(|e| e.is_some())
626 {
627 // examine all *other* inputs that have not been removed...
628 for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
629 for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
630 let all_key_columns_equated = |key: &Vec<usize>| {
631 key.iter().all(|key_col| {
632 // The root expression behind the key column, ie.
633 // the expression re-written in terms of elements in
634 // the projection of the Get operator.
635 let root_expr =
636 input_prov.dereference(&MirScalarExpr::column(*key_col));
637 // Check if there is a join equivalence that joins
638 // 'input' and 'other' on expressions that lead to
639 // the same root expression as the key column.
640 root_expr.as_ref().map_or(false, |root_expr| {
641 equivalences.iter().any(|equivalence| {
642 all![
643 contains_equivalent_expr_from_input(
644 equivalence,
645 root_expr,
646 input,
647 input_prov,
648 ),
649 contains_equivalent_expr_from_input(
650 equivalence,
651 root_expr,
652 other,
653 other_prov,
654 ),
655 ]
656 })
657 })
658 })
659 };
660
661 // Find an unique key for input that has all columns equated to other.
662 if keys.iter().any(all_key_columns_equated) {
663 // Find out whether we can produce input's projection strictly with
664 // elements in other's projection.
665 let expressions = input_prov
666 .dereferenced_projection
667 .iter()
668 .enumerate()
669 .flat_map(|(c, _)| {
670 // Check if the expression under input's 'c' column can be built
671 // with elements in other's projection.
672 input_prov.dereferenced_projection[c].as_ref().map_or(
673 None,
674 |root_expr| {
675 try_build_expression_using_other(
676 root_expr,
677 other,
678 other_prov,
679 input_mapper,
680 )
681 },
682 )
683 })
684 .collect_vec();
685 if expressions.len() == input_prov.dereferenced_projection.len() {
686 return Some(expressions);
687 }
688 }
689 }
690 }
691 }
692 }
693
694 None
695}
696
697/// Tries to build `root_expr` using elements from other's projection.
698fn try_build_expression_using_other(
699 root_expr: &MirScalarExpr,
700 other: usize,
701 other_prov: &ProvInfo,
702 input_mapper: &JoinInputMapper,
703) -> Option<MirScalarExpr> {
704 if root_expr.is_literal() {
705 return Some(root_expr.clone());
706 }
707
708 // Check if 'other' projects a column that lead to `root_expr`.
709 for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
710 if let Some(derefed) = derefed {
711 if derefed == root_expr {
712 return Some(MirScalarExpr::column(
713 input_mapper.map_column_to_global(other_col, other),
714 ));
715 }
716 }
717 }
718
719 // Otherwise, try to build root_expr's sub-expressions recursively
720 // other's projection.
721 match root_expr {
722 MirScalarExpr::Column(_, _) => None,
723 MirScalarExpr::CallUnary { func, expr } => {
724 try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
725 |expr| {
726 Some(MirScalarExpr::CallUnary {
727 func: func.clone(),
728 expr: Box::new(expr),
729 })
730 },
731 )
732 }
733 MirScalarExpr::CallBinary { func, expr1, expr2 } => {
734 try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
735 |expr1| {
736 try_build_expression_using_other(expr2, other, other_prov, input_mapper)
737 .and_then(|expr2| {
738 Some(MirScalarExpr::CallBinary {
739 func: func.clone(),
740 expr1: Box::new(expr1),
741 expr2: Box::new(expr2),
742 })
743 })
744 },
745 )
746 }
747 MirScalarExpr::CallVariadic { func, exprs } => {
748 let new_exprs = exprs
749 .iter()
750 .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
751 .collect_vec();
752 if new_exprs.len() == exprs.len() {
753 Some(MirScalarExpr::CallVariadic {
754 func: func.clone(),
755 exprs: new_exprs,
756 })
757 } else {
758 None
759 }
760 }
761 MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
762 Some(root_expr.clone())
763 }
764 MirScalarExpr::If { cond, then, els } => {
765 try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
766 |cond| {
767 try_build_expression_using_other(then, other, other_prov, input_mapper)
768 .and_then(|then| {
769 try_build_expression_using_other(els, other, other_prov, input_mapper)
770 .and_then(|els| {
771 Some(MirScalarExpr::If {
772 cond: Box::new(cond),
773 then: Box::new(then),
774 els: Box::new(els),
775 })
776 })
777 })
778 },
779 )
780 }
781 }
782}
783
784/// A context of `ProvInfo` vectors associated with bindings that might still be
785/// referenced.
786#[derive(Debug, Default)]
787pub struct ProvInfoCtx {
788 /// [`LocalId`] references in the remaining subtree.
789 ///
790 /// Entries from the `lets` map that are no longer used can be pruned.
791 uses: BTreeMap<LocalId, usize>,
792 /// [`ProvInfo`] vectors associated with let binding in scope.
793 lets: BTreeMap<LocalId, Vec<ProvInfo>>,
794}
795
796impl ProvInfoCtx {
797 /// Extend the `uses` map by the `LocalId`s used in `expr`.
798 pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
799 expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
800 MirRelationExpr::Get {
801 id: Id::Local(id), ..
802 } => {
803 let count = self.uses.entry(id.clone()).or_insert(0_usize);
804 *count += 1;
805 }
806 _ => (),
807 });
808 }
809
810 /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
811 /// `lets` entries for `uses` that reset to zero.
812 pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
813 let mut worklist = vec![expr];
814 while let Some(expr) = worklist.pop() {
815 if let MirRelationExpr::Get {
816 id: Id::Local(id), ..
817 } = expr
818 {
819 if let Some(count) = self.uses.get_mut(id) {
820 if *count > 0 {
821 *count -= 1;
822 }
823 if *count == 0 {
824 if self.lets.remove(id).is_none() {
825 soft_panic_or_log!("ctx.lets[{id}] should exist");
826 }
827 }
828 } else {
829 soft_panic_or_log!("ctx.uses[{id}] should exist");
830 }
831 }
832 match expr {
833 MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
834 // When traversing the tree, don't descend into
835 // `Let`/`LetRec` sub-terms in order to avoid double
836 // counting (those are handled by remove_uses calls of
837 // RedundantJoin::action on subterms that were already
838 // visited because the action works bottom-up).
839 }
840 _ => {
841 worklist.extend(expr.children().rev());
842 }
843 }
844 }
845 }
846
847 /// Get the `ProvInfo` vector for `id` from the context.
848 pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
849 self.lets.get(id)
850 }
851
852 /// Extend the context with the `id: prov_infos` entry.
853 pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
854 self.lets.insert(id, prov_infos)
855 }
856
857 /// Remove the entry identified by `id` from the context.
858 pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
859 self.lets.remove(id)
860 }
861}