mz_transform/
redundant_join.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31use mz_ore::{assert_none, soft_panic_or_log};
32
33use crate::{TransformCtx, all};
34
35/// Remove redundant collections of distinct elements from joins.
36#[derive(Debug)]
37pub struct RedundantJoin {
38    recursion_guard: RecursionGuard,
39}
40
41impl Default for RedundantJoin {
42    fn default() -> RedundantJoin {
43        RedundantJoin {
44            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
45        }
46    }
47}
48
49impl CheckedRecursion for RedundantJoin {
50    fn recursion_guard(&self) -> &RecursionGuard {
51        &self.recursion_guard
52    }
53}
54
55impl crate::Transform for RedundantJoin {
56    fn name(&self) -> &'static str {
57        "RedundantJoin"
58    }
59
60    #[mz_ore::instrument(
61        target = "optimizer",
62        level = "debug",
63        fields(path.segment = "redundant_join")
64    )]
65    fn actually_perform_transform(
66        &self,
67        relation: &mut MirRelationExpr,
68        _: &mut TransformCtx,
69    ) -> Result<(), crate::TransformError> {
70        let mut ctx = ProvInfoCtx::default();
71        ctx.extend_uses(relation);
72        let result = self.action(relation, &mut ctx);
73        mz_repr::explain::trace_plan(&*relation);
74        result.map(|_| ())
75    }
76}
77
78impl RedundantJoin {
79    /// Remove redundant collections of distinct elements from joins.
80    ///
81    /// This method tracks "provenance" information for each collections,
82    /// those being column-wise relationships to identified collections
83    /// (either imported collections, or let-bound collections). These
84    /// relationships state that when projected on to these columns, the
85    /// records of the one collection are contained in the records of the
86    /// identified collection.
87    ///
88    /// This provenance information is then used for the `MirRelationExpr::Join`
89    /// variant to remove "redundant" joins, those that can be determined to
90    /// neither restrict nor augment one of the input relations. Consult the
91    /// `find_redundancy` method and its documentation for more detail.
92    pub fn action(
93        &self,
94        relation: &mut MirRelationExpr,
95        ctx: &mut ProvInfoCtx,
96    ) -> Result<Vec<ProvInfo>, crate::TransformError> {
97        let mut result = self.checked_recur(|_| {
98            match relation {
99                MirRelationExpr::Let { id, value, body } => {
100                    // Recursively determine provenance of the value.
101                    let value_prov = self.action(value, ctx)?;
102                    // Clear uses from the just visited binding definition.
103                    ctx.remove_uses(value);
104
105                    // Extend the lets context with an entry for this binding.
106                    let prov_old = ctx.insert(*id, value_prov);
107                    assert_none!(prov_old, "No shadowing");
108
109                    // Determine provenance of the body.
110                    let result = self.action(body, ctx)?;
111                    ctx.remove_uses(body);
112
113                    // Remove the lets entry for this binding from the context.
114                    ctx.remove(id);
115
116                    Ok(result)
117                }
118
119                MirRelationExpr::LetRec {
120                    ids,
121                    values,
122                    limits: _,
123                    body,
124                } => {
125                    // As a first approximation, we naively extend the `lets`
126                    // context with the empty vec![] for each id.
127                    for id in ids.iter() {
128                        let prov_old = ctx.insert(*id, vec![]);
129                        assert_none!(prov_old, "No shadowing");
130                    }
131
132                    // In other words, we don't attempt to derive additional
133                    // provenance information for a binding from its `value`.
134                    //
135                    // We descend into the values and the body with the naively
136                    // extended context.
137                    for value in values.iter_mut() {
138                        self.action(value, ctx)?;
139                    }
140                    // Clear uses from the just visited recursive binding
141                    // definitions.
142                    for value in values.iter_mut() {
143                        ctx.remove_uses(value);
144                    }
145                    let result = self.action(body, ctx)?;
146                    ctx.remove_uses(body);
147
148                    // Remove the lets entries for all ids.
149                    for id in ids.iter() {
150                        ctx.remove(id);
151                    }
152
153                    Ok(result)
154                }
155
156                MirRelationExpr::Get { id, typ, .. } => {
157                    if let Id::Local(id) = id {
158                        // Extract the value provenance (this should always exist).
159                        let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
160                            soft_panic_or_log!("no ctx entry for LocalId {id}");
161                            vec![]
162                        });
163                        // Add information about being exactly this let binding too.
164                        val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
165                        Ok(val_info)
166                    } else {
167                        // Add information about being exactly this GlobalId reference.
168                        Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
169                    }
170                }
171
172                MirRelationExpr::Join {
173                    inputs,
174                    equivalences,
175                    implementation,
176                } => {
177                    // This logic first applies what it has learned about its input provenance,
178                    // and if it finds a redundant join input it removes it. In that case, it
179                    // also fails to produce exciting provenance information, partly out of
180                    // laziness and the challenge of ensuring it is correct. Instead, if it is
181                    // unable to find a redundant join it produces meaningful provenance information.
182
183                    // Recursively apply transformation, and determine the provenance of inputs.
184                    let mut input_prov = Vec::new();
185                    for i in inputs.iter_mut() {
186                        input_prov.push(self.action(i, ctx)?);
187                    }
188
189                    // Determine useful information about the structure of the inputs.
190                    let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
191                    let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
192
193                    // If we find an input that can be removed, we should do so!
194                    // We only do this once per invocation to keep our sanity, but we could
195                    // rewrite it to iterate. We can avoid looking for any relation that
196                    // does not have keys, as it cannot be redundant in that case.
197                    if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
198                        .rev()
199                        .filter(|i| !input_types[*i].keys.is_empty())
200                        .flat_map(|i| {
201                            find_redundancy(
202                                i,
203                                &input_types[i].keys,
204                                &old_input_mapper,
205                                equivalences,
206                                &input_prov[..],
207                            )
208                            .map(|b| (i, b))
209                        })
210                        .next()
211                    {
212                        // Clear uses from the removed input.
213                        ctx.remove_uses(&inputs[remove_input_idx]);
214
215                        inputs.remove(remove_input_idx);
216                        input_types.remove(remove_input_idx);
217
218                        // Update the column offsets in the binding expressions to catch
219                        // up with the removal of `remove_input_idx`.
220                        for expr in bindings.iter_mut() {
221                            expr.visit_pre_mut(|e| {
222                                if let MirScalarExpr::Column(c) = e {
223                                    let (_local_col, input_relation) =
224                                        old_input_mapper.map_column_to_local(*c);
225                                    if input_relation > remove_input_idx {
226                                        *c -= old_input_mapper.input_arity(remove_input_idx);
227                                    }
228                                }
229                            });
230                        }
231
232                        // Replace column references from `remove_input_idx` with the corresponding
233                        // binding expression. Update the offsets of the column references
234                        // from inputs after `remove_input_idx`.
235                        for equivalence in equivalences.iter_mut() {
236                            for expr in equivalence.iter_mut() {
237                                expr.visit_mut_post(&mut |e| {
238                                    if let MirScalarExpr::Column(c) = e {
239                                        let (local_col, input_relation) =
240                                            old_input_mapper.map_column_to_local(*c);
241                                        if input_relation == remove_input_idx {
242                                            *e = bindings[local_col].clone();
243                                        } else if input_relation > remove_input_idx {
244                                            *c -= old_input_mapper.input_arity(remove_input_idx);
245                                        }
246                                    }
247                                })?;
248                            }
249                        }
250
251                        mz_expr::canonicalize::canonicalize_equivalences(
252                            equivalences,
253                            input_types.iter().map(|t| &t.column_types),
254                        );
255
256                        // Build a projection that leaves the binding expressions in the same
257                        // position as the columns of the removed join input they are replacing.
258                        let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
259                        let mut projection = Vec::new();
260                        let new_join_arity = new_input_mapper.total_columns();
261                        for i in 0..old_input_mapper.total_inputs() {
262                            if i != remove_input_idx {
263                                projection.extend(
264                                    new_input_mapper.global_columns(if i < remove_input_idx {
265                                        i
266                                    } else {
267                                        i - 1
268                                    }),
269                                );
270                            } else {
271                                projection.extend(new_join_arity..new_join_arity + bindings.len());
272                            }
273                        }
274
275                        // Unset implementation, as irrevocably hosed by this transformation.
276                        *implementation = mz_expr::JoinImplementation::Unimplemented;
277
278                        *relation = relation.take_dangerous().map(bindings).project(projection);
279                        // The projection will gum up provenance reasoning anyhow, so don't work hard.
280                        // We will return to this expression again with the same analysis.
281                        Ok(Vec::new())
282                    } else {
283                        // Provenance information should be the union of input provenance information,
284                        // with columns updated. Because rows may be dropped in the join, all `exact`
285                        // bits should be un-set.
286                        let mut results = Vec::new();
287                        for (input, input_prov) in input_prov.into_iter().enumerate() {
288                            for mut prov in input_prov {
289                                prov.exact = false;
290                                let mut projection = vec![None; old_input_mapper.total_columns()];
291                                for (local_col, global_col) in
292                                    old_input_mapper.global_columns(input).enumerate()
293                                {
294                                    projection[global_col]
295                                        .clone_from(&prov.dereferenced_projection[local_col]);
296                                }
297                                prov.dereferenced_projection = projection;
298                                results.push(prov);
299                            }
300                        }
301                        Ok(results)
302                    }
303                }
304
305                MirRelationExpr::Filter { input, .. } => {
306                    // Filter may drop records, and so we unset `exact`.
307                    let mut result = self.action(input, ctx)?;
308                    for prov in result.iter_mut() {
309                        prov.exact = false;
310                    }
311                    Ok(result)
312                }
313
314                MirRelationExpr::Map { input, scalars } => {
315                    let mut result = self.action(input, ctx)?;
316                    for prov in result.iter_mut() {
317                        for scalar in scalars.iter() {
318                            let dereferenced_scalar = prov.strict_dereference(scalar);
319                            prov.dereferenced_projection.push(dereferenced_scalar);
320                        }
321                    }
322                    Ok(result)
323                }
324
325                MirRelationExpr::Union { base, inputs } => {
326                    let mut prov = self.action(base, ctx)?;
327                    for input in inputs {
328                        let input_prov = self.action(input, ctx)?;
329                        // To merge a new list of provenances, we look at the cross
330                        // produce of things we might know about each source.
331                        // TODO(mcsherry): this can be optimized to use datastructures
332                        // keyed by the source identifier.
333                        let mut new_prov = Vec::new();
334                        for l in prov {
335                            new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
336                        }
337                        prov = new_prov;
338                    }
339                    Ok(prov)
340                }
341
342                MirRelationExpr::Constant { .. } => Ok(Vec::new()),
343
344                MirRelationExpr::Reduce {
345                    input,
346                    group_key,
347                    aggregates,
348                    ..
349                } => {
350                    // Reduce yields its first few columns as a key, and produces
351                    // all key tuples that were present in its input.
352                    let mut result = self.action(input, ctx)?;
353                    for prov in result.iter_mut() {
354                        let mut projection = group_key
355                            .iter()
356                            .map(|key| prov.strict_dereference(key))
357                            .collect_vec();
358                        projection.extend((0..aggregates.len()).map(|_| None));
359                        prov.dereferenced_projection = projection;
360                    }
361                    // TODO: For min, max aggregates, we could preserve provenance
362                    // if the expression references a column. We would need to un-set
363                    // the `exact` bit in that case, and so we would want to keep both
364                    // sets of provenance information.
365                    Ok(result)
366                }
367
368                MirRelationExpr::Threshold { input } => {
369                    // Threshold may drop records, and so we unset `exact`.
370                    let mut result = self.action(input, ctx)?;
371                    for prov in result.iter_mut() {
372                        prov.exact = false;
373                    }
374                    Ok(result)
375                }
376
377                MirRelationExpr::TopK { input, .. } => {
378                    // TopK may drop records, and so we unset `exact`.
379                    let mut result = self.action(input, ctx)?;
380                    for prov in result.iter_mut() {
381                        prov.exact = false;
382                    }
383                    Ok(result)
384                }
385
386                MirRelationExpr::Project { input, outputs } => {
387                    // Projections re-order, drop, and duplicate columns,
388                    // but they neither drop rows nor invent values.
389                    let mut result = self.action(input, ctx)?;
390                    for prov in result.iter_mut() {
391                        let projection = outputs
392                            .iter()
393                            .map(|c| prov.dereference(&MirScalarExpr::Column(*c)))
394                            .collect_vec();
395                        prov.dereferenced_projection = projection;
396                    }
397                    Ok(result)
398                }
399
400                MirRelationExpr::FlatMap { input, func, .. } => {
401                    // FlatMap may drop records, and so we unset `exact`.
402                    let mut result = self.action(input, ctx)?;
403                    for prov in result.iter_mut() {
404                        prov.exact = false;
405                        prov.dereferenced_projection
406                            .extend((0..func.output_type().column_types.len()).map(|_| None));
407                    }
408                    Ok(result)
409                }
410
411                MirRelationExpr::Negate { input } => {
412                    // Negate does not guarantee that the multiplicity of
413                    // each source record it at least one. This could have
414                    // been a problem in `Union`, where we might report
415                    // that the union of positive and negative records is
416                    // "exact": cancellations would make this false.
417                    let mut result = self.action(input, ctx)?;
418                    for prov in result.iter_mut() {
419                        prov.exact = false;
420                    }
421                    Ok(result)
422                }
423
424                MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
425            }
426        })?;
427        result.retain(|info| !info.is_trivial());
428
429        // Uncomment the following lines to trace the individual steps:
430        // println!("{}", relation.pretty());
431        // println!("result = {result:?}");
432        // println!("lets: {lets:?}");
433        // println!("---------------------");
434
435        Ok(result)
436    }
437}
438
439/// A relationship between a collections columns and some source columns.
440///
441/// An instance of this type indicates that some of the bearer's columns
442/// derive from `id`. In particular, the non-`None` elements in
443/// `dereferenced_projection` correspond to columns that can be derived
444/// from `id`'s projection.
445///
446/// The guarantee is that projected on to these columns, the distinct values
447/// of the bearer are contained in the set of distinct values of projected
448/// columns of `id`. In the case that `exact` is set, the two sets are equal.
449#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
450pub struct ProvInfo {
451    /// The Id (local or global) of the source.
452    id: Id,
453    /// The projection of the bearer written in terms of the columns projected
454    /// by the underlying Get operator. Set to `None` for columns that cannot
455    /// be expressed as scalar expression referencing only columns of the
456    /// underlying Get operator.
457    dereferenced_projection: Vec<Option<MirScalarExpr>>,
458    /// If true, all distinct projected source rows are present in the rows of
459    /// the projection of the current collection. This constraint is lost as soon
460    /// as a transformation may drop records.
461    exact: bool,
462}
463
464impl ProvInfo {
465    fn make_leaf(id: Id, arity: usize) -> Self {
466        Self {
467            id,
468            dereferenced_projection: (0..arity)
469                .map(|c| Some(MirScalarExpr::column(c)))
470                .collect::<Vec<_>>(),
471            exact: true,
472        }
473    }
474
475    /// Rewrite `expr` so it refers to the columns of the original source instead
476    /// of the columns of the projected source.
477    fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
478        match expr {
479            MirScalarExpr::Column(c) => {
480                if let Some(expr) = &self.dereferenced_projection[*c] {
481                    Some(expr.clone())
482                } else {
483                    None
484                }
485            }
486            MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
487                Some(MirScalarExpr::CallUnary {
488                    func: func.clone(),
489                    expr: Box::new(expr),
490                })
491            }),
492            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
493                self.dereference(expr1).and_then(|expr1| {
494                    self.dereference(expr2).and_then(|expr2| {
495                        Some(MirScalarExpr::CallBinary {
496                            func: func.clone(),
497                            expr1: Box::new(expr1),
498                            expr2: Box::new(expr2),
499                        })
500                    })
501                })
502            }
503            MirScalarExpr::CallVariadic { func, exprs } => {
504                let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
505                if new_exprs.len() == exprs.len() {
506                    Some(MirScalarExpr::CallVariadic {
507                        func: func.clone(),
508                        exprs: new_exprs,
509                    })
510                } else {
511                    None
512                }
513            }
514            MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
515                Some(expr.clone())
516            }
517            MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
518                self.dereference(then).and_then(|then| {
519                    self.dereference(els).and_then(|els| {
520                        Some(MirScalarExpr::If {
521                            cond: Box::new(cond),
522                            then: Box::new(then),
523                            els: Box::new(els),
524                        })
525                    })
526                })
527            }),
528        }
529    }
530
531    /// Like `dereference` but only returns expressions that actually depend on
532    /// the original source.
533    fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
534        let derefed = self.dereference(expr);
535        match derefed {
536            Some(ref expr) if !expr.support().is_empty() => derefed,
537            _ => None,
538        }
539    }
540
541    /// Merge two constraints to find a constraint that satisfies both inputs.
542    ///
543    /// This method returns nothing if no columns are in common (either because
544    /// difference sources are identified, or just no columns in common) and it
545    /// intersects bindings and the `exact` bit.
546    fn meet(&self, other: &Self) -> Option<Self> {
547        if self.id == other.id {
548            let resulting_projection = self
549                .dereferenced_projection
550                .iter()
551                .zip(other.dereferenced_projection.iter())
552                .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
553                .collect_vec();
554            if resulting_projection.iter().any(|e| e.is_some()) {
555                Some(ProvInfo {
556                    id: self.id,
557                    dereferenced_projection: resulting_projection,
558                    exact: self.exact && other.exact,
559                })
560            } else {
561                None
562            }
563        } else {
564            None
565        }
566    }
567
568    /// Check if all entries of the dereferenced projection are missing.
569    ///
570    /// If this is the case keeping the `ProvInfo` entry around is meaningless.
571    fn is_trivial(&self) -> bool {
572        all![
573            !self.dereferenced_projection.is_empty(),
574            self.dereferenced_projection.iter().all(|x| x.is_none()),
575        ]
576    }
577}
578
579/// Attempts to find column bindings that make `input` redundant.
580///
581/// This method attempts to determine that `input` may be redundant by searching
582/// the join structure for another relation `other` with provenance that contains some
583/// provenance of `input`, and keys for `input` that are equated by the join to the
584/// corresponding columns of `other` under their provenance. The `input` provenance
585/// must also have its `exact` bit set.
586///
587/// In these circumstances, the claim is that because the key columns are equated and
588/// determine non-key columns, any matches between `input` and
589/// `other` will neither introduce new information to `other`, nor restrict the rows
590/// of `other`, nor alter their multplicity.
591fn find_redundancy(
592    input: usize,
593    keys: &[Vec<usize>],
594    input_mapper: &JoinInputMapper,
595    equivalences: &[Vec<MirScalarExpr>],
596    input_provs: &[Vec<ProvInfo>],
597) -> Option<Vec<MirScalarExpr>> {
598    // Whether the `equivalence` contains an expression that only references
599    // `input` that leads to the same as `root_expr` once dereferenced.
600    let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
601                                               root_expr: &MirScalarExpr,
602                                               input: usize,
603                                               provenance: &ProvInfo|
604     -> bool {
605        equivalence.iter().any(|expr| {
606            Some(input) == input_mapper.single_input(expr)
607                && provenance
608                    .dereference(&input_mapper.map_expr_to_local(expr.clone()))
609                    .as_ref()
610                    == Some(root_expr)
611        })
612    };
613    for input_prov in input_provs[input].iter() {
614        // We can only elide if the input contains all records, and binds all columns.
615        if input_prov.exact
616            && input_prov
617                .dereferenced_projection
618                .iter()
619                .all(|e| e.is_some())
620        {
621            // examine all *other* inputs that have not been removed...
622            for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
623                for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
624                    let all_key_columns_equated = |key: &Vec<usize>| {
625                        key.iter().all(|key_col| {
626                            // The root expression behind the key column, ie.
627                            // the expression re-written in terms of elements in
628                            // the projection of the Get operator.
629                            let root_expr =
630                                input_prov.dereference(&MirScalarExpr::column(*key_col));
631                            // Check if there is a join equivalence that joins
632                            // 'input' and 'other' on expressions that lead to
633                            // the same root expression as the key column.
634                            root_expr.as_ref().map_or(false, |root_expr| {
635                                equivalences.iter().any(|equivalence| {
636                                    all![
637                                        contains_equivalent_expr_from_input(
638                                            equivalence,
639                                            root_expr,
640                                            input,
641                                            input_prov,
642                                        ),
643                                        contains_equivalent_expr_from_input(
644                                            equivalence,
645                                            root_expr,
646                                            other,
647                                            other_prov,
648                                        ),
649                                    ]
650                                })
651                            })
652                        })
653                    };
654
655                    // Find an unique key for input that has all columns equated to other.
656                    if keys.iter().any(all_key_columns_equated) {
657                        // Find out whether we can produce input's projection strictly with
658                        // elements in other's projection.
659                        let expressions = input_prov
660                            .dereferenced_projection
661                            .iter()
662                            .enumerate()
663                            .flat_map(|(c, _)| {
664                                // Check if the expression under input's 'c' column can be built
665                                // with elements in other's projection.
666                                input_prov.dereferenced_projection[c].as_ref().map_or(
667                                    None,
668                                    |root_expr| {
669                                        try_build_expression_using_other(
670                                            root_expr,
671                                            other,
672                                            other_prov,
673                                            input_mapper,
674                                        )
675                                    },
676                                )
677                            })
678                            .collect_vec();
679                        if expressions.len() == input_prov.dereferenced_projection.len() {
680                            return Some(expressions);
681                        }
682                    }
683                }
684            }
685        }
686    }
687
688    None
689}
690
691/// Tries to build `root_expr` using elements from other's projection.
692fn try_build_expression_using_other(
693    root_expr: &MirScalarExpr,
694    other: usize,
695    other_prov: &ProvInfo,
696    input_mapper: &JoinInputMapper,
697) -> Option<MirScalarExpr> {
698    if root_expr.is_literal() {
699        return Some(root_expr.clone());
700    }
701
702    // Check if 'other' projects a column that lead to `root_expr`.
703    for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
704        if let Some(derefed) = derefed {
705            if derefed == root_expr {
706                return Some(MirScalarExpr::Column(
707                    input_mapper.map_column_to_global(other_col, other),
708                ));
709            }
710        }
711    }
712
713    // Otherwise, try to build root_expr's sub-expressions recursively
714    // other's projection.
715    match root_expr {
716        MirScalarExpr::Column(_) => None,
717        MirScalarExpr::CallUnary { func, expr } => {
718            try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
719                |expr| {
720                    Some(MirScalarExpr::CallUnary {
721                        func: func.clone(),
722                        expr: Box::new(expr),
723                    })
724                },
725            )
726        }
727        MirScalarExpr::CallBinary { func, expr1, expr2 } => {
728            try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
729                |expr1| {
730                    try_build_expression_using_other(expr2, other, other_prov, input_mapper)
731                        .and_then(|expr2| {
732                            Some(MirScalarExpr::CallBinary {
733                                func: func.clone(),
734                                expr1: Box::new(expr1),
735                                expr2: Box::new(expr2),
736                            })
737                        })
738                },
739            )
740        }
741        MirScalarExpr::CallVariadic { func, exprs } => {
742            let new_exprs = exprs
743                .iter()
744                .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
745                .collect_vec();
746            if new_exprs.len() == exprs.len() {
747                Some(MirScalarExpr::CallVariadic {
748                    func: func.clone(),
749                    exprs: new_exprs,
750                })
751            } else {
752                None
753            }
754        }
755        MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
756            Some(root_expr.clone())
757        }
758        MirScalarExpr::If { cond, then, els } => {
759            try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
760                |cond| {
761                    try_build_expression_using_other(then, other, other_prov, input_mapper)
762                        .and_then(|then| {
763                            try_build_expression_using_other(els, other, other_prov, input_mapper)
764                                .and_then(|els| {
765                                    Some(MirScalarExpr::If {
766                                        cond: Box::new(cond),
767                                        then: Box::new(then),
768                                        els: Box::new(els),
769                                    })
770                                })
771                        })
772                },
773            )
774        }
775    }
776}
777
778/// A context of `ProvInfo` vectors associated with bindings that might still be
779/// referenced.
780#[derive(Debug, Default)]
781pub struct ProvInfoCtx {
782    /// [`LocalId`] references in the remaining subtree.
783    ///
784    /// Entries from the `lets` map that are no longer used can be pruned.
785    uses: BTreeMap<LocalId, usize>,
786    /// [`ProvInfo`] vectors associated with let binding in scope.
787    lets: BTreeMap<LocalId, Vec<ProvInfo>>,
788}
789
790impl ProvInfoCtx {
791    /// Extend the `uses` map by the `LocalId`s used in `expr`.
792    pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
793        expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
794            MirRelationExpr::Get {
795                id: Id::Local(id), ..
796            } => {
797                let count = self.uses.entry(id.clone()).or_insert(0_usize);
798                *count += 1;
799            }
800            _ => (),
801        });
802    }
803
804    /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
805    /// `lets` entries for `uses` that reset to zero.
806    pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
807        let mut worklist = vec![expr];
808        while let Some(expr) = worklist.pop() {
809            if let MirRelationExpr::Get {
810                id: Id::Local(id), ..
811            } = expr
812            {
813                if let Some(count) = self.uses.get_mut(id) {
814                    if *count > 0 {
815                        *count -= 1;
816                    }
817                    if *count == 0 {
818                        if self.lets.remove(id).is_none() {
819                            soft_panic_or_log!("ctx.lets[{id}] should exist");
820                        }
821                    }
822                } else {
823                    soft_panic_or_log!("ctx.uses[{id}] should exist");
824                }
825            }
826            match expr {
827                MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
828                    // When traversing the tree, don't descend into
829                    // `Let`/`LetRec` sub-terms in order to avoid double
830                    // counting (those are handled by remove_uses calls of
831                    // RedundantJoin::action on subterms that were already
832                    // visited because the action works bottom-up).
833                }
834                _ => {
835                    worklist.extend(expr.children().rev());
836                }
837            }
838        }
839    }
840
841    /// Get the `ProvInfo` vector for `id` from the context.
842    pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
843        self.lets.get(id)
844    }
845
846    /// Extend the context with the `id: prov_infos` entry.
847    pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
848        self.lets.insert(id, prov_infos)
849    }
850
851    /// Remove the entry identified by `id` from the context.
852    pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
853        self.lets.remove(id)
854    }
855}