Skip to main content

mz_transform/
redundant_join.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{
30    Columns, Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT,
31};
32use mz_ore::stack::{CheckedRecursion, RecursionGuard};
33use mz_ore::{assert_none, soft_panic_or_log};
34
35use crate::{TransformCtx, all};
36
37/// Remove redundant collections of distinct elements from joins.
38#[derive(Debug)]
39pub struct RedundantJoin {
40    recursion_guard: RecursionGuard,
41}
42
43impl Default for RedundantJoin {
44    fn default() -> RedundantJoin {
45        RedundantJoin {
46            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
47        }
48    }
49}
50
51impl CheckedRecursion for RedundantJoin {
52    fn recursion_guard(&self) -> &RecursionGuard {
53        &self.recursion_guard
54    }
55}
56
57impl crate::Transform for RedundantJoin {
58    fn name(&self) -> &'static str {
59        "RedundantJoin"
60    }
61
62    #[mz_ore::instrument(
63        target = "optimizer",
64        level = "debug",
65        fields(path.segment = "redundant_join")
66    )]
67    fn actually_perform_transform(
68        &self,
69        relation: &mut MirRelationExpr,
70        _: &mut TransformCtx,
71    ) -> Result<(), crate::TransformError> {
72        let mut ctx = ProvInfoCtx::default();
73        ctx.extend_uses(relation);
74        let result = self.action(relation, &mut ctx);
75        mz_repr::explain::trace_plan(&*relation);
76        result.map(|_| ())
77    }
78}
79
80impl RedundantJoin {
81    /// Remove redundant collections of distinct elements from joins.
82    ///
83    /// This method tracks "provenance" information for each collections,
84    /// those being column-wise relationships to identified collections
85    /// (either imported collections, or let-bound collections). These
86    /// relationships state that when projected on to these columns, the
87    /// records of the one collection are contained in the records of the
88    /// identified collection.
89    ///
90    /// This provenance information is then used for the `MirRelationExpr::Join`
91    /// variant to remove "redundant" joins, those that can be determined to
92    /// neither restrict nor augment one of the input relations. Consult the
93    /// `find_redundancy` method and its documentation for more detail.
94    pub fn action(
95        &self,
96        relation: &mut MirRelationExpr,
97        ctx: &mut ProvInfoCtx,
98    ) -> Result<Vec<ProvInfo>, crate::TransformError> {
99        let mut result = self.checked_recur(|_| {
100            match relation {
101                MirRelationExpr::Let { id, value, body } => {
102                    // Recursively determine provenance of the value.
103                    let value_prov = self.action(value, ctx)?;
104                    // Clear uses from the just visited binding definition.
105                    ctx.remove_uses(value);
106
107                    // Extend the lets context with an entry for this binding.
108                    let prov_old = ctx.insert(*id, value_prov);
109                    assert_none!(prov_old, "No shadowing");
110
111                    // Determine provenance of the body.
112                    let result = self.action(body, ctx)?;
113                    ctx.remove_uses(body);
114
115                    // Remove the lets entry for this binding from the context.
116                    ctx.remove(id);
117
118                    Ok(result)
119                }
120
121                MirRelationExpr::LetRec {
122                    ids,
123                    values,
124                    limits: _,
125                    body,
126                } => {
127                    // As a first approximation, we naively extend the `lets`
128                    // context with the empty vec![] for each id.
129                    for id in ids.iter() {
130                        let prov_old = ctx.insert(*id, vec![]);
131                        assert_none!(prov_old, "No shadowing");
132                    }
133
134                    // In other words, we don't attempt to derive additional
135                    // provenance information for a binding from its `value`.
136                    //
137                    // We descend into the values and the body with the naively
138                    // extended context.
139                    for value in values.iter_mut() {
140                        self.action(value, ctx)?;
141                    }
142                    // Clear uses from the just visited recursive binding
143                    // definitions.
144                    for value in values.iter_mut() {
145                        ctx.remove_uses(value);
146                    }
147                    let result = self.action(body, ctx)?;
148                    ctx.remove_uses(body);
149
150                    // Remove the lets entries for all ids.
151                    for id in ids.iter() {
152                        ctx.remove(id);
153                    }
154
155                    Ok(result)
156                }
157
158                MirRelationExpr::Get { id, typ, .. } => {
159                    if let Id::Local(id) = id {
160                        // Extract the value provenance (this should always exist).
161                        let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
162                            soft_panic_or_log!("no ctx entry for LocalId {id}");
163                            vec![]
164                        });
165                        // Add information about being exactly this let binding too.
166                        val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
167                        Ok(val_info)
168                    } else {
169                        // Add information about being exactly this GlobalId reference.
170                        Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
171                    }
172                }
173
174                MirRelationExpr::Join {
175                    inputs,
176                    equivalences,
177                    implementation,
178                } => {
179                    // This logic first applies what it has learned about its input provenance,
180                    // and if it finds a redundant join input it removes it. In that case, it
181                    // also fails to produce exciting provenance information, partly out of
182                    // laziness and the challenge of ensuring it is correct. Instead, if it is
183                    // unable to find a redundant join it produces meaningful provenance information.
184
185                    // Recursively apply transformation, and determine the provenance of inputs.
186                    let mut input_prov = Vec::new();
187                    for i in inputs.iter_mut() {
188                        input_prov.push(self.action(i, ctx)?);
189                    }
190
191                    // Determine useful information about the structure of the inputs.
192                    let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
193                    let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
194
195                    // If we find an input that can be removed, we should do so!
196                    // We only do this once per invocation to keep our sanity, but we could
197                    // rewrite it to iterate. We can avoid looking for any relation that
198                    // does not have keys, as it cannot be redundant in that case.
199                    if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
200                        .rev()
201                        .filter(|i| !input_types[*i].keys.is_empty())
202                        .flat_map(|i| {
203                            find_redundancy(
204                                i,
205                                &input_types[i].keys,
206                                &old_input_mapper,
207                                equivalences,
208                                &input_prov[..],
209                            )
210                            .map(|b| (i, b))
211                        })
212                        .next()
213                    {
214                        // Clear uses from the removed input.
215                        ctx.remove_uses(&inputs[remove_input_idx]);
216
217                        inputs.remove(remove_input_idx);
218                        input_types.remove(remove_input_idx);
219
220                        // Update the column offsets in the binding expressions to catch
221                        // up with the removal of `remove_input_idx`.
222                        for expr in bindings.iter_mut() {
223                            expr.visit_pre_mut(|e| {
224                                if let MirScalarExpr::Column(c, _) = e {
225                                    let (_local_col, input_relation) =
226                                        old_input_mapper.map_column_to_local(*c);
227                                    if input_relation > remove_input_idx {
228                                        *c -= old_input_mapper.input_arity(remove_input_idx);
229                                    }
230                                }
231                            });
232                        }
233
234                        // Replace column references from `remove_input_idx` with the corresponding
235                        // binding expression. Update the offsets of the column references
236                        // from inputs after `remove_input_idx`.
237                        for equivalence in equivalences.iter_mut() {
238                            for expr in equivalence.iter_mut() {
239                                expr.visit_mut_post(&mut |e| {
240                                    if let MirScalarExpr::Column(c, _) = e {
241                                        let (local_col, input_relation) =
242                                            old_input_mapper.map_column_to_local(*c);
243                                        if input_relation == remove_input_idx {
244                                            *e = bindings[local_col].clone();
245                                        } else if input_relation > remove_input_idx {
246                                            *c -= old_input_mapper.input_arity(remove_input_idx);
247                                        }
248                                    }
249                                })?;
250                            }
251                        }
252
253                        mz_expr::canonicalize::canonicalize_equivalences(
254                            equivalences,
255                            input_types.iter().map(|t| &t.column_types),
256                        );
257
258                        // Build a projection that leaves the binding expressions in the same
259                        // position as the columns of the removed join input they are replacing.
260                        let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
261                        let mut projection = Vec::new();
262                        let new_join_arity = new_input_mapper.total_columns();
263                        for i in 0..old_input_mapper.total_inputs() {
264                            if i != remove_input_idx {
265                                projection.extend(
266                                    new_input_mapper.global_columns(if i < remove_input_idx {
267                                        i
268                                    } else {
269                                        i - 1
270                                    }),
271                                );
272                            } else {
273                                projection.extend(new_join_arity..new_join_arity + bindings.len());
274                            }
275                        }
276
277                        // Unset implementation, as irrevocably hosed by this transformation.
278                        *implementation = mz_expr::JoinImplementation::Unimplemented;
279
280                        *relation = relation.take_dangerous().map(bindings).project(projection);
281                        // The projection will gum up provenance reasoning anyhow, so don't work hard.
282                        // We will return to this expression again with the same analysis.
283                        Ok(Vec::new())
284                    } else {
285                        // Provenance information should be the union of input provenance information,
286                        // with columns updated. Because rows may be dropped in the join, all `exact`
287                        // bits should be un-set.
288                        let mut results = Vec::new();
289                        for (input, input_prov) in input_prov.into_iter().enumerate() {
290                            for mut prov in input_prov {
291                                prov.exact = false;
292                                let mut projection = vec![None; old_input_mapper.total_columns()];
293                                for (local_col, global_col) in
294                                    old_input_mapper.global_columns(input).enumerate()
295                                {
296                                    projection[global_col]
297                                        .clone_from(&prov.dereferenced_projection[local_col]);
298                                }
299                                prov.dereferenced_projection = projection;
300                                results.push(prov);
301                            }
302                        }
303                        Ok(results)
304                    }
305                }
306
307                MirRelationExpr::Filter { input, .. } => {
308                    // Filter may drop records, and so we unset `exact`.
309                    let mut result = self.action(input, ctx)?;
310                    for prov in result.iter_mut() {
311                        prov.exact = false;
312                    }
313                    Ok(result)
314                }
315
316                MirRelationExpr::Map { input, scalars } => {
317                    let mut result = self.action(input, ctx)?;
318                    for prov in result.iter_mut() {
319                        for scalar in scalars.iter() {
320                            let dereferenced_scalar = prov.strict_dereference(scalar);
321                            prov.dereferenced_projection.push(dereferenced_scalar);
322                        }
323                    }
324                    Ok(result)
325                }
326
327                MirRelationExpr::Union { base, inputs } => {
328                    let mut prov = self.action(base, ctx)?;
329                    for input in inputs {
330                        let input_prov = self.action(input, ctx)?;
331                        // To merge a new list of provenances, we look at the cross
332                        // produce of things we might know about each source.
333                        // TODO(mcsherry): this can be optimized to use datastructures
334                        // keyed by the source identifier.
335                        let mut new_prov = Vec::new();
336                        for l in prov {
337                            new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
338                        }
339                        prov = new_prov;
340                    }
341                    Ok(prov)
342                }
343
344                MirRelationExpr::Constant { .. } => Ok(Vec::new()),
345
346                MirRelationExpr::Reduce {
347                    input,
348                    group_key,
349                    aggregates,
350                    ..
351                } => {
352                    // Reduce yields its first few columns as a key, and produces
353                    // all key tuples that were present in its input.
354                    let mut result = self.action(input, ctx)?;
355                    for prov in result.iter_mut() {
356                        let mut projection = group_key
357                            .iter()
358                            .map(|key| prov.strict_dereference(key))
359                            .collect_vec();
360                        projection.extend((0..aggregates.len()).map(|_| None));
361                        prov.dereferenced_projection = projection;
362                    }
363                    // TODO: For min, max aggregates, we could preserve provenance
364                    // if the expression references a column. We would need to un-set
365                    // the `exact` bit in that case, and so we would want to keep both
366                    // sets of provenance information.
367                    Ok(result)
368                }
369
370                MirRelationExpr::Threshold { input } => {
371                    // Threshold may drop records, and so we unset `exact`.
372                    let mut result = self.action(input, ctx)?;
373                    for prov in result.iter_mut() {
374                        prov.exact = false;
375                    }
376                    Ok(result)
377                }
378
379                MirRelationExpr::TopK { input, .. } => {
380                    // TopK may drop records, and so we unset `exact`.
381                    let mut result = self.action(input, ctx)?;
382                    for prov in result.iter_mut() {
383                        prov.exact = false;
384                    }
385                    Ok(result)
386                }
387
388                MirRelationExpr::Project { input, outputs } => {
389                    // Projections re-order, drop, and duplicate columns,
390                    // but they neither drop rows nor invent values.
391                    let mut result = self.action(input, ctx)?;
392                    for prov in result.iter_mut() {
393                        let projection = outputs
394                            .iter()
395                            .map(|c| prov.dereference(&MirScalarExpr::column(*c)))
396                            .collect_vec();
397                        prov.dereferenced_projection = projection;
398                    }
399                    Ok(result)
400                }
401
402                MirRelationExpr::FlatMap {
403                    input,
404                    func,
405                    exprs: _,
406                } => {
407                    // FlatMap may drop records, and so we unset `exact`.
408                    let mut result = self.action(input, ctx)?;
409                    for prov in result.iter_mut() {
410                        prov.exact = false;
411                        prov.dereferenced_projection
412                            .extend((0..func.output_arity()).map(|_| None));
413                    }
414                    Ok(result)
415                }
416
417                MirRelationExpr::Negate { input } => {
418                    // Negate does not guarantee that the multiplicity of
419                    // each source record it at least one. This could have
420                    // been a problem in `Union`, where we might report
421                    // that the union of positive and negative records is
422                    // "exact": cancellations would make this false.
423                    let mut result = self.action(input, ctx)?;
424                    for prov in result.iter_mut() {
425                        prov.exact = false;
426                    }
427                    Ok(result)
428                }
429
430                MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
431            }
432        })?;
433        result.retain(|info| !info.is_trivial());
434
435        // Uncomment the following lines to trace the individual steps:
436        // println!("{}", relation.pretty());
437        // println!("result = {result:?}");
438        // println!("lets: {lets:?}");
439        // println!("---------------------");
440
441        Ok(result)
442    }
443}
444
445/// A relationship between a collections columns and some source columns.
446///
447/// An instance of this type indicates that some of the bearer's columns
448/// derive from `id`. In particular, the non-`None` elements in
449/// `dereferenced_projection` correspond to columns that can be derived
450/// from `id`'s projection.
451///
452/// The guarantee is that projected on to these columns, the distinct values
453/// of the bearer are contained in the set of distinct values of projected
454/// columns of `id`. In the case that `exact` is set, the two sets are equal.
455#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
456pub struct ProvInfo {
457    /// The Id (local or global) of the source.
458    id: Id,
459    /// The projection of the bearer written in terms of the columns projected
460    /// by the underlying Get operator. Set to `None` for columns that cannot
461    /// be expressed as scalar expression referencing only columns of the
462    /// underlying Get operator.
463    dereferenced_projection: Vec<Option<MirScalarExpr>>,
464    /// If true, all distinct projected source rows are present in the rows of
465    /// the projection of the current collection. This constraint is lost as soon
466    /// as a transformation may drop records.
467    exact: bool,
468}
469
470impl ProvInfo {
471    fn make_leaf(id: Id, arity: usize) -> Self {
472        Self {
473            id,
474            dereferenced_projection: (0..arity)
475                .map(|c| Some(MirScalarExpr::column(c)))
476                .collect::<Vec<_>>(),
477            exact: true,
478        }
479    }
480
481    /// Rewrite `expr` so it refers to the columns of the original source instead
482    /// of the columns of the projected source.
483    fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
484        match expr {
485            MirScalarExpr::Column(c, _) => {
486                if let Some(expr) = &self.dereferenced_projection[*c] {
487                    Some(expr.clone())
488                } else {
489                    None
490                }
491            }
492            MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
493                Some(MirScalarExpr::CallUnary {
494                    func: func.clone(),
495                    expr: Box::new(expr),
496                })
497            }),
498            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
499                self.dereference(expr1).and_then(|expr1| {
500                    self.dereference(expr2).and_then(|expr2| {
501                        Some(MirScalarExpr::CallBinary {
502                            func: func.clone(),
503                            expr1: Box::new(expr1),
504                            expr2: Box::new(expr2),
505                        })
506                    })
507                })
508            }
509            MirScalarExpr::CallVariadic { func, exprs } => {
510                let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
511                if new_exprs.len() == exprs.len() {
512                    Some(MirScalarExpr::CallVariadic {
513                        func: func.clone(),
514                        exprs: new_exprs,
515                    })
516                } else {
517                    None
518                }
519            }
520            MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
521                Some(expr.clone())
522            }
523            MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
524                self.dereference(then).and_then(|then| {
525                    self.dereference(els).and_then(|els| {
526                        Some(MirScalarExpr::If {
527                            cond: Box::new(cond),
528                            then: Box::new(then),
529                            els: Box::new(els),
530                        })
531                    })
532                })
533            }),
534        }
535    }
536
537    /// Like `dereference` but only returns expressions that actually depend on
538    /// the original source.
539    fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
540        let derefed = self.dereference(expr);
541        match derefed {
542            Some(ref expr) if !expr.support().is_empty() => derefed,
543            _ => None,
544        }
545    }
546
547    /// Merge two constraints to find a constraint that satisfies both inputs.
548    ///
549    /// This method returns nothing if no columns are in common (either because
550    /// difference sources are identified, or just no columns in common) and it
551    /// intersects bindings and the `exact` bit.
552    fn meet(&self, other: &Self) -> Option<Self> {
553        if self.id == other.id {
554            let resulting_projection = self
555                .dereferenced_projection
556                .iter()
557                .zip_eq(other.dereferenced_projection.iter())
558                .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
559                .collect_vec();
560            if resulting_projection.iter().any(|e| e.is_some()) {
561                Some(ProvInfo {
562                    id: self.id,
563                    dereferenced_projection: resulting_projection,
564                    exact: self.exact && other.exact,
565                })
566            } else {
567                None
568            }
569        } else {
570            None
571        }
572    }
573
574    /// Check if all entries of the dereferenced projection are missing.
575    ///
576    /// If this is the case keeping the `ProvInfo` entry around is meaningless.
577    fn is_trivial(&self) -> bool {
578        all![
579            !self.dereferenced_projection.is_empty(),
580            self.dereferenced_projection.iter().all(|x| x.is_none()),
581        ]
582    }
583}
584
585/// Attempts to find column bindings that make `input` redundant.
586///
587/// This method attempts to determine that `input` may be redundant by searching
588/// the join structure for another relation `other` with provenance that contains some
589/// provenance of `input`, and keys for `input` that are equated by the join to the
590/// corresponding columns of `other` under their provenance. The `input` provenance
591/// must also have its `exact` bit set.
592///
593/// In these circumstances, the claim is that because the key columns are equated and
594/// determine non-key columns, any matches between `input` and
595/// `other` will neither introduce new information to `other`, nor restrict the rows
596/// of `other`, nor alter their multplicity.
597fn find_redundancy(
598    input: usize,
599    keys: &[Vec<usize>],
600    input_mapper: &JoinInputMapper,
601    equivalences: &[Vec<MirScalarExpr>],
602    input_provs: &[Vec<ProvInfo>],
603) -> Option<Vec<MirScalarExpr>> {
604    // Whether the `equivalence` contains an expression that only references
605    // `input` that leads to the same as `root_expr` once dereferenced.
606    let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
607                                               root_expr: &MirScalarExpr,
608                                               input: usize,
609                                               provenance: &ProvInfo|
610     -> bool {
611        equivalence.iter().any(|expr| {
612            Some(input) == input_mapper.single_input(expr)
613                && provenance
614                    .dereference(&input_mapper.map_expr_to_local(expr.clone()))
615                    .as_ref()
616                    == Some(root_expr)
617        })
618    };
619    for input_prov in input_provs[input].iter() {
620        // We can only elide if the input contains all records, and binds all columns.
621        if input_prov.exact
622            && input_prov
623                .dereferenced_projection
624                .iter()
625                .all(|e| e.is_some())
626        {
627            // examine all *other* inputs that have not been removed...
628            for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
629                for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
630                    let all_key_columns_equated = |key: &Vec<usize>| {
631                        key.iter().all(|key_col| {
632                            // The root expression behind the key column, ie.
633                            // the expression re-written in terms of elements in
634                            // the projection of the Get operator.
635                            let root_expr =
636                                input_prov.dereference(&MirScalarExpr::column(*key_col));
637                            // Check if there is a join equivalence that joins
638                            // 'input' and 'other' on expressions that lead to
639                            // the same root expression as the key column.
640                            root_expr.as_ref().map_or(false, |root_expr| {
641                                equivalences.iter().any(|equivalence| {
642                                    all![
643                                        contains_equivalent_expr_from_input(
644                                            equivalence,
645                                            root_expr,
646                                            input,
647                                            input_prov,
648                                        ),
649                                        contains_equivalent_expr_from_input(
650                                            equivalence,
651                                            root_expr,
652                                            other,
653                                            other_prov,
654                                        ),
655                                    ]
656                                })
657                            })
658                        })
659                    };
660
661                    // Find an unique key for input that has all columns equated to other.
662                    if keys.iter().any(all_key_columns_equated) {
663                        // Find out whether we can produce input's projection strictly with
664                        // elements in other's projection.
665                        let expressions = input_prov
666                            .dereferenced_projection
667                            .iter()
668                            .enumerate()
669                            .flat_map(|(c, _)| {
670                                // Check if the expression under input's 'c' column can be built
671                                // with elements in other's projection.
672                                input_prov.dereferenced_projection[c].as_ref().map_or(
673                                    None,
674                                    |root_expr| {
675                                        try_build_expression_using_other(
676                                            root_expr,
677                                            other,
678                                            other_prov,
679                                            input_mapper,
680                                        )
681                                    },
682                                )
683                            })
684                            .collect_vec();
685                        if expressions.len() == input_prov.dereferenced_projection.len() {
686                            return Some(expressions);
687                        }
688                    }
689                }
690            }
691        }
692    }
693
694    None
695}
696
697/// Tries to build `root_expr` using elements from other's projection.
698fn try_build_expression_using_other(
699    root_expr: &MirScalarExpr,
700    other: usize,
701    other_prov: &ProvInfo,
702    input_mapper: &JoinInputMapper,
703) -> Option<MirScalarExpr> {
704    if root_expr.is_literal() {
705        return Some(root_expr.clone());
706    }
707
708    // Check if 'other' projects a column that lead to `root_expr`.
709    for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
710        if let Some(derefed) = derefed {
711            if derefed == root_expr {
712                return Some(MirScalarExpr::column(
713                    input_mapper.map_column_to_global(other_col, other),
714                ));
715            }
716        }
717    }
718
719    // Otherwise, try to build root_expr's sub-expressions recursively
720    // other's projection.
721    match root_expr {
722        MirScalarExpr::Column(_, _) => None,
723        MirScalarExpr::CallUnary { func, expr } => {
724            try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
725                |expr| {
726                    Some(MirScalarExpr::CallUnary {
727                        func: func.clone(),
728                        expr: Box::new(expr),
729                    })
730                },
731            )
732        }
733        MirScalarExpr::CallBinary { func, expr1, expr2 } => {
734            try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
735                |expr1| {
736                    try_build_expression_using_other(expr2, other, other_prov, input_mapper)
737                        .and_then(|expr2| {
738                            Some(MirScalarExpr::CallBinary {
739                                func: func.clone(),
740                                expr1: Box::new(expr1),
741                                expr2: Box::new(expr2),
742                            })
743                        })
744                },
745            )
746        }
747        MirScalarExpr::CallVariadic { func, exprs } => {
748            let new_exprs = exprs
749                .iter()
750                .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
751                .collect_vec();
752            if new_exprs.len() == exprs.len() {
753                Some(MirScalarExpr::CallVariadic {
754                    func: func.clone(),
755                    exprs: new_exprs,
756                })
757            } else {
758                None
759            }
760        }
761        MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
762            Some(root_expr.clone())
763        }
764        MirScalarExpr::If { cond, then, els } => {
765            try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
766                |cond| {
767                    try_build_expression_using_other(then, other, other_prov, input_mapper)
768                        .and_then(|then| {
769                            try_build_expression_using_other(els, other, other_prov, input_mapper)
770                                .and_then(|els| {
771                                    Some(MirScalarExpr::If {
772                                        cond: Box::new(cond),
773                                        then: Box::new(then),
774                                        els: Box::new(els),
775                                    })
776                                })
777                        })
778                },
779            )
780        }
781    }
782}
783
784/// A context of `ProvInfo` vectors associated with bindings that might still be
785/// referenced.
786#[derive(Debug, Default)]
787pub struct ProvInfoCtx {
788    /// [`LocalId`] references in the remaining subtree.
789    ///
790    /// Entries from the `lets` map that are no longer used can be pruned.
791    uses: BTreeMap<LocalId, usize>,
792    /// [`ProvInfo`] vectors associated with let binding in scope.
793    lets: BTreeMap<LocalId, Vec<ProvInfo>>,
794}
795
796impl ProvInfoCtx {
797    /// Extend the `uses` map by the `LocalId`s used in `expr`.
798    pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
799        expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
800            MirRelationExpr::Get {
801                id: Id::Local(id), ..
802            } => {
803                let count = self.uses.entry(id.clone()).or_insert(0_usize);
804                *count += 1;
805            }
806            _ => (),
807        });
808    }
809
810    /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
811    /// `lets` entries for `uses` that reset to zero.
812    pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
813        let mut worklist = vec![expr];
814        while let Some(expr) = worklist.pop() {
815            if let MirRelationExpr::Get {
816                id: Id::Local(id), ..
817            } = expr
818            {
819                if let Some(count) = self.uses.get_mut(id) {
820                    if *count > 0 {
821                        *count -= 1;
822                    }
823                    if *count == 0 {
824                        if self.lets.remove(id).is_none() {
825                            soft_panic_or_log!("ctx.lets[{id}] should exist");
826                        }
827                    }
828                } else {
829                    soft_panic_or_log!("ctx.uses[{id}] should exist");
830                }
831            }
832            match expr {
833                MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
834                    // When traversing the tree, don't descend into
835                    // `Let`/`LetRec` sub-terms in order to avoid double
836                    // counting (those are handled by remove_uses calls of
837                    // RedundantJoin::action on subterms that were already
838                    // visited because the action works bottom-up).
839                }
840                _ => {
841                    worklist.extend(expr.children().rev());
842                }
843            }
844        }
845    }
846
847    /// Get the `ProvInfo` vector for `id` from the context.
848    pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
849        self.lets.get(id)
850    }
851
852    /// Extend the context with the `id: prov_infos` entry.
853    pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
854        self.lets.insert(id, prov_infos)
855    }
856
857    /// Remove the entry identified by `id` from the context.
858    pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
859        self.lets.remove(id)
860    }
861}