mz_transform/
redundant_join.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Remove redundant collections of distinct elements from joins.
11//!
12//! This analysis looks for joins in which one collection contains distinct
13//! elements, and it can be determined that the join would only restrict the
14//! results, and that the restriction is redundant (the other results would
15//! not be reduced by the join).
16//!
17//! This type of optimization shows up often in subqueries, where distinct
18//! collections are used in decorrelation, and afterwards often distinct
19//! collections are then joined against the results.
20
21// If statements seem a bit clearer in this case. Specialized methods
22// that replace simple and common alternatives frustrate developers.
23#![allow(clippy::comparison_chain, clippy::filter_next)]
24
25use std::collections::BTreeMap;
26
27use itertools::Itertools;
28use mz_expr::visit::Visit;
29use mz_expr::{Id, JoinInputMapper, LocalId, MirRelationExpr, MirScalarExpr, RECURSION_LIMIT};
30use mz_ore::stack::{CheckedRecursion, RecursionGuard};
31use mz_ore::{assert_none, soft_panic_or_log};
32
33use crate::{TransformCtx, all};
34
35/// Remove redundant collections of distinct elements from joins.
36#[derive(Debug)]
37pub struct RedundantJoin {
38    recursion_guard: RecursionGuard,
39}
40
41impl Default for RedundantJoin {
42    fn default() -> RedundantJoin {
43        RedundantJoin {
44            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
45        }
46    }
47}
48
49impl CheckedRecursion for RedundantJoin {
50    fn recursion_guard(&self) -> &RecursionGuard {
51        &self.recursion_guard
52    }
53}
54
55impl crate::Transform for RedundantJoin {
56    fn name(&self) -> &'static str {
57        "RedundantJoin"
58    }
59
60    #[mz_ore::instrument(
61        target = "optimizer",
62        level = "debug",
63        fields(path.segment = "redundant_join")
64    )]
65    fn actually_perform_transform(
66        &self,
67        relation: &mut MirRelationExpr,
68        _: &mut TransformCtx,
69    ) -> Result<(), crate::TransformError> {
70        let mut ctx = ProvInfoCtx::default();
71        ctx.extend_uses(relation);
72        let result = self.action(relation, &mut ctx);
73        mz_repr::explain::trace_plan(&*relation);
74        result.map(|_| ())
75    }
76}
77
78impl RedundantJoin {
79    /// Remove redundant collections of distinct elements from joins.
80    ///
81    /// This method tracks "provenance" information for each collections,
82    /// those being column-wise relationships to identified collections
83    /// (either imported collections, or let-bound collections). These
84    /// relationships state that when projected on to these columns, the
85    /// records of the one collection are contained in the records of the
86    /// identified collection.
87    ///
88    /// This provenance information is then used for the `MirRelationExpr::Join`
89    /// variant to remove "redundant" joins, those that can be determined to
90    /// neither restrict nor augment one of the input relations. Consult the
91    /// `find_redundancy` method and its documentation for more detail.
92    pub fn action(
93        &self,
94        relation: &mut MirRelationExpr,
95        ctx: &mut ProvInfoCtx,
96    ) -> Result<Vec<ProvInfo>, crate::TransformError> {
97        let mut result = self.checked_recur(|_| {
98            match relation {
99                MirRelationExpr::Let { id, value, body } => {
100                    // Recursively determine provenance of the value.
101                    let value_prov = self.action(value, ctx)?;
102                    // Clear uses from the just visited binding definition.
103                    ctx.remove_uses(value);
104
105                    // Extend the lets context with an entry for this binding.
106                    let prov_old = ctx.insert(*id, value_prov);
107                    assert_none!(prov_old, "No shadowing");
108
109                    // Determine provenance of the body.
110                    let result = self.action(body, ctx)?;
111                    ctx.remove_uses(body);
112
113                    // Remove the lets entry for this binding from the context.
114                    ctx.remove(id);
115
116                    Ok(result)
117                }
118
119                MirRelationExpr::LetRec {
120                    ids,
121                    values,
122                    limits: _,
123                    body,
124                } => {
125                    // As a first approximation, we naively extend the `lets`
126                    // context with the empty vec![] for each id.
127                    for id in ids.iter() {
128                        let prov_old = ctx.insert(*id, vec![]);
129                        assert_none!(prov_old, "No shadowing");
130                    }
131
132                    // In other words, we don't attempt to derive additional
133                    // provenance information for a binding from its `value`.
134                    //
135                    // We descend into the values and the body with the naively
136                    // extended context.
137                    for value in values.iter_mut() {
138                        self.action(value, ctx)?;
139                    }
140                    // Clear uses from the just visited recursive binding
141                    // definitions.
142                    for value in values.iter_mut() {
143                        ctx.remove_uses(value);
144                    }
145                    let result = self.action(body, ctx)?;
146                    ctx.remove_uses(body);
147
148                    // Remove the lets entries for all ids.
149                    for id in ids.iter() {
150                        ctx.remove(id);
151                    }
152
153                    Ok(result)
154                }
155
156                MirRelationExpr::Get { id, typ, .. } => {
157                    if let Id::Local(id) = id {
158                        // Extract the value provenance (this should always exist).
159                        let mut val_info = ctx.get(id).cloned().unwrap_or_else(|| {
160                            soft_panic_or_log!("no ctx entry for LocalId {id}");
161                            vec![]
162                        });
163                        // Add information about being exactly this let binding too.
164                        val_info.push(ProvInfo::make_leaf(Id::Local(*id), typ.arity()));
165                        Ok(val_info)
166                    } else {
167                        // Add information about being exactly this GlobalId reference.
168                        Ok(vec![ProvInfo::make_leaf(*id, typ.arity())])
169                    }
170                }
171
172                MirRelationExpr::Join {
173                    inputs,
174                    equivalences,
175                    implementation,
176                } => {
177                    // This logic first applies what it has learned about its input provenance,
178                    // and if it finds a redundant join input it removes it. In that case, it
179                    // also fails to produce exciting provenance information, partly out of
180                    // laziness and the challenge of ensuring it is correct. Instead, if it is
181                    // unable to find a redundant join it produces meaningful provenance information.
182
183                    // Recursively apply transformation, and determine the provenance of inputs.
184                    let mut input_prov = Vec::new();
185                    for i in inputs.iter_mut() {
186                        input_prov.push(self.action(i, ctx)?);
187                    }
188
189                    // Determine useful information about the structure of the inputs.
190                    let mut input_types = inputs.iter().map(|i| i.typ()).collect::<Vec<_>>();
191                    let old_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
192
193                    // If we find an input that can be removed, we should do so!
194                    // We only do this once per invocation to keep our sanity, but we could
195                    // rewrite it to iterate. We can avoid looking for any relation that
196                    // does not have keys, as it cannot be redundant in that case.
197                    if let Some((remove_input_idx, mut bindings)) = (0..input_types.len())
198                        .rev()
199                        .filter(|i| !input_types[*i].keys.is_empty())
200                        .flat_map(|i| {
201                            find_redundancy(
202                                i,
203                                &input_types[i].keys,
204                                &old_input_mapper,
205                                equivalences,
206                                &input_prov[..],
207                            )
208                            .map(|b| (i, b))
209                        })
210                        .next()
211                    {
212                        // Clear uses from the removed input.
213                        ctx.remove_uses(&inputs[remove_input_idx]);
214
215                        inputs.remove(remove_input_idx);
216                        input_types.remove(remove_input_idx);
217
218                        // Update the column offsets in the binding expressions to catch
219                        // up with the removal of `remove_input_idx`.
220                        for expr in bindings.iter_mut() {
221                            expr.visit_pre_mut(|e| {
222                                if let MirScalarExpr::Column(c, _) = e {
223                                    let (_local_col, input_relation) =
224                                        old_input_mapper.map_column_to_local(*c);
225                                    if input_relation > remove_input_idx {
226                                        *c -= old_input_mapper.input_arity(remove_input_idx);
227                                    }
228                                }
229                            });
230                        }
231
232                        // Replace column references from `remove_input_idx` with the corresponding
233                        // binding expression. Update the offsets of the column references
234                        // from inputs after `remove_input_idx`.
235                        for equivalence in equivalences.iter_mut() {
236                            for expr in equivalence.iter_mut() {
237                                expr.visit_mut_post(&mut |e| {
238                                    if let MirScalarExpr::Column(c, _) = e {
239                                        let (local_col, input_relation) =
240                                            old_input_mapper.map_column_to_local(*c);
241                                        if input_relation == remove_input_idx {
242                                            *e = bindings[local_col].clone();
243                                        } else if input_relation > remove_input_idx {
244                                            *c -= old_input_mapper.input_arity(remove_input_idx);
245                                        }
246                                    }
247                                })?;
248                            }
249                        }
250
251                        mz_expr::canonicalize::canonicalize_equivalences(
252                            equivalences,
253                            input_types.iter().map(|t| &t.column_types),
254                        );
255
256                        // Build a projection that leaves the binding expressions in the same
257                        // position as the columns of the removed join input they are replacing.
258                        let new_input_mapper = JoinInputMapper::new_from_input_types(&input_types);
259                        let mut projection = Vec::new();
260                        let new_join_arity = new_input_mapper.total_columns();
261                        for i in 0..old_input_mapper.total_inputs() {
262                            if i != remove_input_idx {
263                                projection.extend(
264                                    new_input_mapper.global_columns(if i < remove_input_idx {
265                                        i
266                                    } else {
267                                        i - 1
268                                    }),
269                                );
270                            } else {
271                                projection.extend(new_join_arity..new_join_arity + bindings.len());
272                            }
273                        }
274
275                        // Unset implementation, as irrevocably hosed by this transformation.
276                        *implementation = mz_expr::JoinImplementation::Unimplemented;
277
278                        *relation = relation.take_dangerous().map(bindings).project(projection);
279                        // The projection will gum up provenance reasoning anyhow, so don't work hard.
280                        // We will return to this expression again with the same analysis.
281                        Ok(Vec::new())
282                    } else {
283                        // Provenance information should be the union of input provenance information,
284                        // with columns updated. Because rows may be dropped in the join, all `exact`
285                        // bits should be un-set.
286                        let mut results = Vec::new();
287                        for (input, input_prov) in input_prov.into_iter().enumerate() {
288                            for mut prov in input_prov {
289                                prov.exact = false;
290                                let mut projection = vec![None; old_input_mapper.total_columns()];
291                                for (local_col, global_col) in
292                                    old_input_mapper.global_columns(input).enumerate()
293                                {
294                                    projection[global_col]
295                                        .clone_from(&prov.dereferenced_projection[local_col]);
296                                }
297                                prov.dereferenced_projection = projection;
298                                results.push(prov);
299                            }
300                        }
301                        Ok(results)
302                    }
303                }
304
305                MirRelationExpr::Filter { input, .. } => {
306                    // Filter may drop records, and so we unset `exact`.
307                    let mut result = self.action(input, ctx)?;
308                    for prov in result.iter_mut() {
309                        prov.exact = false;
310                    }
311                    Ok(result)
312                }
313
314                MirRelationExpr::Map { input, scalars } => {
315                    let mut result = self.action(input, ctx)?;
316                    for prov in result.iter_mut() {
317                        for scalar in scalars.iter() {
318                            let dereferenced_scalar = prov.strict_dereference(scalar);
319                            prov.dereferenced_projection.push(dereferenced_scalar);
320                        }
321                    }
322                    Ok(result)
323                }
324
325                MirRelationExpr::Union { base, inputs } => {
326                    let mut prov = self.action(base, ctx)?;
327                    for input in inputs {
328                        let input_prov = self.action(input, ctx)?;
329                        // To merge a new list of provenances, we look at the cross
330                        // produce of things we might know about each source.
331                        // TODO(mcsherry): this can be optimized to use datastructures
332                        // keyed by the source identifier.
333                        let mut new_prov = Vec::new();
334                        for l in prov {
335                            new_prov.extend(input_prov.iter().flat_map(|r| l.meet(r)))
336                        }
337                        prov = new_prov;
338                    }
339                    Ok(prov)
340                }
341
342                MirRelationExpr::Constant { .. } => Ok(Vec::new()),
343
344                MirRelationExpr::Reduce {
345                    input,
346                    group_key,
347                    aggregates,
348                    ..
349                } => {
350                    // Reduce yields its first few columns as a key, and produces
351                    // all key tuples that were present in its input.
352                    let mut result = self.action(input, ctx)?;
353                    for prov in result.iter_mut() {
354                        let mut projection = group_key
355                            .iter()
356                            .map(|key| prov.strict_dereference(key))
357                            .collect_vec();
358                        projection.extend((0..aggregates.len()).map(|_| None));
359                        prov.dereferenced_projection = projection;
360                    }
361                    // TODO: For min, max aggregates, we could preserve provenance
362                    // if the expression references a column. We would need to un-set
363                    // the `exact` bit in that case, and so we would want to keep both
364                    // sets of provenance information.
365                    Ok(result)
366                }
367
368                MirRelationExpr::Threshold { input } => {
369                    // Threshold may drop records, and so we unset `exact`.
370                    let mut result = self.action(input, ctx)?;
371                    for prov in result.iter_mut() {
372                        prov.exact = false;
373                    }
374                    Ok(result)
375                }
376
377                MirRelationExpr::TopK { input, .. } => {
378                    // TopK may drop records, and so we unset `exact`.
379                    let mut result = self.action(input, ctx)?;
380                    for prov in result.iter_mut() {
381                        prov.exact = false;
382                    }
383                    Ok(result)
384                }
385
386                MirRelationExpr::Project { input, outputs } => {
387                    // Projections re-order, drop, and duplicate columns,
388                    // but they neither drop rows nor invent values.
389                    let mut result = self.action(input, ctx)?;
390                    for prov in result.iter_mut() {
391                        let projection = outputs
392                            .iter()
393                            .map(|c| prov.dereference(&MirScalarExpr::column(*c)))
394                            .collect_vec();
395                        prov.dereferenced_projection = projection;
396                    }
397                    Ok(result)
398                }
399
400                MirRelationExpr::FlatMap {
401                    input,
402                    func,
403                    exprs: _,
404                } => {
405                    // FlatMap may drop records, and so we unset `exact`.
406                    let mut result = self.action(input, ctx)?;
407                    for prov in result.iter_mut() {
408                        prov.exact = false;
409                        prov.dereferenced_projection
410                            .extend((0..func.output_arity()).map(|_| None));
411                    }
412                    Ok(result)
413                }
414
415                MirRelationExpr::Negate { input } => {
416                    // Negate does not guarantee that the multiplicity of
417                    // each source record it at least one. This could have
418                    // been a problem in `Union`, where we might report
419                    // that the union of positive and negative records is
420                    // "exact": cancellations would make this false.
421                    let mut result = self.action(input, ctx)?;
422                    for prov in result.iter_mut() {
423                        prov.exact = false;
424                    }
425                    Ok(result)
426                }
427
428                MirRelationExpr::ArrangeBy { input, .. } => self.action(input, ctx),
429            }
430        })?;
431        result.retain(|info| !info.is_trivial());
432
433        // Uncomment the following lines to trace the individual steps:
434        // println!("{}", relation.pretty());
435        // println!("result = {result:?}");
436        // println!("lets: {lets:?}");
437        // println!("---------------------");
438
439        Ok(result)
440    }
441}
442
443/// A relationship between a collections columns and some source columns.
444///
445/// An instance of this type indicates that some of the bearer's columns
446/// derive from `id`. In particular, the non-`None` elements in
447/// `dereferenced_projection` correspond to columns that can be derived
448/// from `id`'s projection.
449///
450/// The guarantee is that projected on to these columns, the distinct values
451/// of the bearer are contained in the set of distinct values of projected
452/// columns of `id`. In the case that `exact` is set, the two sets are equal.
453#[derive(Clone, Debug, Ord, Eq, PartialOrd, PartialEq)]
454pub struct ProvInfo {
455    /// The Id (local or global) of the source.
456    id: Id,
457    /// The projection of the bearer written in terms of the columns projected
458    /// by the underlying Get operator. Set to `None` for columns that cannot
459    /// be expressed as scalar expression referencing only columns of the
460    /// underlying Get operator.
461    dereferenced_projection: Vec<Option<MirScalarExpr>>,
462    /// If true, all distinct projected source rows are present in the rows of
463    /// the projection of the current collection. This constraint is lost as soon
464    /// as a transformation may drop records.
465    exact: bool,
466}
467
468impl ProvInfo {
469    fn make_leaf(id: Id, arity: usize) -> Self {
470        Self {
471            id,
472            dereferenced_projection: (0..arity)
473                .map(|c| Some(MirScalarExpr::column(c)))
474                .collect::<Vec<_>>(),
475            exact: true,
476        }
477    }
478
479    /// Rewrite `expr` so it refers to the columns of the original source instead
480    /// of the columns of the projected source.
481    fn dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
482        match expr {
483            MirScalarExpr::Column(c, _) => {
484                if let Some(expr) = &self.dereferenced_projection[*c] {
485                    Some(expr.clone())
486                } else {
487                    None
488                }
489            }
490            MirScalarExpr::CallUnary { func, expr } => self.dereference(expr).and_then(|expr| {
491                Some(MirScalarExpr::CallUnary {
492                    func: func.clone(),
493                    expr: Box::new(expr),
494                })
495            }),
496            MirScalarExpr::CallBinary { func, expr1, expr2 } => {
497                self.dereference(expr1).and_then(|expr1| {
498                    self.dereference(expr2).and_then(|expr2| {
499                        Some(MirScalarExpr::CallBinary {
500                            func: func.clone(),
501                            expr1: Box::new(expr1),
502                            expr2: Box::new(expr2),
503                        })
504                    })
505                })
506            }
507            MirScalarExpr::CallVariadic { func, exprs } => {
508                let new_exprs = exprs.iter().flat_map(|e| self.dereference(e)).collect_vec();
509                if new_exprs.len() == exprs.len() {
510                    Some(MirScalarExpr::CallVariadic {
511                        func: func.clone(),
512                        exprs: new_exprs,
513                    })
514                } else {
515                    None
516                }
517            }
518            MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
519                Some(expr.clone())
520            }
521            MirScalarExpr::If { cond, then, els } => self.dereference(cond).and_then(|cond| {
522                self.dereference(then).and_then(|then| {
523                    self.dereference(els).and_then(|els| {
524                        Some(MirScalarExpr::If {
525                            cond: Box::new(cond),
526                            then: Box::new(then),
527                            els: Box::new(els),
528                        })
529                    })
530                })
531            }),
532        }
533    }
534
535    /// Like `dereference` but only returns expressions that actually depend on
536    /// the original source.
537    fn strict_dereference(&self, expr: &MirScalarExpr) -> Option<MirScalarExpr> {
538        let derefed = self.dereference(expr);
539        match derefed {
540            Some(ref expr) if !expr.support().is_empty() => derefed,
541            _ => None,
542        }
543    }
544
545    /// Merge two constraints to find a constraint that satisfies both inputs.
546    ///
547    /// This method returns nothing if no columns are in common (either because
548    /// difference sources are identified, or just no columns in common) and it
549    /// intersects bindings and the `exact` bit.
550    fn meet(&self, other: &Self) -> Option<Self> {
551        if self.id == other.id {
552            let resulting_projection = self
553                .dereferenced_projection
554                .iter()
555                .zip(other.dereferenced_projection.iter())
556                .map(|(e1, e2)| if e1 == e2 { e1.clone() } else { None })
557                .collect_vec();
558            if resulting_projection.iter().any(|e| e.is_some()) {
559                Some(ProvInfo {
560                    id: self.id,
561                    dereferenced_projection: resulting_projection,
562                    exact: self.exact && other.exact,
563                })
564            } else {
565                None
566            }
567        } else {
568            None
569        }
570    }
571
572    /// Check if all entries of the dereferenced projection are missing.
573    ///
574    /// If this is the case keeping the `ProvInfo` entry around is meaningless.
575    fn is_trivial(&self) -> bool {
576        all![
577            !self.dereferenced_projection.is_empty(),
578            self.dereferenced_projection.iter().all(|x| x.is_none()),
579        ]
580    }
581}
582
583/// Attempts to find column bindings that make `input` redundant.
584///
585/// This method attempts to determine that `input` may be redundant by searching
586/// the join structure for another relation `other` with provenance that contains some
587/// provenance of `input`, and keys for `input` that are equated by the join to the
588/// corresponding columns of `other` under their provenance. The `input` provenance
589/// must also have its `exact` bit set.
590///
591/// In these circumstances, the claim is that because the key columns are equated and
592/// determine non-key columns, any matches between `input` and
593/// `other` will neither introduce new information to `other`, nor restrict the rows
594/// of `other`, nor alter their multplicity.
595fn find_redundancy(
596    input: usize,
597    keys: &[Vec<usize>],
598    input_mapper: &JoinInputMapper,
599    equivalences: &[Vec<MirScalarExpr>],
600    input_provs: &[Vec<ProvInfo>],
601) -> Option<Vec<MirScalarExpr>> {
602    // Whether the `equivalence` contains an expression that only references
603    // `input` that leads to the same as `root_expr` once dereferenced.
604    let contains_equivalent_expr_from_input = |equivalence: &[MirScalarExpr],
605                                               root_expr: &MirScalarExpr,
606                                               input: usize,
607                                               provenance: &ProvInfo|
608     -> bool {
609        equivalence.iter().any(|expr| {
610            Some(input) == input_mapper.single_input(expr)
611                && provenance
612                    .dereference(&input_mapper.map_expr_to_local(expr.clone()))
613                    .as_ref()
614                    == Some(root_expr)
615        })
616    };
617    for input_prov in input_provs[input].iter() {
618        // We can only elide if the input contains all records, and binds all columns.
619        if input_prov.exact
620            && input_prov
621                .dereferenced_projection
622                .iter()
623                .all(|e| e.is_some())
624        {
625            // examine all *other* inputs that have not been removed...
626            for other in (0..input_mapper.total_inputs()).filter(|other| other != &input) {
627                for other_prov in input_provs[other].iter().filter(|p| p.id == input_prov.id) {
628                    let all_key_columns_equated = |key: &Vec<usize>| {
629                        key.iter().all(|key_col| {
630                            // The root expression behind the key column, ie.
631                            // the expression re-written in terms of elements in
632                            // the projection of the Get operator.
633                            let root_expr =
634                                input_prov.dereference(&MirScalarExpr::column(*key_col));
635                            // Check if there is a join equivalence that joins
636                            // 'input' and 'other' on expressions that lead to
637                            // the same root expression as the key column.
638                            root_expr.as_ref().map_or(false, |root_expr| {
639                                equivalences.iter().any(|equivalence| {
640                                    all![
641                                        contains_equivalent_expr_from_input(
642                                            equivalence,
643                                            root_expr,
644                                            input,
645                                            input_prov,
646                                        ),
647                                        contains_equivalent_expr_from_input(
648                                            equivalence,
649                                            root_expr,
650                                            other,
651                                            other_prov,
652                                        ),
653                                    ]
654                                })
655                            })
656                        })
657                    };
658
659                    // Find an unique key for input that has all columns equated to other.
660                    if keys.iter().any(all_key_columns_equated) {
661                        // Find out whether we can produce input's projection strictly with
662                        // elements in other's projection.
663                        let expressions = input_prov
664                            .dereferenced_projection
665                            .iter()
666                            .enumerate()
667                            .flat_map(|(c, _)| {
668                                // Check if the expression under input's 'c' column can be built
669                                // with elements in other's projection.
670                                input_prov.dereferenced_projection[c].as_ref().map_or(
671                                    None,
672                                    |root_expr| {
673                                        try_build_expression_using_other(
674                                            root_expr,
675                                            other,
676                                            other_prov,
677                                            input_mapper,
678                                        )
679                                    },
680                                )
681                            })
682                            .collect_vec();
683                        if expressions.len() == input_prov.dereferenced_projection.len() {
684                            return Some(expressions);
685                        }
686                    }
687                }
688            }
689        }
690    }
691
692    None
693}
694
695/// Tries to build `root_expr` using elements from other's projection.
696fn try_build_expression_using_other(
697    root_expr: &MirScalarExpr,
698    other: usize,
699    other_prov: &ProvInfo,
700    input_mapper: &JoinInputMapper,
701) -> Option<MirScalarExpr> {
702    if root_expr.is_literal() {
703        return Some(root_expr.clone());
704    }
705
706    // Check if 'other' projects a column that lead to `root_expr`.
707    for (other_col, derefed) in other_prov.dereferenced_projection.iter().enumerate() {
708        if let Some(derefed) = derefed {
709            if derefed == root_expr {
710                return Some(MirScalarExpr::column(
711                    input_mapper.map_column_to_global(other_col, other),
712                ));
713            }
714        }
715    }
716
717    // Otherwise, try to build root_expr's sub-expressions recursively
718    // other's projection.
719    match root_expr {
720        MirScalarExpr::Column(_, _) => None,
721        MirScalarExpr::CallUnary { func, expr } => {
722            try_build_expression_using_other(expr, other, other_prov, input_mapper).and_then(
723                |expr| {
724                    Some(MirScalarExpr::CallUnary {
725                        func: func.clone(),
726                        expr: Box::new(expr),
727                    })
728                },
729            )
730        }
731        MirScalarExpr::CallBinary { func, expr1, expr2 } => {
732            try_build_expression_using_other(expr1, other, other_prov, input_mapper).and_then(
733                |expr1| {
734                    try_build_expression_using_other(expr2, other, other_prov, input_mapper)
735                        .and_then(|expr2| {
736                            Some(MirScalarExpr::CallBinary {
737                                func: func.clone(),
738                                expr1: Box::new(expr1),
739                                expr2: Box::new(expr2),
740                            })
741                        })
742                },
743            )
744        }
745        MirScalarExpr::CallVariadic { func, exprs } => {
746            let new_exprs = exprs
747                .iter()
748                .flat_map(|e| try_build_expression_using_other(e, other, other_prov, input_mapper))
749                .collect_vec();
750            if new_exprs.len() == exprs.len() {
751                Some(MirScalarExpr::CallVariadic {
752                    func: func.clone(),
753                    exprs: new_exprs,
754                })
755            } else {
756                None
757            }
758        }
759        MirScalarExpr::Literal(..) | MirScalarExpr::CallUnmaterializable(..) => {
760            Some(root_expr.clone())
761        }
762        MirScalarExpr::If { cond, then, els } => {
763            try_build_expression_using_other(cond, other, other_prov, input_mapper).and_then(
764                |cond| {
765                    try_build_expression_using_other(then, other, other_prov, input_mapper)
766                        .and_then(|then| {
767                            try_build_expression_using_other(els, other, other_prov, input_mapper)
768                                .and_then(|els| {
769                                    Some(MirScalarExpr::If {
770                                        cond: Box::new(cond),
771                                        then: Box::new(then),
772                                        els: Box::new(els),
773                                    })
774                                })
775                        })
776                },
777            )
778        }
779    }
780}
781
782/// A context of `ProvInfo` vectors associated with bindings that might still be
783/// referenced.
784#[derive(Debug, Default)]
785pub struct ProvInfoCtx {
786    /// [`LocalId`] references in the remaining subtree.
787    ///
788    /// Entries from the `lets` map that are no longer used can be pruned.
789    uses: BTreeMap<LocalId, usize>,
790    /// [`ProvInfo`] vectors associated with let binding in scope.
791    lets: BTreeMap<LocalId, Vec<ProvInfo>>,
792}
793
794impl ProvInfoCtx {
795    /// Extend the `uses` map by the `LocalId`s used in `expr`.
796    pub fn extend_uses(&mut self, expr: &MirRelationExpr) {
797        expr.visit_pre(&mut |expr: &MirRelationExpr| match expr {
798            MirRelationExpr::Get {
799                id: Id::Local(id), ..
800            } => {
801                let count = self.uses.entry(id.clone()).or_insert(0_usize);
802                *count += 1;
803            }
804            _ => (),
805        });
806    }
807
808    /// Decrement `uses` entries by the `LocalId`s used in `expr` and remove
809    /// `lets` entries for `uses` that reset to zero.
810    pub fn remove_uses(&mut self, expr: &MirRelationExpr) {
811        let mut worklist = vec![expr];
812        while let Some(expr) = worklist.pop() {
813            if let MirRelationExpr::Get {
814                id: Id::Local(id), ..
815            } = expr
816            {
817                if let Some(count) = self.uses.get_mut(id) {
818                    if *count > 0 {
819                        *count -= 1;
820                    }
821                    if *count == 0 {
822                        if self.lets.remove(id).is_none() {
823                            soft_panic_or_log!("ctx.lets[{id}] should exist");
824                        }
825                    }
826                } else {
827                    soft_panic_or_log!("ctx.uses[{id}] should exist");
828                }
829            }
830            match expr {
831                MirRelationExpr::Let { .. } | MirRelationExpr::LetRec { .. } => {
832                    // When traversing the tree, don't descend into
833                    // `Let`/`LetRec` sub-terms in order to avoid double
834                    // counting (those are handled by remove_uses calls of
835                    // RedundantJoin::action on subterms that were already
836                    // visited because the action works bottom-up).
837                }
838                _ => {
839                    worklist.extend(expr.children().rev());
840                }
841            }
842        }
843    }
844
845    /// Get the `ProvInfo` vector for `id` from the context.
846    pub fn get(&self, id: &LocalId) -> Option<&Vec<ProvInfo>> {
847        self.lets.get(id)
848    }
849
850    /// Extend the context with the `id: prov_infos` entry.
851    pub fn insert(&mut self, id: LocalId, prov_infos: Vec<ProvInfo>) -> Option<Vec<ProvInfo>> {
852        self.lets.insert(id, prov_infos)
853    }
854
855    /// Remove the entry identified by `id` from the context.
856    pub fn remove(&mut self, id: &LocalId) -> Option<Vec<ProvInfo>> {
857        self.lets.remove(id)
858    }
859}