Skip to main content

mz_expr/relation/
join_input_mapper.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::ops::Range;
12
13use itertools::Itertools;
14use mz_repr::ReprRelationType;
15
16use crate::scalar::columns::Columns;
17use crate::scalar::func::variadic::{And, Or};
18use crate::visit::Visit;
19use crate::{MirRelationExpr, MirScalarExpr, VariadicFunc};
20
21/// Any column in a join expression exists in two contexts:
22/// 1) It has a position relative to the result of the join (global)
23/// 2) It has a position relative to the specific input it came from (local)
24/// This utility focuses on taking expressions that are in terms of
25/// the local input and re-expressing them in global terms and vice versa.
26///
27/// Methods in this class that take an argument `equivalences` are only
28/// guaranteed to return a correct answer if equivalence classes are in
29/// canonical form.
30/// (See [`crate::relation::canonicalize::canonicalize_equivalences`].)
31#[derive(Debug)]
32pub struct JoinInputMapper {
33    /// The number of columns per input. All other fields in this struct are
34    /// derived using the information in this field.
35    arities: Vec<usize>,
36    /// Looks up which input each column belongs to. Derived from `arities`.
37    /// Stored as a field to avoid recomputation.
38    input_relation: Vec<usize>,
39    /// The sum of the arities of the previous inputs in the join. Derived from
40    /// `arities`. Stored as a field to avoid recomputation.
41    prior_arities: Vec<usize>,
42}
43
44impl JoinInputMapper {
45    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
46    /// columns to local context columns.
47    pub fn new(inputs: &[MirRelationExpr]) -> Self {
48        Self::new_from_input_arities(inputs.iter().map(|i| i.arity()))
49    }
50
51    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
52    /// columns to local context columns. Using this method is more
53    /// efficient if input repr types have been pre-calculated.
54    pub fn new_from_input_types(types: &[ReprRelationType]) -> Self {
55        Self::new_from_input_arities(types.iter().map(|t| t.arity()))
56    }
57
58    /// Creates a new `JoinInputMapper` and calculates the mapping of global context
59    /// columns to local context columns. Using this method is more
60    /// efficient if input arities have been pre-calculated
61    pub fn new_from_input_arities<I>(arities: I) -> Self
62    where
63        I: IntoIterator<Item = usize>,
64    {
65        let arities = arities.into_iter().collect::<Vec<usize>>();
66        let mut offset = 0;
67        let mut prior_arities = Vec::new();
68        for input in 0..arities.len() {
69            prior_arities.push(offset);
70            offset += arities[input];
71        }
72
73        let input_relation = arities
74            .iter()
75            .enumerate()
76            .flat_map(|(r, a)| std::iter::repeat(r).take(*a))
77            .collect::<Vec<_>>();
78
79        JoinInputMapper {
80            arities,
81            input_relation,
82            prior_arities,
83        }
84    }
85
86    /// reports sum of the number of columns of each input
87    pub fn total_columns(&self) -> usize {
88        self.arities.iter().sum()
89    }
90
91    /// reports total numbers of inputs in the join
92    pub fn total_inputs(&self) -> usize {
93        self.arities.len()
94    }
95
96    /// Using the keys that came from each local input,
97    /// figures out which keys remain unique in the larger join
98    /// Currently, we only figure out a small subset of the keys that
99    /// can remain unique.
100    pub fn global_keys<'a, I>(
101        &self,
102        mut local_keys: I,
103        equivalences: &[Vec<MirScalarExpr>],
104    ) -> Vec<Vec<usize>>
105    where
106        I: Iterator<Item = &'a Vec<Vec<usize>>>,
107    {
108        // A relation's uniqueness constraint holds if there is a
109        // sequence of the other relations such that each one has
110        // a uniqueness constraint whose columns are used in join
111        // constraints with relations prior in the sequence.
112        //
113        // Currently, we only:
114        // 1. test for whether the uniqueness constraints for the first input will hold
115        // 2. try one sequence, namely the inputs in order
116        // 3. check that the column themselves are used in the join constraints
117        //    Technically uniqueness constraint would still hold if a 1-to-1
118        //    expression on a unique key is used in the join constraint.
119
120        // for inputs `1..self.total_inputs()`, store a set of columns from that
121        // input that exist in join constraints that have expressions belonging to
122        // earlier inputs.
123        let mut column_with_prior_bound_by_input = vec![BTreeSet::new(); self.total_inputs() - 1];
124        for equivalence in equivalences {
125            // do a scan to find the first input represented in the constraint
126            let min_bound_input = equivalence
127                .iter()
128                .flat_map(|expr| self.lookup_inputs(expr).max())
129                .min();
130            if let Some(min_bound_input) = min_bound_input {
131                for expr in equivalence {
132                    // then store all columns in the constraint that don't come
133                    // from the first input
134                    if let MirScalarExpr::Column(c, _name) = expr {
135                        let (col, input) = self.map_column_to_local(*c);
136                        if input > min_bound_input {
137                            column_with_prior_bound_by_input[input - 1].insert(col);
138                        }
139                    }
140                }
141            }
142        }
143
144        if self.total_inputs() > 0 {
145            let first_input_keys = local_keys.next().unwrap().clone();
146            // for inputs `1..self.total_inputs()`, checks the keys belong to each
147            // input against the storage of columns that exist in join constraints
148            // that have expressions belonging to earlier inputs.
149            let remains_unique = local_keys.enumerate().all(|(index, keys)| {
150                keys.iter().any(|ks| {
151                    ks.iter()
152                        .all(|k| column_with_prior_bound_by_input[index].contains(k))
153                })
154            });
155
156            if remains_unique {
157                return first_input_keys;
158            }
159        }
160        vec![]
161    }
162
163    /// returns the arity for a particular input
164    #[inline]
165    pub fn input_arity(&self, index: usize) -> usize {
166        self.arities[index]
167    }
168
169    /// All column numbers in order for a particular input in the local context
170    #[inline]
171    pub fn local_columns(&self, index: usize) -> Range<usize> {
172        0..self.arities[index]
173    }
174
175    /// All column numbers in order for a particular input in the global context
176    #[inline]
177    pub fn global_columns(&self, index: usize) -> Range<usize> {
178        self.prior_arities[index]..(self.prior_arities[index] + self.arities[index])
179    }
180
181    /// Takes an expression from the global context and creates a new version
182    /// where column references have been remapped to the local context.
183    /// Assumes that all columns in `expr` are from the same input.
184    pub fn map_expr_to_local<C: Columns + Sized>(&self, mut expr: C) -> C {
185        expr.visit_columns(|c| {
186            *c -= self.prior_arities[self.input_relation[*c]];
187        });
188        expr
189    }
190
191    /// Takes an expression from the local context of the `index`th input and
192    /// creates a new version where column references have been remapped to the
193    /// global context.
194    pub fn map_expr_to_global<C: Columns + Sized>(&self, mut expr: C, index: usize) -> C {
195        expr.visit_columns(|c| {
196            *c += self.prior_arities[index];
197        });
198        expr
199    }
200
201    /// Remap column numbers from the global to the local context.
202    /// Returns a 2-tuple `(<new column number>, <index of input>)`
203    pub fn map_column_to_local(&self, column: usize) -> (usize, usize) {
204        let index = self.input_relation[column];
205        (column - self.prior_arities[index], index)
206    }
207
208    /// Remap a column number from a local context to the global context.
209    pub fn map_column_to_global(&self, column: usize, index: usize) -> usize {
210        column + self.prior_arities[index]
211    }
212
213    /// Takes a sequence of columns in the global context and splits it into
214    /// a `Vec` containing `self.total_inputs()` `BTreeSet`s, each containing
215    /// the localized columns that belong to the particular input.
216    pub fn split_column_set_by_input<'a, I>(&self, columns: I) -> Vec<BTreeSet<usize>>
217    where
218        I: Iterator<Item = &'a usize>,
219    {
220        let mut new_columns = vec![BTreeSet::new(); self.total_inputs()];
221        for column in columns {
222            let (new_col, input) = self.map_column_to_local(*column);
223            new_columns[input].extend(std::iter::once(new_col));
224        }
225        new_columns
226    }
227
228    /// Find the sorted, dedupped set of inputs an expression references
229    pub fn lookup_inputs<C: Columns>(&self, expr: &C) -> impl Iterator<Item = usize> + use<C> {
230        expr.support()
231            .iter()
232            .map(|c| self.input_relation[*c])
233            .sorted()
234            .dedup()
235    }
236
237    /// Returns the index of the only input referenced in the given expression.
238    pub fn single_input(&self, expr: &impl Columns) -> Option<usize> {
239        let mut inputs = self.lookup_inputs(expr);
240        if let Some(first_input) = inputs.next() {
241            if inputs.next().is_none() {
242                return Some(first_input);
243            }
244        }
245        None
246    }
247
248    /// Returns whether the given expr refers to columns of only the `index`th input.
249    pub fn is_localized(&self, expr: &impl Columns, index: usize) -> bool {
250        if let Some(single_input) = self.single_input(expr) {
251            if single_input == index {
252                return true;
253            }
254        }
255        false
256    }
257
258    /// Takes an expression in the global context and looks in `equivalences`
259    /// for an equivalent expression (also expressed in the global context) that
260    /// belongs to one or more of the inputs in `bound_inputs`
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
266    /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
267    ///
268    /// // A two-column schema common to each of the three inputs
269    /// let schema = ReprRelationType::new(vec![
270    ///   ReprScalarType::Int32.nullable(false),
271    ///   ReprScalarType::Int32.nullable(false),
272    /// ]);
273    ///
274    /// // the specific data are not important here.
275    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
276    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
277    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
278    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
279    ///
280    /// // [input0(#0) = input2(#1)], [input0(#1) = input1(#0) = input2(#0)]
281    /// let equivalences = vec![
282    ///   vec![MirScalarExpr::column(0), MirScalarExpr::column(5)],
283    ///   vec![MirScalarExpr::column(1), MirScalarExpr::column(2), MirScalarExpr::column(4)],
284    /// ];
285    ///
286    /// let input_mapper = JoinInputMapper::new(&[input0, input1, input2]);
287    /// assert_eq!(
288    ///   Some(MirScalarExpr::column(4)),
289    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(2), &[2], &equivalences)
290    /// );
291    /// assert_eq!(
292    ///   None,
293    ///   input_mapper.find_bound_expr(&MirScalarExpr::column(0), &[1], &equivalences)
294    /// );
295    /// ```
296    pub fn find_bound_expr<C: Columns + Clone + Eq>(
297        &self,
298        expr: &C,
299        bound_inputs: &[usize],
300        equivalences: &[Vec<C>],
301    ) -> Option<C> {
302        if let Some(equivalence) = equivalences.iter().find(|equivs| equivs.contains(expr)) {
303            if let Some(bound_expr) = equivalence
304                .iter()
305                .find(|expr| self.lookup_inputs(*expr).all(|i| bound_inputs.contains(&i)))
306            {
307                return Some(bound_expr.clone());
308            }
309        }
310        None
311    }
312
313    /// Try to rewrite `expr` from the global context so that all the
314    /// columns point to the `index`th input by replacing subexpressions with their
315    /// bound equivalents in the `index`th input if necessary.
316    /// Returns whether the rewriting was successful.
317    /// If it returns true, then `expr` is in the context of the `index`th input.
318    /// If it returns false, then still some subexpressions might have been rewritten. However,
319    /// `expr` is still in the global context.
320    pub fn try_localize_to_input_with_bound_expr(
321        &self,
322        expr: &mut MirScalarExpr,
323        index: usize,
324        equivalences: &[Vec<MirScalarExpr>],
325    ) -> bool {
326        // TODO (wangandi): Consider changing this code to be post-order
327        // instead of pre-order? `lookup_inputs` traverses all the nodes in
328        // `e` anyway, so we end up visiting nodes in `e` multiple times
329        // here. Alternatively, consider having the future `PredicateKnowledge`
330        // take over the responsibilities of this code?
331        #[allow(deprecated)]
332        expr.visit_mut_pre_post_nolimit(
333            &mut |e| {
334                let mut inputs = self.lookup_inputs(e);
335                if let Some(first_input) = inputs.next() {
336                    if inputs.next().is_none() && first_input == index {
337                        // there is only one input, and it is equal to index, so we're
338                        // good. do not continue the recursion
339                        return Some(vec![]);
340                    }
341                }
342
343                if let Some(bound_expr) = self.find_bound_expr(e, &[index], equivalences) {
344                    // Replace the subexpression with the equivalent one from input `index`
345                    *e = bound_expr;
346                    // The entire subexpression has been rewritten, so there is
347                    // no need to visit any child expressions.
348                    Some(vec![])
349                } else {
350                    None
351                }
352            },
353            &mut |_| {},
354        );
355        if self.is_localized(expr, index) {
356            // If the localization attempt is successful, all columns in `expr`
357            // should only come from input `index`. Switch to the local context.
358            *expr = self.map_expr_to_local(expr.clone());
359            return true;
360        }
361        false
362    }
363
364    /// Try to find a consequence `c` of the given expression `e` for the given input.
365    ///
366    /// If we return `Some(c)`, that means
367    ///   1. `c` uses only columns from the given input;
368    ///   2. if `c` doesn't hold on a row of the input, then `e` also wouldn't hold;
369    ///   3. if `c` holds on a row of the input, then `e` might or might not hold.
370    /// 1. and 2. means that if we have a join with predicate `e` then we can use `c` for
371    /// pre-filtering a join input before the join. However, 3. means that `e` shouldn't be deleted
372    /// from the join predicates, i.e., we can't do a "traditional" predicate pushdown.
373    ///
374    /// Note that "`c` is a consequence of `e`" is the same thing as 2., see
375    /// <https://en.wikipedia.org/wiki/Contraposition>
376    ///
377    /// Example: For
378    /// `(t1.f2 = 3 AND t2.f2 = 4) OR (t1.f2 = 5 AND t2.f2 = 6)`
379    /// we find
380    /// `t1.f2 = 3 OR t1.f2 = 5` for t1, and
381    /// `t2.f2 = 4 OR t2.f2 = 6` for t2.
382    ///
383    /// Further examples are in TPC-H Q07, Q19, and chbench Q07, Q19.
384    ///
385    /// Parameters:
386    ///  - `expr`: The expression `e` from above. `try_localize_to_input_with_bound_expr` should
387    ///    be called on `expr` before us!
388    ///  - `index`: The index of the join input whose columns we will use.
389    ///  - `equivalences`: Join equivalences that we can use for `try_map_to_input_with_bound_expr`.
390    /// If successful, the returned expression is in the local context of the specified input.
391    pub fn consequence_for_input(
392        &self,
393        expr: &MirScalarExpr,
394        index: usize,
395    ) -> Option<MirScalarExpr> {
396        if self.is_localized(expr, index) {
397            Some(self.map_expr_to_local(expr.clone()))
398        } else {
399            match expr {
400                MirScalarExpr::CallVariadic {
401                    func: VariadicFunc::Or(_),
402                    exprs: or_args,
403                } => {
404                    // Each OR arg should provide a consequence. If they do, we OR them.
405                    let consequences_per_arg = or_args
406                        .into_iter()
407                        .map(|or_arg| {
408                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(or_arg, index))
409                        })
410                        .collect::<Option<Vec<_>>>()?; // return None if any of them are None
411                    Some(MirScalarExpr::call_variadic(Or, consequences_per_arg))
412                }
413                MirScalarExpr::CallVariadic {
414                    func: VariadicFunc::And(_),
415                    exprs: and_args,
416                } => {
417                    // If any of the AND args provide a consequence, then we take those that do,
418                    // and AND them.
419                    let consequences_per_arg = and_args
420                        .into_iter()
421                        .map(|and_arg| {
422                            mz_ore::stack::maybe_grow(|| self.consequence_for_input(and_arg, index))
423                        })
424                        .flat_map(|c| c) // take only those that provide a consequence
425                        .collect_vec();
426                    if consequences_per_arg.is_empty() {
427                        None
428                    } else {
429                        Some(MirScalarExpr::call_variadic(And, consequences_per_arg))
430                    }
431                }
432                _ => None,
433            }
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use mz_repr::{Datum, ReprScalarType};
441
442    use crate::scalar::func;
443    use crate::{BinaryFunc, MirScalarExpr, UnaryFunc};
444
445    use super::*;
446
447    #[mz_ore::test]
448    #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
449    fn try_map_to_input_with_bound_expr_test() {
450        let input_mapper = JoinInputMapper {
451            arities: vec![2, 3, 3],
452            input_relation: vec![0, 0, 1, 1, 1, 2, 2, 2],
453            prior_arities: vec![0, 2, 5],
454        };
455
456        // keys are numbered by (equivalence class #, input #)
457        let key10 = MirScalarExpr::column(0);
458        let key12 = MirScalarExpr::column(6);
459        let localized_key12 = MirScalarExpr::column(1);
460
461        let mut equivalences = vec![vec![key10.clone(), key12.clone()]];
462
463        // when the column is already part of the target input, all that happens
464        // is that it gets localized
465        let mut cloned = key12.clone();
466        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
467        assert_eq!(MirScalarExpr::column(1), cloned);
468
469        // basic tests that we can find a column's corresponding column in a
470        // different input
471        let mut cloned = key12.clone();
472        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
473        assert_eq!(key10, cloned);
474        let mut cloned = key12.clone();
475        assert_eq!(
476            false,
477            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
478        );
479
480        let key20 = MirScalarExpr::CallUnary {
481            func: UnaryFunc::NegInt32(crate::func::NegInt32),
482            expr: Box::new(MirScalarExpr::column(1)),
483        };
484        let key21 = MirScalarExpr::CallBinary {
485            func: BinaryFunc::AddInt32(func::AddInt32),
486            expr1: Box::new(MirScalarExpr::column(2)),
487            expr2: Box::new(MirScalarExpr::literal(
488                Ok(Datum::Int32(4)),
489                ReprScalarType::Int32,
490            )),
491        };
492        let key22 = MirScalarExpr::column(5);
493        let localized_key22 = MirScalarExpr::column(0);
494        equivalences.push(vec![key22.clone(), key20.clone(), key21.clone()]);
495
496        // basic tests that we can find an expression's corresponding expression in a
497        // different input
498        let mut cloned = key21.clone();
499        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
500        assert_eq!(key20, cloned);
501        let mut cloned = key21.clone();
502        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
503        assert_eq!(localized_key22, cloned);
504
505        // test that `try_map_to_input_with_bound_expr` will map multiple
506        // subexpressions to the corresponding expressions bound to a different input
507        let key_comp = MirScalarExpr::CallBinary {
508            func: func::MulInt32.into(),
509            expr1: Box::new(key12.clone()),
510            expr2: Box::new(key22),
511        };
512        let mut cloned = key_comp.clone();
513        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
514        assert_eq!(
515            MirScalarExpr::CallBinary {
516                func: func::MulInt32.into(),
517                expr1: Box::new(key10.clone()),
518                expr2: Box::new(key20.clone()),
519            },
520            cloned,
521        );
522
523        // test that the function returns None when part
524        // of the expression can be mapped to an input but the rest can't
525        let mut cloned = key_comp.clone();
526        assert_eq!(
527            false,
528            input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
529        );
530
531        let key_comp_plus_non_key = MirScalarExpr::CallBinary {
532            func: func::Eq.into(),
533            expr1: Box::new(key_comp),
534            expr2: Box::new(MirScalarExpr::column(7)),
535        };
536        let mut mutab = key_comp_plus_non_key;
537        assert_eq!(
538            false,
539            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 0, &equivalences),
540        );
541
542        let key_comp_multi_input = MirScalarExpr::CallBinary {
543            func: func::Eq.into(),
544            expr1: Box::new(key12),
545            expr2: Box::new(key21),
546        };
547        // test that the function works when part of the expression is already
548        // part of the target input
549        let mut cloned = key_comp_multi_input.clone();
550        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
551        assert_eq!(
552            MirScalarExpr::CallBinary {
553                func: func::Eq.into(),
554                expr1: Box::new(localized_key12),
555                expr2: Box::new(localized_key22),
556            },
557            cloned,
558        );
559        // test that the function works when parts of the expression come from
560        // multiple inputs
561        let mut cloned = key_comp_multi_input.clone();
562        input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
563        assert_eq!(
564            MirScalarExpr::CallBinary {
565                func: func::Eq.into(),
566                expr1: Box::new(key10),
567                expr2: Box::new(key20),
568            },
569            cloned,
570        );
571        let mut mutab = key_comp_multi_input;
572        assert_eq!(
573            false,
574            input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 1, &equivalences),
575        )
576    }
577}