Skip to main content

mz_expr/relation/
join_input_mapper.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::ops::Range;
12
13use itertools::Itertools;
14use mz_repr::ReprRelationType;
15
16use crate::scalar::func::variadic::{And, Or};
17use crate::visit::Visit;
18use crate::{MirRelationExpr, MirScalarExpr, VariadicFunc};
19
20/// Any column in a join expression exists in two contexts:
21/// 1) It has a position relative to the result of the join (global)
22/// 2) It has a position relative to the specific input it came from (local)
23/// This utility focuses on taking expressions that are in terms of
24/// the local input and re-expressing them in global terms and vice versa.
25///
26/// Methods in this class that take an argument `equivalences` are only
27/// guaranteed to return a correct answer if equivalence classes are in
28/// canonical form.
29/// (See [`crate::relation::canonicalize::canonicalize_equivalences`].)
30#[derive(Debug)]
31pub struct JoinInputMapper {
32    /// The number of columns per input. All other fields in this struct are
33    /// derived using the information in this field.
34    arities: Vec<usize>,
35    /// Looks up which input each column belongs to. Derived from `arities`.
36    /// Stored as a field to avoid recomputation.
37    input_relation: Vec<usize>,
38    /// The sum of the arities of the previous inputs in the join. Derived from
39    /// `arities`. Stored as a field to avoid recomputation.
40    prior_arities: Vec<usize>,
41}
42
43impl JoinInputMapper {
44    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
45    /// columns to local context columns.
46    pub fn new(inputs: &[MirRelationExpr]) -> Self {
47        Self::new_from_input_arities(inputs.iter().map(|i| i.arity()))
48    }
49
50    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
51    /// columns to local context columns. Using this method is more
52    /// efficient if input repr types have been pre-calculated.
53    pub fn new_from_input_types(types: &[ReprRelationType]) -> Self {
54        Self::new_from_input_arities(types.iter().map(|t| t.arity()))
55    }
56
57    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
58    /// columns to local context columns. Using this method is more
59    /// efficient if input arities have been pre-calculated
60    pub fn new_from_input_arities<I>(arities: I) -> Self
61    where
62        I: IntoIterator<Item = usize>,
63    {
64        let arities = arities.into_iter().collect::<Vec<usize>>();
65        let mut offset = 0;
66        let mut prior_arities = Vec::new();
67        for input in 0..arities.len() {
68            prior_arities.push(offset);
69            offset += arities[input];
70        }
71
72        let input_relation = arities
73            .iter()
74            .enumerate()
75            .flat_map(|(r, a)| std::iter::repeat(r).take(*a))
76            .collect::<Vec<_>>();
77
78        JoinInputMapper {
79            arities,
80            input_relation,
81            prior_arities,
82        }
83    }
84
85    /// reports sum of the number of columns of each input
86    pub fn total_columns(&self) -> usize {
87        self.arities.iter().sum()
88    }
89
90    /// reports total numbers of inputs in the join
91    pub fn total_inputs(&self) -> usize {
92        self.arities.len()
93    }
94
95    /// Using the keys that came from each local input,
96    /// figures out which keys remain unique in the larger join
97    /// Currently, we only figure out a small subset of the keys that
98    /// can remain unique.
99    pub fn global_keys<'a, I>(
100        &self,
101        mut local_keys: I,
102        equivalences: &[Vec<MirScalarExpr>],
103    ) -> Vec<Vec<usize>>
104    where
105        I: Iterator<Item = &'a Vec<Vec<usize>>>,
106    {
107        // A relation's uniqueness constraint holds if there is a
108        // sequence of the other relations such that each one has
109        // a uniqueness constraint whose columns are used in join
110        // constraints with relations prior in the sequence.
111        //
112        // Currently, we only:
113        // 1. test for whether the uniqueness constraints for the first input will hold
114        // 2. try one sequence, namely the inputs in order
115        // 3. check that the column themselves are used in the join constraints
116        //    Technically uniqueness constraint would still hold if a 1-to-1
117        //    expression on a unique key is used in the join constraint.
118
119        // for inputs `1..self.total_inputs()`, store a set of columns from that
120        // input that exist in join constraints that have expressions belonging to
121        // earlier inputs.
122        let mut column_with_prior_bound_by_input = vec![BTreeSet::new(); self.total_inputs() - 1];
123        for equivalence in equivalences {
124            // do a scan to find the first input represented in the constraint
125            let min_bound_input = equivalence
126                .iter()
127                .flat_map(|expr| self.lookup_inputs(expr).max())
128                .min();
129            if let Some(min_bound_input) = min_bound_input {
130                for expr in equivalence {
131                    // then store all columns in the constraint that don't come
132                    // from the first input
133                    if let MirScalarExpr::Column(c, _name) = expr {
134                        let (col, input) = self.map_column_to_local(*c);
135                        if input > min_bound_input {
136                            column_with_prior_bound_by_input[input - 1].insert(col);
137                        }
138                    }
139                }
140            }
141        }
142
143        if self.total_inputs() > 0 {
144            let first_input_keys = local_keys.next().unwrap().clone();
145            // for inputs `1..self.total_inputs()`, checks the keys belong to each
146            // input against the storage of columns that exist in join constraints
147            // that have expressions belonging to earlier inputs.
148            let remains_unique = local_keys.enumerate().all(|(index, keys)| {
149                keys.iter().any(|ks| {
150                    ks.iter()
151                        .all(|k| column_with_prior_bound_by_input[index].contains(k))
152                })
153            });
154
155            if remains_unique {
156                return first_input_keys;
157            }
158        }
159        vec![]
160    }
161
162    /// returns the arity for a particular input
163    #[inline]
164    pub fn input_arity(&self, index: usize) -> usize {
165        self.arities[index]
166    }
167
168    /// All column numbers in order for a particular input in the local context
169    #[inline]
170    pub fn local_columns(&self, index: usize) -> Range<usize> {
171        0..self.arities[index]
172    }
173
174    /// All column numbers in order for a particular input in the global context
175    #[inline]
176    pub fn global_columns(&self, index: usize) -> Range<usize> {
177        self.prior_arities[index]..(self.prior_arities[index] + self.arities[index])
178    }
179
180    /// Takes an expression from the global context and creates a new version
181    /// where column references have been remapped to the local context.
182    /// Assumes that all columns in `expr` are from the same input.
183    pub fn map_expr_to_local(&self, mut expr: MirScalarExpr) -> MirScalarExpr {
184        expr.visit_pre_mut(|e| {
185            if let MirScalarExpr::Column(c, _name) = e {
186                *c -= self.prior_arities[self.input_relation[*c]];
187            }
188        });
189        expr
190    }
191
192    /// Takes an expression from the local context of the `index`th input and
193    /// creates a new version where column references have been remapped to the
194    /// global context.
195    pub fn map_expr_to_global(&self, mut expr: MirScalarExpr, index: usize) -> MirScalarExpr {
196        expr.visit_pre_mut(|e| {
197            if let MirScalarExpr::Column(c, _name) = e {
198                *c += self.prior_arities[index];
199            }
200        });
201        expr
202    }
203
204    /// Remap column numbers from the global to the local context.
205    /// Returns a 2-tuple `(<new column number>, <index of input>)`
206    pub fn map_column_to_local(&self, column: usize) -> (usize, usize) {
207        let index = self.input_relation[column];
208        (column - self.prior_arities[index], index)
209    }
210
211    /// Remap a column number from a local context to the global context.
212    pub fn map_column_to_global(&self, column: usize, index: usize) -> usize {
213        column + self.prior_arities[index]
214    }
215
216    /// Takes a sequence of columns in the global context and splits it into
217    /// a `Vec` containing `self.total_inputs()` `BTreeSet`s, each containing
218    /// the localized columns that belong to the particular input.
219    pub fn split_column_set_by_input<'a, I>(&self, columns: I) -> Vec<BTreeSet<usize>>
220    where
221        I: Iterator<Item = &'a usize>,
222    {
223        let mut new_columns = vec![BTreeSet::new(); self.total_inputs()];
224        for column in columns {
225            let (new_col, input) = self.map_column_to_local(*column);
226            new_columns[input].extend(std::iter::once(new_col));
227        }
228        new_columns
229    }
230
231    /// Find the sorted, dedupped set of inputs an expression references
232    pub fn lookup_inputs(&self, expr: &MirScalarExpr) -> impl Iterator<Item = usize> + use<> {
233        expr.support()
234            .iter()
235            .map(|c| self.input_relation[*c])
236            .sorted()
237            .dedup()
238    }
239
240    /// Returns the index of the only input referenced in the given expression.
241    pub fn single_input(&self, expr: &MirScalarExpr) -> Option<usize> {
242        let mut inputs = self.lookup_inputs(expr);
243        if let Some(first_input) = inputs.next() {
244            if inputs.next().is_none() {
245                return Some(first_input);
246            }
247        }
248        None
249    }
250
251    /// Returns whether the given expr refers to columns of only the `index`th input.
252    pub fn is_localized(&self, expr: &MirScalarExpr, index: usize) -> bool {
253        if let Some(single_input) = self.single_input(expr) {
254            if single_input == index {
255                return true;
256            }
257        }
258        false
259    }
260
261    /// Takes an expression in the global context and looks in `equivalences`
262    /// for an equivalent expression (also expressed in the global context) that
263    /// belongs to one or more of the inputs in `bound_inputs`
264    ///
265    /// # Examples
266    ///
267    /// ```
268    /// use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
269    /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
270    ///
271    /// // A two-column schema common to each of the three inputs
272    /// let schema = ReprRelationType::new(vec![
273    ///   ReprScalarType::Int32.nullable(false),
274    ///   ReprScalarType::Int32.nullable(false),
275    /// ]);
276    ///
277    /// // the specific data are not important here.
278    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
279    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
280    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
281    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
282    ///
283    /// // [input0(#0) = input2(#1)], [input0(#1) = input1(#0) = input2(#0)]
284    /// let equivalences = vec![
285    ///   vec![MirScalarExpr::column(0), MirScalarExpr::column(5)],
286    ///   vec![MirScalarExpr::column(1), MirScalarExpr::column(2), MirScalarExpr::column(4)],
287    /// ];
288    ///
289    /// let input_mapper = JoinInputMapper::new(&[input0, input1, input2]);
290    /// assert_eq!(
291    ///   Some(MirScalarExpr::column(4)),
292    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(2), &[2], &equivalences)
293    /// );
294    /// assert_eq!(
295    ///   None,
296    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(0), &[1], &equivalences)
297    /// );
298    /// ```
299    pub fn find_bound_expr(
300        &self,
301        expr: &MirScalarExpr,
302        bound_inputs: &[usize],
303        equivalences: &[Vec<MirScalarExpr>],
304    ) -> Option<MirScalarExpr> {
305        if let Some(equivalence) = equivalences.iter().find(|equivs| equivs.contains(expr)) {
306            if let Some(bound_expr) = equivalence
307                .iter()
308                .find(|expr| self.lookup_inputs(expr).all(|i| bound_inputs.contains(&i)))
309            {
310                return Some(bound_expr.clone());
311            }
312        }
313        None
314    }
315
316    /// Try to rewrite `expr` from the global context so that all the
317    /// columns point to the `index`th input by replacing subexpressions with their
318    /// bound equivalents in the `index`th input if necessary.
319    /// Returns whether the rewriting was successful.
320    /// If it returns true, then `expr` is in the context of the `index`th input.
321    /// If it returns false, then still some subexpressions might have been rewritten. However,
322    /// `expr` is still in the global context.
323    pub fn try_localize_to_input_with_bound_expr(
324        &self,
325        expr: &mut MirScalarExpr,
326        index: usize,
327        equivalences: &[Vec<MirScalarExpr>],
328    ) -> bool {
329        // TODO (wangandi): Consider changing this code to be post-order
330        // instead of pre-order? `lookup_inputs` traverses all the nodes in
331        // `e` anyway, so we end up visiting nodes in `e` multiple times
332        // here. Alternatively, consider having the future `PredicateKnowledge`
333        // take over the responsibilities of this code?
334        #[allow(deprecated)]
335        expr.visit_mut_pre_post_nolimit(
336            &mut |e| {
337                let mut inputs = self.lookup_inputs(e);
338                if let Some(first_input) = inputs.next() {
339                    if inputs.next().is_none() && first_input == index {
340                        // there is only one input, and it is equal to index, so we're
341                        // good. do not continue the recursion
342                        return Some(vec![]);
343                    }
344                }
345
346                if let Some(bound_expr) = self.find_bound_expr(e, &[index], equivalences) {
347                    // Replace the subexpression with the equivalent one from input `index`
348                    *e = bound_expr;
349                    // The entire subexpression has been rewritten, so there is
350                    // no need to visit any child expressions.
351                    Some(vec![])
352                } else {
353                    None
354                }
355            },
356            &mut |_| {},
357        );
358        if self.is_localized(expr, index) {
359            // If the localization attempt is successful, all columns in `expr`
360            // should only come from input `index`. Switch to the local context.
361            *expr = self.map_expr_to_local(expr.clone());
362            return true;
363        }
364        false
365    }
366
367    /// Try to find a consequence `c` of the given expression `e` for the given input.
368    ///
369    /// If we return `Some(c)`, that means
370    ///   1. `c` uses only columns from the given input;
371    ///   2. if `c` doesn't hold on a row of the input, then `e` also wouldn't hold;
372    ///   3. if `c` holds on a row of the input, then `e` might or might not hold.
373    /// 1. and 2. means that if we have a join with predicate `e` then we can use `c` for
374    /// pre-filtering a join input before the join. However, 3. means that `e` shouldn't be deleted
375    /// from the join predicates, i.e., we can't do a "traditional" predicate pushdown.
376    ///
377    /// Note that "`c` is a consequence of `e`" is the same thing as 2., see
378    /// <https://en.wikipedia.org/wiki/Contraposition>
379    ///
380    /// Example: For
381    /// `(t1.f2 = 3 AND t2.f2 = 4) OR (t1.f2 = 5 AND t2.f2 = 6)`
382    /// we find
383    /// `t1.f2 = 3 OR t1.f2 = 5` for t1, and
384    /// `t2.f2 = 4 OR t2.f2 = 6` for t2.
385    ///
386    /// Further examples are in TPC-H Q07, Q19, and chbench Q07, Q19.
387    ///
388    /// Parameters:
389    ///  - `expr`: The expression `e` from above. `try_localize_to_input_with_bound_expr` should
390    ///    be called on `expr` before us!
391    ///  - `index`: The index of the join input whose columns we will use.
392    ///  - `equivalences`: Join equivalences that we can use for `try_map_to_input_with_bound_expr`.
393    /// If successful, the returned expression is in the local context of the specified input.
394    pub fn consequence_for_input(
395        &self,
396        expr: &MirScalarExpr,
397        index: usize,
398    ) -> Option<MirScalarExpr> {
399        if self.is_localized(expr, index) {
400            Some(self.map_expr_to_local(expr.clone()))
401        } else {
402            match expr {
403                MirScalarExpr::CallVariadic {
404                    func: VariadicFunc::Or(_),
405                    exprs: or_args,
406                } => {
407                    // Each OR arg should provide a consequence. If they do, we OR them.
408                    let consequences_per_arg = or_args
409                        .into_iter()
410                        .map(|or_arg| {
411                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(or_arg, index))
412                        })
413                        .collect::<Option<Vec<_>>>()?; // return None if any of them are None
414                    Some(MirScalarExpr::call_variadic(Or, consequences_per_arg))
415                }
416                MirScalarExpr::CallVariadic {
417                    func: VariadicFunc::And(_),
418                    exprs: and_args,
419                } => {
420                    // If any of the AND args provide a consequence, then we take those that do,
421                    // and AND them.
422                    let consequences_per_arg = and_args
423                        .into_iter()
424                        .map(|and_arg| {
425                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(and_arg, index))
426                        })
427                        .flat_map(|c| c) // take only those that provide a consequence
428                        .collect_vec();
429                    if consequences_per_arg.is_empty() {
430                        None
431                    } else {
432                        Some(MirScalarExpr::call_variadic(And, consequences_per_arg))
433                    }
434                }
435                _ => None,
436            }
437        }
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use mz_repr::{Datum, ReprScalarType};
444
445    use crate::scalar::func;
446    use crate::{BinaryFunc, MirScalarExpr, UnaryFunc};
447
448    use super::*;
449
450    #[mz_ore::test]
451    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
452    fn try_map_to_input_with_bound_expr_test() {
453        let input_mapper = JoinInputMapper {
454            arities: vec![2, 3, 3],
455            input_relation: vec![0, 0, 1, 1, 1, 2, 2, 2],
456            prior_arities: vec![0, 2, 5],
457        };
458
459        // keys are numbered by (equivalence class #, input #)
460        let key10 = MirScalarExpr::column(0);
461        let key12 = MirScalarExpr::column(6);
462        let localized_key12 = MirScalarExpr::column(1);
463
464        let mut equivalences = vec![vec![key10.clone(), key12.clone()]];
465
466        // when the column is already part of the target input, all that happens
467        // is that it gets localized
468        let mut cloned = key12.clone();
469        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
470        assert_eq!(MirScalarExpr::column(1), cloned);
471
472        // basic tests that we can find a column's corresponding column in a
473        // different input
474        let mut cloned = key12.clone();
475        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
476        assert_eq!(key10, cloned);
477        let mut cloned = key12.clone();
478        assert_eq!(
479            false,
480            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
481        );
482
483        let key20 = MirScalarExpr::CallUnary {
484            func: UnaryFunc::NegInt32(crate::func::NegInt32),
485            expr: Box::new(MirScalarExpr::column(1)),
486        };
487        let key21 = MirScalarExpr::CallBinary {
488            func: BinaryFunc::AddInt32(func::AddInt32),
489            expr1: Box::new(MirScalarExpr::column(2)),
490            expr2: Box::new(MirScalarExpr::literal(
491                Ok(Datum::Int32(4)),
492                ReprScalarType::Int32,
493            )),
494        };
495        let key22 = MirScalarExpr::column(5);
496        let localized_key22 = MirScalarExpr::column(0);
497        equivalences.push(vec![key22.clone(), key20.clone(), key21.clone()]);
498
499        // basic tests that we can find an expression's corresponding expression in a
500        // different input
501        let mut cloned = key21.clone();
502        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
503        assert_eq!(key20, cloned);
504        let mut cloned = key21.clone();
505        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
506        assert_eq!(localized_key22, cloned);
507
508        // test that `try_map_to_input_with_bound_expr` will map multiple
509        // subexpressions to the corresponding expressions bound to a different input
510        let key_comp = MirScalarExpr::CallBinary {
511            func: func::MulInt32.into(),
512            expr1: Box::new(key12.clone()),
513            expr2: Box::new(key22),
514        };
515        let mut cloned = key_comp.clone();
516        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
517        assert_eq!(
518            MirScalarExpr::CallBinary {
519                func: func::MulInt32.into(),
520                expr1: Box::new(key10.clone()),
521                expr2: Box::new(key20.clone()),
522            },
523            cloned,
524        );
525
526        // test that the function returns None when part
527        // of the expression can be mapped to an input but the rest can't
528        let mut cloned = key_comp.clone();
529        assert_eq!(
530            false,
531            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
532        );
533
534        let key_comp_plus_non_key = MirScalarExpr::CallBinary {
535            func: func::Eq.into(),
536            expr1: Box::new(key_comp),
537            expr2: Box::new(MirScalarExpr::column(7)),
538        };
539        let mut mutab = key_comp_plus_non_key;
540        assert_eq!(
541            false,
542            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 0, &equivalences),
543        );
544
545        let key_comp_multi_input = MirScalarExpr::CallBinary {
546            func: func::Eq.into(),
547            expr1: Box::new(key12),
548            expr2: Box::new(key21),
549        };
550        // test that the function works when part of the expression is already
551        // part of the target input
552        let mut cloned = key_comp_multi_input.clone();
553        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
554        assert_eq!(
555            MirScalarExpr::CallBinary {
556                func: func::Eq.into(),
557                expr1: Box::new(localized_key12),
558                expr2: Box::new(localized_key22),
559            },
560            cloned,
561        );
562        // test that the function works when parts of the expression come from
563        // multiple inputs
564        let mut cloned = key_comp_multi_input.clone();
565        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
566        assert_eq!(
567            MirScalarExpr::CallBinary {
568                func: func::Eq.into(),
569                expr1: Box::new(key10),
570                expr2: Box::new(key20),
571            },
572            cloned,
573        );
574        let mut mutab = key_comp_multi_input;
575        assert_eq!(
576            false,
577            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 1, &equivalences),
578        )
579    }
580}