mz_expr/relation/join_input_mapper.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::ops::Range;
12
13use itertools::Itertools;
14use mz_repr::SqlRelationType;
15
16use crate::visit::Visit;
17use crate::{MirRelationExpr, MirScalarExpr, VariadicFunc};
18
19/// Any column in a join expression exists in two contexts:
20/// 1) It has a position relative to the result of the join (global)
21/// 2) It has a position relative to the specific input it came from (local)
22/// This utility focuses on taking expressions that are in terms of
23/// the local input and re-expressing them in global terms and vice versa.
24///
25/// Methods in this class that take an argument `equivalences` are only
26/// guaranteed to return a correct answer if equivalence classes are in
27/// canonical form.
28/// (See [`crate::relation::canonicalize::canonicalize_equivalences`].)
29#[derive(Debug)]
30pub struct JoinInputMapper {
31    /// The number of columns per input. All other fields in this struct are
32    /// derived using the information in this field.
33    arities: Vec<usize>,
34    /// Looks up which input each column belongs to. Derived from `arities`.
35    /// Stored as a field to avoid recomputation.
36    input_relation: Vec<usize>,
37    /// The sum of the arities of the previous inputs in the join. Derived from
38    /// `arities`. Stored as a field to avoid recomputation.
39    prior_arities: Vec<usize>,
40}
41
42impl JoinInputMapper {
43    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
44    /// columns to local context columns.
45    pub fn new(inputs: &[MirRelationExpr]) -> Self {
46        Self::new_from_input_arities(inputs.iter().map(|i| i.arity()))
47    }
48
49    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
50    /// columns to local context columns. Using this method saves is more
51    /// efficient if input types have been pre-calculated
52    pub fn new_from_input_types(types: &[SqlRelationType]) -> Self {
53        Self::new_from_input_arities(types.iter().map(|t| t.column_types.len()))
54    }
55
56    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
57    /// columns to local context columns. Using this method saves is more
58    /// efficient if input arities have been pre-calculated
59    pub fn new_from_input_arities<I>(arities: I) -> Self
60    where
61        I: IntoIterator<Item = usize>,
62    {
63        let arities = arities.into_iter().collect::<Vec<usize>>();
64        let mut offset = 0;
65        let mut prior_arities = Vec::new();
66        for input in 0..arities.len() {
67            prior_arities.push(offset);
68            offset += arities[input];
69        }
70
71        let input_relation = arities
72            .iter()
73            .enumerate()
74            .flat_map(|(r, a)| std::iter::repeat(r).take(*a))
75            .collect::<Vec<_>>();
76
77        JoinInputMapper {
78            arities,
79            input_relation,
80            prior_arities,
81        }
82    }
83
84    /// reports sum of the number of columns of each input
85    pub fn total_columns(&self) -> usize {
86        self.arities.iter().sum()
87    }
88
89    /// reports total numbers of inputs in the join
90    pub fn total_inputs(&self) -> usize {
91        self.arities.len()
92    }
93
94    /// Using the keys that came from each local input,
95    /// figures out which keys remain unique in the larger join
96    /// Currently, we only figure out a small subset of the keys that
97    /// can remain unique.
98    pub fn global_keys<'a, I>(
99        &self,
100        mut local_keys: I,
101        equivalences: &[Vec<MirScalarExpr>],
102    ) -> Vec<Vec<usize>>
103    where
104        I: Iterator<Item = &'a Vec<Vec<usize>>>,
105    {
106        // A relation's uniqueness constraint holds if there is a
107        // sequence of the other relations such that each one has
108        // a uniqueness constraint whose columns are used in join
109        // constraints with relations prior in the sequence.
110        //
111        // Currently, we only:
112        // 1. test for whether the uniqueness constraints for the first input will hold
113        // 2. try one sequence, namely the inputs in order
114        // 3. check that the column themselves are used in the join constraints
115        //    Technically uniqueness constraint would still hold if a 1-to-1
116        //    expression on a unique key is used in the join constraint.
117
118        // for inputs `1..self.total_inputs()`, store a set of columns from that
119        // input that exist in join constraints that have expressions belonging to
120        // earlier inputs.
121        let mut column_with_prior_bound_by_input = vec![BTreeSet::new(); self.total_inputs() - 1];
122        for equivalence in equivalences {
123            // do a scan to find the first input represented in the constraint
124            let min_bound_input = equivalence
125                .iter()
126                .flat_map(|expr| self.lookup_inputs(expr).max())
127                .min();
128            if let Some(min_bound_input) = min_bound_input {
129                for expr in equivalence {
130                    // then store all columns in the constraint that don't come
131                    // from the first input
132                    if let MirScalarExpr::Column(c, _name) = expr {
133                        let (col, input) = self.map_column_to_local(*c);
134                        if input > min_bound_input {
135                            column_with_prior_bound_by_input[input - 1].insert(col);
136                        }
137                    }
138                }
139            }
140        }
141
142        if self.total_inputs() > 0 {
143            let first_input_keys = local_keys.next().unwrap().clone();
144            // for inputs `1..self.total_inputs()`, checks the keys belong to each
145            // input against the storage of columns that exist in join constraints
146            // that have expressions belonging to earlier inputs.
147            let remains_unique = local_keys.enumerate().all(|(index, keys)| {
148                keys.iter().any(|ks| {
149                    ks.iter()
150                        .all(|k| column_with_prior_bound_by_input[index].contains(k))
151                })
152            });
153
154            if remains_unique {
155                return first_input_keys;
156            }
157        }
158        vec![]
159    }
160
161    /// returns the arity for a particular input
162    #[inline]
163    pub fn input_arity(&self, index: usize) -> usize {
164        self.arities[index]
165    }
166
167    /// All column numbers in order for a particular input in the local context
168    #[inline]
169    pub fn local_columns(&self, index: usize) -> Range<usize> {
170        0..self.arities[index]
171    }
172
173    /// All column numbers in order for a particular input in the global context
174    #[inline]
175    pub fn global_columns(&self, index: usize) -> Range<usize> {
176        self.prior_arities[index]..(self.prior_arities[index] + self.arities[index])
177    }
178
179    /// Takes an expression from the global context and creates a new version
180    /// where column references have been remapped to the local context.
181    /// Assumes that all columns in `expr` are from the same input.
182    pub fn map_expr_to_local(&self, mut expr: MirScalarExpr) -> MirScalarExpr {
183        expr.visit_pre_mut(|e| {
184            if let MirScalarExpr::Column(c, _name) = e {
185                *c -= self.prior_arities[self.input_relation[*c]];
186            }
187        });
188        expr
189    }
190
191    /// Takes an expression from the local context of the `index`th input and
192    /// creates a new version where column references have been remapped to the
193    /// global context.
194    pub fn map_expr_to_global(&self, mut expr: MirScalarExpr, index: usize) -> MirScalarExpr {
195        expr.visit_pre_mut(|e| {
196            if let MirScalarExpr::Column(c, _name) = e {
197                *c += self.prior_arities[index];
198            }
199        });
200        expr
201    }
202
203    /// Remap column numbers from the global to the local context.
204    /// Returns a 2-tuple `(<new column number>, <index of input>)`
205    pub fn map_column_to_local(&self, column: usize) -> (usize, usize) {
206        let index = self.input_relation[column];
207        (column - self.prior_arities[index], index)
208    }
209
210    /// Remap a column number from a local context to the global context.
211    pub fn map_column_to_global(&self, column: usize, index: usize) -> usize {
212        column + self.prior_arities[index]
213    }
214
215    /// Takes a sequence of columns in the global context and splits it into
216    /// a `Vec` containing `self.total_inputs()` `BTreeSet`s, each containing
217    /// the localized columns that belong to the particular input.
218    pub fn split_column_set_by_input<'a, I>(&self, columns: I) -> Vec<BTreeSet<usize>>
219    where
220        I: Iterator<Item = &'a usize>,
221    {
222        let mut new_columns = vec![BTreeSet::new(); self.total_inputs()];
223        for column in columns {
224            let (new_col, input) = self.map_column_to_local(*column);
225            new_columns[input].extend(std::iter::once(new_col));
226        }
227        new_columns
228    }
229
230    /// Find the sorted, dedupped set of inputs an expression references
231    pub fn lookup_inputs(&self, expr: &MirScalarExpr) -> impl Iterator<Item = usize> + use<> {
232        expr.support()
233            .iter()
234            .map(|c| self.input_relation[*c])
235            .sorted()
236            .dedup()
237    }
238
239    /// Returns the index of the only input referenced in the given expression.
240    pub fn single_input(&self, expr: &MirScalarExpr) -> Option<usize> {
241        let mut inputs = self.lookup_inputs(expr);
242        if let Some(first_input) = inputs.next() {
243            if inputs.next().is_none() {
244                return Some(first_input);
245            }
246        }
247        None
248    }
249
250    /// Returns whether the given expr refers to columns of only the `index`th input.
251    pub fn is_localized(&self, expr: &MirScalarExpr, index: usize) -> bool {
252        if let Some(single_input) = self.single_input(expr) {
253            if single_input == index {
254                return true;
255            }
256        }
257        false
258    }
259
260    /// Takes an expression in the global context and looks in `equivalences`
261    /// for an equivalent expression (also expressed in the global context) that
262    /// belongs to one or more of the inputs in `bound_inputs`
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
268    /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
269    ///
270    /// // A two-column schema common to each of the three inputs
271    /// let schema = SqlRelationType::new(vec![
272    ///   SqlScalarType::Int32.nullable(false),
273    ///   SqlScalarType::Int32.nullable(false),
274    /// ]);
275    ///
276    /// // the specific data are not important here.
277    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
278    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
279    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
280    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
281    ///
282    /// // [input0(#0) = input2(#1)], [input0(#1) = input1(#0) = input2(#0)]
283    /// let equivalences = vec![
284    ///   vec![MirScalarExpr::column(0), MirScalarExpr::column(5)],
285    ///   vec![MirScalarExpr::column(1), MirScalarExpr::column(2), MirScalarExpr::column(4)],
286    /// ];
287    ///
288    /// let input_mapper = JoinInputMapper::new(&[input0, input1, input2]);
289    /// assert_eq!(
290    ///   Some(MirScalarExpr::column(4)),
291    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(2), &[2], &equivalences)
292    /// );
293    /// assert_eq!(
294    ///   None,
295    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(0), &[1], &equivalences)
296    /// );
297    /// ```
298    pub fn find_bound_expr(
299        &self,
300        expr: &MirScalarExpr,
301        bound_inputs: &[usize],
302        equivalences: &[Vec<MirScalarExpr>],
303    ) -> Option<MirScalarExpr> {
304        if let Some(equivalence) = equivalences.iter().find(|equivs| equivs.contains(expr)) {
305            if let Some(bound_expr) = equivalence
306                .iter()
307                .find(|expr| self.lookup_inputs(expr).all(|i| bound_inputs.contains(&i)))
308            {
309                return Some(bound_expr.clone());
310            }
311        }
312        None
313    }
314
315    /// Try to rewrite `expr` from the global context so that all the
316    /// columns point to the `index`th input by replacing subexpressions with their
317    /// bound equivalents in the `index`th input if necessary.
318    /// Returns whether the rewriting was successful.
319    /// If it returns true, then `expr` is in the context of the `index`th input.
320    /// If it returns false, then still some subexpressions might have been rewritten. However,
321    /// `expr` is still in the global context.
322    pub fn try_localize_to_input_with_bound_expr(
323        &self,
324        expr: &mut MirScalarExpr,
325        index: usize,
326        equivalences: &[Vec<MirScalarExpr>],
327    ) -> bool {
328        // TODO (wangandi): Consider changing this code to be post-order
329        // instead of pre-order? `lookup_inputs` traverses all the nodes in
330        // `e` anyway, so we end up visiting nodes in `e` multiple times
331        // here. Alternatively, consider having the future `PredicateKnowledge`
332        // take over the responsibilities of this code?
333        #[allow(deprecated)]
334        expr.visit_mut_pre_post_nolimit(
335            &mut |e| {
336                let mut inputs = self.lookup_inputs(e);
337                if let Some(first_input) = inputs.next() {
338                    if inputs.next().is_none() && first_input == index {
339                        // there is only one input, and it is equal to index, so we're
340                        // good. do not continue the recursion
341                        return Some(vec![]);
342                    }
343                }
344
345                if let Some(bound_expr) = self.find_bound_expr(e, &[index], equivalences) {
346                    // Replace the subexpression with the equivalent one from input `index`
347                    *e = bound_expr;
348                    // The entire subexpression has been rewritten, so there is
349                    // no need to visit any child expressions.
350                    Some(vec![])
351                } else {
352                    None
353                }
354            },
355            &mut |_| {},
356        );
357        if self.is_localized(expr, index) {
358            // If the localization attempt is successful, all columns in `expr`
359            // should only come from input `index`. Switch to the local context.
360            *expr = self.map_expr_to_local(expr.clone());
361            return true;
362        }
363        false
364    }
365
366    /// Try to find a consequence `c` of the given expression `e` for the given input.
367    ///
368    /// If we return `Some(c)`, that means
369    ///   1. `c` uses only columns from the given input;
370    ///   2. if `c` doesn't hold on a row of the input, then `e` also wouldn't hold;
371    ///   3. if `c` holds on a row of the input, then `e` might or might not hold.
372    /// 1. and 2. means that if we have a join with predicate `e` then we can use `c` for
373    /// pre-filtering a join input before the join. However, 3. means that `e` shouldn't be deleted
374    /// from the join predicates, i.e., we can't do a "traditional" predicate pushdown.
375    ///
376    /// Note that "`c` is a consequence of `e`" is the same thing as 2., see
377    /// <https://en.wikipedia.org/wiki/Contraposition>
378    ///
379    /// Example: For
380    /// `(t1.f2 = 3 AND t2.f2 = 4) OR (t1.f2 = 5 AND t2.f2 = 6)`
381    /// we find
382    /// `t1.f2 = 3 OR t1.f2 = 5` for t1, and
383    /// `t2.f2 = 4 OR t2.f2 = 6` for t2.
384    ///
385    /// Further examples are in TPC-H Q07, Q19, and chbench Q07, Q19.
386    ///
387    /// Parameters:
388    ///  - `expr`: The expression `e` from above. `try_localize_to_input_with_bound_expr` should
389    ///    be called on `expr` before us!
390    ///  - `index`: The index of the join input whose columns we will use.
391    ///  - `equivalences`: Join equivalences that we can use for `try_map_to_input_with_bound_expr`.
392    /// If successful, the returned expression is in the local context of the specified input.
393    pub fn consequence_for_input(
394        &self,
395        expr: &MirScalarExpr,
396        index: usize,
397    ) -> Option<MirScalarExpr> {
398        if self.is_localized(expr, index) {
399            Some(self.map_expr_to_local(expr.clone()))
400        } else {
401            match expr {
402                MirScalarExpr::CallVariadic {
403                    func: VariadicFunc::Or,
404                    exprs: or_args,
405                } => {
406                    // Each OR arg should provide a consequence. If they do, we OR them.
407                    let consequences_per_arg = or_args
408                        .into_iter()
409                        .map(|or_arg| {
410                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(or_arg, index))
411                        })
412                        .collect::<Option<Vec<_>>>()?; // return None if any of them are None
413                    Some(MirScalarExpr::CallVariadic {
414                        func: VariadicFunc::Or,
415                        exprs: consequences_per_arg,
416                    })
417                }
418                MirScalarExpr::CallVariadic {
419                    func: VariadicFunc::And,
420                    exprs: and_args,
421                } => {
422                    // If any of the AND args provide a consequence, then we take those that do,
423                    // and AND them.
424                    let consequences_per_arg = and_args
425                        .into_iter()
426                        .map(|and_arg| {
427                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(and_arg, index))
428                        })
429                        .flat_map(|c| c) // take only those that provide a consequence
430                        .collect_vec();
431                    if consequences_per_arg.is_empty() {
432                        None
433                    } else {
434                        Some(MirScalarExpr::CallVariadic {
435                            func: VariadicFunc::And,
436                            exprs: consequences_per_arg,
437                        })
438                    }
439                }
440                _ => None,
441            }
442        }
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use mz_repr::{Datum, SqlScalarType};
449
450    use crate::scalar::func;
451    use crate::{BinaryFunc, MirScalarExpr, UnaryFunc};
452
453    use super::*;
454
455    #[mz_ore::test]
456    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
457    fn try_map_to_input_with_bound_expr_test() {
458        let input_mapper = JoinInputMapper {
459            arities: vec![2, 3, 3],
460            input_relation: vec![0, 0, 1, 1, 1, 2, 2, 2],
461            prior_arities: vec![0, 2, 5],
462        };
463
464        // keys are numbered by (equivalence class #, input #)
465        let key10 = MirScalarExpr::column(0);
466        let key12 = MirScalarExpr::column(6);
467        let localized_key12 = MirScalarExpr::column(1);
468
469        let mut equivalences = vec![vec![key10.clone(), key12.clone()]];
470
471        // when the column is already part of the target input, all that happens
472        // is that it gets localized
473        let mut cloned = key12.clone();
474        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
475        assert_eq!(MirScalarExpr::column(1), cloned);
476
477        // basic tests that we can find a column's corresponding column in a
478        // different input
479        let mut cloned = key12.clone();
480        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
481        assert_eq!(key10, cloned);
482        let mut cloned = key12.clone();
483        assert_eq!(
484            false,
485            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
486        );
487
488        let key20 = MirScalarExpr::CallUnary {
489            func: UnaryFunc::NegInt32(crate::func::NegInt32),
490            expr: Box::new(MirScalarExpr::column(1)),
491        };
492        let key21 = MirScalarExpr::CallBinary {
493            func: BinaryFunc::AddInt32(func::AddInt32),
494            expr1: Box::new(MirScalarExpr::column(2)),
495            expr2: Box::new(MirScalarExpr::literal(
496                Ok(Datum::Int32(4)),
497                SqlScalarType::Int32,
498            )),
499        };
500        let key22 = MirScalarExpr::column(5);
501        let localized_key22 = MirScalarExpr::column(0);
502        equivalences.push(vec![key22.clone(), key20.clone(), key21.clone()]);
503
504        // basic tests that we can find an expression's corresponding expression in a
505        // different input
506        let mut cloned = key21.clone();
507        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
508        assert_eq!(key20, cloned);
509        let mut cloned = key21.clone();
510        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
511        assert_eq!(localized_key22, cloned);
512
513        // test that `try_map_to_input_with_bound_expr` will map multiple
514        // subexpressions to the corresponding expressions bound to a different input
515        let key_comp = MirScalarExpr::CallBinary {
516            func: func::MulInt32.into(),
517            expr1: Box::new(key12.clone()),
518            expr2: Box::new(key22),
519        };
520        let mut cloned = key_comp.clone();
521        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
522        assert_eq!(
523            MirScalarExpr::CallBinary {
524                func: func::MulInt32.into(),
525                expr1: Box::new(key10.clone()),
526                expr2: Box::new(key20.clone()),
527            },
528            cloned,
529        );
530
531        // test that the function returns None when part
532        // of the expression can be mapped to an input but the rest can't
533        let mut cloned = key_comp.clone();
534        assert_eq!(
535            false,
536            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
537        );
538
539        let key_comp_plus_non_key = MirScalarExpr::CallBinary {
540            func: func::Eq.into(),
541            expr1: Box::new(key_comp),
542            expr2: Box::new(MirScalarExpr::column(7)),
543        };
544        let mut mutab = key_comp_plus_non_key;
545        assert_eq!(
546            false,
547            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 0, &equivalences),
548        );
549
550        let key_comp_multi_input = MirScalarExpr::CallBinary {
551            func: func::Eq.into(),
552            expr1: Box::new(key12),
553            expr2: Box::new(key21),
554        };
555        // test that the function works when part of the expression is already
556        // part of the target input
557        let mut cloned = key_comp_multi_input.clone();
558        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
559        assert_eq!(
560            MirScalarExpr::CallBinary {
561                func: func::Eq.into(),
562                expr1: Box::new(localized_key12),
563                expr2: Box::new(localized_key22),
564            },
565            cloned,
566        );
567        // test that the function works when parts of the expression come from
568        // multiple inputs
569        let mut cloned = key_comp_multi_input.clone();
570        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
571        assert_eq!(
572            MirScalarExpr::CallBinary {
573                func: func::Eq.into(),
574                expr1: Box::new(key10),
575                expr2: Box::new(key20),
576            },
577            cloned,
578        );
579        let mut mutab = key_comp_multi_input;
580        assert_eq!(
581            false,
582            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 1, &equivalences),
583        )
584    }
585}