mz_expr/relation/join_input_mapper.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::ops::Range;
12
13use itertools::Itertools;
14use mz_repr::ReprRelationType;
15
16use crate::scalar::columns::Columns;
17use crate::scalar::func::variadic::{And, Or};
18use crate::visit::Visit;
19use crate::{MirRelationExpr, MirScalarExpr, VariadicFunc};
20
21/// Any column in a join expression exists in two contexts:
22/// 1) It has a position relative to the result of the join (global)
23/// 2) It has a position relative to the specific input it came from (local)
24/// This utility focuses on taking expressions that are in terms of
25/// the local input and re-expressing them in global terms and vice versa.
26///
27/// Methods in this class that take an argument `equivalences` are only
28/// guaranteed to return a correct answer if equivalence classes are in
29/// canonical form.
30/// (See [`crate::relation::canonicalize::canonicalize_equivalences`].)
31#[derive(Debug)]
32pub struct JoinInputMapper {
33 /// The number of columns per input. All other fields in this struct are
34 /// derived using the information in this field.
35 arities: Vec<usize>,
36 /// Looks up which input each column belongs to. Derived from `arities`.
37 /// Stored as a field to avoid recomputation.
38 input_relation: Vec<usize>,
39 /// The sum of the arities of the previous inputs in the join. Derived from
40 /// `arities`. Stored as a field to avoid recomputation.
41 prior_arities: Vec<usize>,
42}
43
44impl JoinInputMapper {
45 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
46 /// columns to local context columns.
47 pub fn new(inputs: &[MirRelationExpr]) -> Self {
48 Self::new_from_input_arities(inputs.iter().map(|i| i.arity()))
49 }
50
51 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
52 /// columns to local context columns. Using this method is more
53 /// efficient if input repr types have been pre-calculated.
54 pub fn new_from_input_types(types: &[ReprRelationType]) -> Self {
55 Self::new_from_input_arities(types.iter().map(|t| t.arity()))
56 }
57
58 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
59 /// columns to local context columns. Using this method is more
60 /// efficient if input arities have been pre-calculated
61 pub fn new_from_input_arities<I>(arities: I) -> Self
62 where
63 I: IntoIterator<Item = usize>,
64 {
65 let arities = arities.into_iter().collect::<Vec<usize>>();
66 let mut offset = 0;
67 let mut prior_arities = Vec::new();
68 for input in 0..arities.len() {
69 prior_arities.push(offset);
70 offset += arities[input];
71 }
72
73 let input_relation = arities
74 .iter()
75 .enumerate()
76 .flat_map(|(r, a)| std::iter::repeat(r).take(*a))
77 .collect::<Vec<_>>();
78
79 JoinInputMapper {
80 arities,
81 input_relation,
82 prior_arities,
83 }
84 }
85
86 /// reports sum of the number of columns of each input
87 pub fn total_columns(&self) -> usize {
88 self.arities.iter().sum()
89 }
90
91 /// reports total numbers of inputs in the join
92 pub fn total_inputs(&self) -> usize {
93 self.arities.len()
94 }
95
96 /// Using the keys that came from each local input,
97 /// figures out which keys remain unique in the larger join
98 /// Currently, we only figure out a small subset of the keys that
99 /// can remain unique.
100 pub fn global_keys<'a, I>(
101 &self,
102 mut local_keys: I,
103 equivalences: &[Vec<MirScalarExpr>],
104 ) -> Vec<Vec<usize>>
105 where
106 I: Iterator<Item = &'a Vec<Vec<usize>>>,
107 {
108 // A relation's uniqueness constraint holds if there is a
109 // sequence of the other relations such that each one has
110 // a uniqueness constraint whose columns are used in join
111 // constraints with relations prior in the sequence.
112 //
113 // Currently, we only:
114 // 1. test for whether the uniqueness constraints for the first input will hold
115 // 2. try one sequence, namely the inputs in order
116 // 3. check that the column themselves are used in the join constraints
117 // Technically uniqueness constraint would still hold if a 1-to-1
118 // expression on a unique key is used in the join constraint.
119
120 // for inputs `1..self.total_inputs()`, store a set of columns from that
121 // input that exist in join constraints that have expressions belonging to
122 // earlier inputs.
123 let mut column_with_prior_bound_by_input = vec![BTreeSet::new(); self.total_inputs() - 1];
124 for equivalence in equivalences {
125 // do a scan to find the first input represented in the constraint
126 let min_bound_input = equivalence
127 .iter()
128 .flat_map(|expr| self.lookup_inputs(expr).max())
129 .min();
130 if let Some(min_bound_input) = min_bound_input {
131 for expr in equivalence {
132 // then store all columns in the constraint that don't come
133 // from the first input
134 if let MirScalarExpr::Column(c, _name) = expr {
135 let (col, input) = self.map_column_to_local(*c);
136 if input > min_bound_input {
137 column_with_prior_bound_by_input[input - 1].insert(col);
138 }
139 }
140 }
141 }
142 }
143
144 if self.total_inputs() > 0 {
145 let first_input_keys = local_keys.next().unwrap().clone();
146 // for inputs `1..self.total_inputs()`, checks the keys belong to each
147 // input against the storage of columns that exist in join constraints
148 // that have expressions belonging to earlier inputs.
149 let remains_unique = local_keys.enumerate().all(|(index, keys)| {
150 keys.iter().any(|ks| {
151 ks.iter()
152 .all(|k| column_with_prior_bound_by_input[index].contains(k))
153 })
154 });
155
156 if remains_unique {
157 return first_input_keys;
158 }
159 }
160 vec![]
161 }
162
163 /// returns the arity for a particular input
164 #[inline]
165 pub fn input_arity(&self, index: usize) -> usize {
166 self.arities[index]
167 }
168
169 /// All column numbers in order for a particular input in the local context
170 #[inline]
171 pub fn local_columns(&self, index: usize) -> Range<usize> {
172 0..self.arities[index]
173 }
174
175 /// All column numbers in order for a particular input in the global context
176 #[inline]
177 pub fn global_columns(&self, index: usize) -> Range<usize> {
178 self.prior_arities[index]..(self.prior_arities[index] + self.arities[index])
179 }
180
181 /// Takes an expression from the global context and creates a new version
182 /// where column references have been remapped to the local context.
183 /// Assumes that all columns in `expr` are from the same input.
184 pub fn map_expr_to_local<C: Columns + Sized>(&self, mut expr: C) -> C {
185 expr.visit_columns(|c| {
186 *c -= self.prior_arities[self.input_relation[*c]];
187 });
188 expr
189 }
190
191 /// Takes an expression from the local context of the `index`th input and
192 /// creates a new version where column references have been remapped to the
193 /// global context.
194 pub fn map_expr_to_global<C: Columns + Sized>(&self, mut expr: C, index: usize) -> C {
195 expr.visit_columns(|c| {
196 *c += self.prior_arities[index];
197 });
198 expr
199 }
200
201 /// Remap column numbers from the global to the local context.
202 /// Returns a 2-tuple `(<new column number>, <index of input>)`
203 pub fn map_column_to_local(&self, column: usize) -> (usize, usize) {
204 let index = self.input_relation[column];
205 (column - self.prior_arities[index], index)
206 }
207
208 /// Remap a column number from a local context to the global context.
209 pub fn map_column_to_global(&self, column: usize, index: usize) -> usize {
210 column + self.prior_arities[index]
211 }
212
213 /// Takes a sequence of columns in the global context and splits it into
214 /// a `Vec` containing `self.total_inputs()` `BTreeSet`s, each containing
215 /// the localized columns that belong to the particular input.
216 pub fn split_column_set_by_input<'a, I>(&self, columns: I) -> Vec<BTreeSet<usize>>
217 where
218 I: Iterator<Item = &'a usize>,
219 {
220 let mut new_columns = vec![BTreeSet::new(); self.total_inputs()];
221 for column in columns {
222 let (new_col, input) = self.map_column_to_local(*column);
223 new_columns[input].extend(std::iter::once(new_col));
224 }
225 new_columns
226 }
227
228 /// Find the sorted, dedupped set of inputs an expression references
229 pub fn lookup_inputs<C: Columns>(&self, expr: &C) -> impl Iterator<Item = usize> + use<C> {
230 expr.support()
231 .iter()
232 .map(|c| self.input_relation[*c])
233 .sorted()
234 .dedup()
235 }
236
237 /// Returns the index of the only input referenced in the given expression.
238 pub fn single_input(&self, expr: &impl Columns) -> Option<usize> {
239 let mut inputs = self.lookup_inputs(expr);
240 if let Some(first_input) = inputs.next() {
241 if inputs.next().is_none() {
242 return Some(first_input);
243 }
244 }
245 None
246 }
247
248 /// Returns whether the given expr refers to columns of only the `index`th input.
249 pub fn is_localized(&self, expr: &impl Columns, index: usize) -> bool {
250 if let Some(single_input) = self.single_input(expr) {
251 if single_input == index {
252 return true;
253 }
254 }
255 false
256 }
257
258 /// Takes an expression in the global context and looks in `equivalences`
259 /// for an equivalent expression (also expressed in the global context) that
260 /// belongs to one or more of the inputs in `bound_inputs`
261 ///
262 /// # Examples
263 ///
264 /// ```
265 /// use mz_repr::{Datum, ReprColumnType, ReprRelationType, ReprScalarType};
266 /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
267 ///
268 /// // A two-column schema common to each of the three inputs
269 /// let schema = ReprRelationType::new(vec![
270 /// ReprScalarType::Int32.nullable(false),
271 /// ReprScalarType::Int32.nullable(false),
272 /// ]);
273 ///
274 /// // the specific data are not important here.
275 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
276 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
277 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
278 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
279 ///
280 /// // [input0(#0) = input2(#1)], [input0(#1) = input1(#0) = input2(#0)]
281 /// let equivalences = vec![
282 /// vec![MirScalarExpr::column(0), MirScalarExpr::column(5)],
283 /// vec![MirScalarExpr::column(1), MirScalarExpr::column(2), MirScalarExpr::column(4)],
284 /// ];
285 ///
286 /// let input_mapper = JoinInputMapper::new(&[input0, input1, input2]);
287 /// assert_eq!(
288 /// Some(MirScalarExpr::column(4)),
289 /// input_mapper.find_bound_expr(&MirScalarExpr::column(2), &[2], &equivalences)
290 /// );
291 /// assert_eq!(
292 /// None,
293 /// input_mapper.find_bound_expr(&MirScalarExpr::column(0), &[1], &equivalences)
294 /// );
295 /// ```
296 pub fn find_bound_expr<C: Columns + Clone + Eq>(
297 &self,
298 expr: &C,
299 bound_inputs: &[usize],
300 equivalences: &[Vec<C>],
301 ) -> Option<C> {
302 if let Some(equivalence) = equivalences.iter().find(|equivs| equivs.contains(expr)) {
303 if let Some(bound_expr) = equivalence
304 .iter()
305 .find(|expr| self.lookup_inputs(*expr).all(|i| bound_inputs.contains(&i)))
306 {
307 return Some(bound_expr.clone());
308 }
309 }
310 None
311 }
312
313 /// Try to rewrite `expr` from the global context so that all the
314 /// columns point to the `index`th input by replacing subexpressions with their
315 /// bound equivalents in the `index`th input if necessary.
316 /// Returns whether the rewriting was successful.
317 /// If it returns true, then `expr` is in the context of the `index`th input.
318 /// If it returns false, then still some subexpressions might have been rewritten. However,
319 /// `expr` is still in the global context.
320 pub fn try_localize_to_input_with_bound_expr(
321 &self,
322 expr: &mut MirScalarExpr,
323 index: usize,
324 equivalences: &[Vec<MirScalarExpr>],
325 ) -> bool {
326 // TODO (wangandi): Consider changing this code to be post-order
327 // instead of pre-order? `lookup_inputs` traverses all the nodes in
328 // `e` anyway, so we end up visiting nodes in `e` multiple times
329 // here. Alternatively, consider having the future `PredicateKnowledge`
330 // take over the responsibilities of this code?
331 #[allow(deprecated)]
332 expr.visit_mut_pre_post_nolimit(
333 &mut |e| {
334 let mut inputs = self.lookup_inputs(e);
335 if let Some(first_input) = inputs.next() {
336 if inputs.next().is_none() && first_input == index {
337 // there is only one input, and it is equal to index, so we're
338 // good. do not continue the recursion
339 return Some(vec![]);
340 }
341 }
342
343 if let Some(bound_expr) = self.find_bound_expr(e, &[index], equivalences) {
344 // Replace the subexpression with the equivalent one from input `index`
345 *e = bound_expr;
346 // The entire subexpression has been rewritten, so there is
347 // no need to visit any child expressions.
348 Some(vec![])
349 } else {
350 None
351 }
352 },
353 &mut |_| {},
354 );
355 if self.is_localized(expr, index) {
356 // If the localization attempt is successful, all columns in `expr`
357 // should only come from input `index`. Switch to the local context.
358 *expr = self.map_expr_to_local(expr.clone());
359 return true;
360 }
361 false
362 }
363
364 /// Try to find a consequence `c` of the given expression `e` for the given input.
365 ///
366 /// If we return `Some(c)`, that means
367 /// 1. `c` uses only columns from the given input;
368 /// 2. if `c` doesn't hold on a row of the input, then `e` also wouldn't hold;
369 /// 3. if `c` holds on a row of the input, then `e` might or might not hold.
370 /// 1. and 2. means that if we have a join with predicate `e` then we can use `c` for
371 /// pre-filtering a join input before the join. However, 3. means that `e` shouldn't be deleted
372 /// from the join predicates, i.e., we can't do a "traditional" predicate pushdown.
373 ///
374 /// Note that "`c` is a consequence of `e`" is the same thing as 2., see
375 /// <https://en.wikipedia.org/wiki/Contraposition>
376 ///
377 /// Example: For
378 /// `(t1.f2 = 3 AND t2.f2 = 4) OR (t1.f2 = 5 AND t2.f2 = 6)`
379 /// we find
380 /// `t1.f2 = 3 OR t1.f2 = 5` for t1, and
381 /// `t2.f2 = 4 OR t2.f2 = 6` for t2.
382 ///
383 /// Further examples are in TPC-H Q07, Q19, and chbench Q07, Q19.
384 ///
385 /// Parameters:
386 /// - `expr`: The expression `e` from above. `try_localize_to_input_with_bound_expr` should
387 /// be called on `expr` before us!
388 /// - `index`: The index of the join input whose columns we will use.
389 /// - `equivalences`: Join equivalences that we can use for `try_map_to_input_with_bound_expr`.
390 /// If successful, the returned expression is in the local context of the specified input.
391 pub fn consequence_for_input(
392 &self,
393 expr: &MirScalarExpr,
394 index: usize,
395 ) -> Option<MirScalarExpr> {
396 if self.is_localized(expr, index) {
397 Some(self.map_expr_to_local(expr.clone()))
398 } else {
399 match expr {
400 MirScalarExpr::CallVariadic {
401 func: VariadicFunc::Or(_),
402 exprs: or_args,
403 } => {
404 // Each OR arg should provide a consequence. If they do, we OR them.
405 let consequences_per_arg = or_args
406 .into_iter()
407 .map(|or_arg| {
408 mz_ore::stack::maybe_grow(|| self.consequence_for_input(or_arg, index))
409 })
410 .collect::<Option<Vec<_>>>()?; // return None if any of them are None
411 Some(MirScalarExpr::call_variadic(Or, consequences_per_arg))
412 }
413 MirScalarExpr::CallVariadic {
414 func: VariadicFunc::And(_),
415 exprs: and_args,
416 } => {
417 // If any of the AND args provide a consequence, then we take those that do,
418 // and AND them.
419 let consequences_per_arg = and_args
420 .into_iter()
421 .map(|and_arg| {
422 mz_ore::stack::maybe_grow(|| self.consequence_for_input(and_arg, index))
423 })
424 .flat_map(|c| c) // take only those that provide a consequence
425 .collect_vec();
426 if consequences_per_arg.is_empty() {
427 None
428 } else {
429 Some(MirScalarExpr::call_variadic(And, consequences_per_arg))
430 }
431 }
432 _ => None,
433 }
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use mz_repr::{Datum, ReprScalarType};
441
442 use crate::scalar::func;
443 use crate::{BinaryFunc, MirScalarExpr, UnaryFunc};
444
445 use super::*;
446
447 #[mz_ore::test]
448 #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
449 fn try_map_to_input_with_bound_expr_test() {
450 let input_mapper = JoinInputMapper {
451 arities: vec![2, 3, 3],
452 input_relation: vec![0, 0, 1, 1, 1, 2, 2, 2],
453 prior_arities: vec![0, 2, 5],
454 };
455
456 // keys are numbered by (equivalence class #, input #)
457 let key10 = MirScalarExpr::column(0);
458 let key12 = MirScalarExpr::column(6);
459 let localized_key12 = MirScalarExpr::column(1);
460
461 let mut equivalences = vec![vec![key10.clone(), key12.clone()]];
462
463 // when the column is already part of the target input, all that happens
464 // is that it gets localized
465 let mut cloned = key12.clone();
466 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
467 assert_eq!(MirScalarExpr::column(1), cloned);
468
469 // basic tests that we can find a column's corresponding column in a
470 // different input
471 let mut cloned = key12.clone();
472 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
473 assert_eq!(key10, cloned);
474 let mut cloned = key12.clone();
475 assert_eq!(
476 false,
477 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
478 );
479
480 let key20 = MirScalarExpr::CallUnary {
481 func: UnaryFunc::NegInt32(crate::func::NegInt32),
482 expr: Box::new(MirScalarExpr::column(1)),
483 };
484 let key21 = MirScalarExpr::CallBinary {
485 func: BinaryFunc::AddInt32(func::AddInt32),
486 expr1: Box::new(MirScalarExpr::column(2)),
487 expr2: Box::new(MirScalarExpr::literal(
488 Ok(Datum::Int32(4)),
489 ReprScalarType::Int32,
490 )),
491 };
492 let key22 = MirScalarExpr::column(5);
493 let localized_key22 = MirScalarExpr::column(0);
494 equivalences.push(vec![key22.clone(), key20.clone(), key21.clone()]);
495
496 // basic tests that we can find an expression's corresponding expression in a
497 // different input
498 let mut cloned = key21.clone();
499 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
500 assert_eq!(key20, cloned);
501 let mut cloned = key21.clone();
502 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
503 assert_eq!(localized_key22, cloned);
504
505 // test that `try_map_to_input_with_bound_expr` will map multiple
506 // subexpressions to the corresponding expressions bound to a different input
507 let key_comp = MirScalarExpr::CallBinary {
508 func: func::MulInt32.into(),
509 expr1: Box::new(key12.clone()),
510 expr2: Box::new(key22),
511 };
512 let mut cloned = key_comp.clone();
513 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
514 assert_eq!(
515 MirScalarExpr::CallBinary {
516 func: func::MulInt32.into(),
517 expr1: Box::new(key10.clone()),
518 expr2: Box::new(key20.clone()),
519 },
520 cloned,
521 );
522
523 // test that the function returns None when part
524 // of the expression can be mapped to an input but the rest can't
525 let mut cloned = key_comp.clone();
526 assert_eq!(
527 false,
528 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
529 );
530
531 let key_comp_plus_non_key = MirScalarExpr::CallBinary {
532 func: func::Eq.into(),
533 expr1: Box::new(key_comp),
534 expr2: Box::new(MirScalarExpr::column(7)),
535 };
536 let mut mutab = key_comp_plus_non_key;
537 assert_eq!(
538 false,
539 input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 0, &equivalences),
540 );
541
542 let key_comp_multi_input = MirScalarExpr::CallBinary {
543 func: func::Eq.into(),
544 expr1: Box::new(key12),
545 expr2: Box::new(key21),
546 };
547 // test that the function works when part of the expression is already
548 // part of the target input
549 let mut cloned = key_comp_multi_input.clone();
550 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
551 assert_eq!(
552 MirScalarExpr::CallBinary {
553 func: func::Eq.into(),
554 expr1: Box::new(localized_key12),
555 expr2: Box::new(localized_key22),
556 },
557 cloned,
558 );
559 // test that the function works when parts of the expression come from
560 // multiple inputs
561 let mut cloned = key_comp_multi_input.clone();
562 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
563 assert_eq!(
564 MirScalarExpr::CallBinary {
565 func: func::Eq.into(),
566 expr1: Box::new(key10),
567 expr2: Box::new(key20),
568 },
569 cloned,
570 );
571 let mut mutab = key_comp_multi_input;
572 assert_eq!(
573 false,
574 input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 1, &equivalences),
575 )
576 }
577}