mz_expr/relation/join_input_mapper.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10use std::collections::BTreeSet;
11use std::ops::Range;
12
13use itertools::Itertools;
14use mz_repr::SqlRelationType;
15
16use crate::visit::Visit;
17use crate::{MirRelationExpr, MirScalarExpr, VariadicFunc};
18
19/// Any column in a join expression exists in two contexts:
20/// 1) It has a position relative to the result of the join (global)
21/// 2) It has a position relative to the specific input it came from (local)
22/// This utility focuses on taking expressions that are in terms of
23/// the local input and re-expressing them in global terms and vice versa.
24///
25/// Methods in this class that take an argument `equivalences` are only
26/// guaranteed to return a correct answer if equivalence classes are in
27/// canonical form.
28/// (See [`crate::relation::canonicalize::canonicalize_equivalences`].)
29#[derive(Debug)]
30pub struct JoinInputMapper {
31 /// The number of columns per input. All other fields in this struct are
32 /// derived using the information in this field.
33 arities: Vec<usize>,
34 /// Looks up which input each column belongs to. Derived from `arities`.
35 /// Stored as a field to avoid recomputation.
36 input_relation: Vec<usize>,
37 /// The sum of the arities of the previous inputs in the join. Derived from
38 /// `arities`. Stored as a field to avoid recomputation.
39 prior_arities: Vec<usize>,
40}
41
42impl JoinInputMapper {
43 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
44 /// columns to local context columns.
45 pub fn new(inputs: &[MirRelationExpr]) -> Self {
46 Self::new_from_input_arities(inputs.iter().map(|i| i.arity()))
47 }
48
49 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
50 /// columns to local context columns. Using this method saves is more
51 /// efficient if input types have been pre-calculated
52 pub fn new_from_input_types(types: &[SqlRelationType]) -> Self {
53 Self::new_from_input_arities(types.iter().map(|t| t.column_types.len()))
54 }
55
56 /// Creates a new `JoinInputMapper` and calculates the mapping of global context
57 /// columns to local context columns. Using this method saves is more
58 /// efficient if input arities have been pre-calculated
59 pub fn new_from_input_arities<I>(arities: I) -> Self
60 where
61 I: IntoIterator<Item = usize>,
62 {
63 let arities = arities.into_iter().collect::<Vec<usize>>();
64 let mut offset = 0;
65 let mut prior_arities = Vec::new();
66 for input in 0..arities.len() {
67 prior_arities.push(offset);
68 offset += arities[input];
69 }
70
71 let input_relation = arities
72 .iter()
73 .enumerate()
74 .flat_map(|(r, a)| std::iter::repeat(r).take(*a))
75 .collect::<Vec<_>>();
76
77 JoinInputMapper {
78 arities,
79 input_relation,
80 prior_arities,
81 }
82 }
83
84 /// reports sum of the number of columns of each input
85 pub fn total_columns(&self) -> usize {
86 self.arities.iter().sum()
87 }
88
89 /// reports total numbers of inputs in the join
90 pub fn total_inputs(&self) -> usize {
91 self.arities.len()
92 }
93
94 /// Using the keys that came from each local input,
95 /// figures out which keys remain unique in the larger join
96 /// Currently, we only figure out a small subset of the keys that
97 /// can remain unique.
98 pub fn global_keys<'a, I>(
99 &self,
100 mut local_keys: I,
101 equivalences: &[Vec<MirScalarExpr>],
102 ) -> Vec<Vec<usize>>
103 where
104 I: Iterator<Item = &'a Vec<Vec<usize>>>,
105 {
106 // A relation's uniqueness constraint holds if there is a
107 // sequence of the other relations such that each one has
108 // a uniqueness constraint whose columns are used in join
109 // constraints with relations prior in the sequence.
110 //
111 // Currently, we only:
112 // 1. test for whether the uniqueness constraints for the first input will hold
113 // 2. try one sequence, namely the inputs in order
114 // 3. check that the column themselves are used in the join constraints
115 // Technically uniqueness constraint would still hold if a 1-to-1
116 // expression on a unique key is used in the join constraint.
117
118 // for inputs `1..self.total_inputs()`, store a set of columns from that
119 // input that exist in join constraints that have expressions belonging to
120 // earlier inputs.
121 let mut column_with_prior_bound_by_input = vec![BTreeSet::new(); self.total_inputs() - 1];
122 for equivalence in equivalences {
123 // do a scan to find the first input represented in the constraint
124 let min_bound_input = equivalence
125 .iter()
126 .flat_map(|expr| self.lookup_inputs(expr).max())
127 .min();
128 if let Some(min_bound_input) = min_bound_input {
129 for expr in equivalence {
130 // then store all columns in the constraint that don't come
131 // from the first input
132 if let MirScalarExpr::Column(c, _name) = expr {
133 let (col, input) = self.map_column_to_local(*c);
134 if input > min_bound_input {
135 column_with_prior_bound_by_input[input - 1].insert(col);
136 }
137 }
138 }
139 }
140 }
141
142 if self.total_inputs() > 0 {
143 let first_input_keys = local_keys.next().unwrap().clone();
144 // for inputs `1..self.total_inputs()`, checks the keys belong to each
145 // input against the storage of columns that exist in join constraints
146 // that have expressions belonging to earlier inputs.
147 let remains_unique = local_keys.enumerate().all(|(index, keys)| {
148 keys.iter().any(|ks| {
149 ks.iter()
150 .all(|k| column_with_prior_bound_by_input[index].contains(k))
151 })
152 });
153
154 if remains_unique {
155 return first_input_keys;
156 }
157 }
158 vec![]
159 }
160
161 /// returns the arity for a particular input
162 #[inline]
163 pub fn input_arity(&self, index: usize) -> usize {
164 self.arities[index]
165 }
166
167 /// All column numbers in order for a particular input in the local context
168 #[inline]
169 pub fn local_columns(&self, index: usize) -> Range<usize> {
170 0..self.arities[index]
171 }
172
173 /// All column numbers in order for a particular input in the global context
174 #[inline]
175 pub fn global_columns(&self, index: usize) -> Range<usize> {
176 self.prior_arities[index]..(self.prior_arities[index] + self.arities[index])
177 }
178
179 /// Takes an expression from the global context and creates a new version
180 /// where column references have been remapped to the local context.
181 /// Assumes that all columns in `expr` are from the same input.
182 pub fn map_expr_to_local(&self, mut expr: MirScalarExpr) -> MirScalarExpr {
183 expr.visit_pre_mut(|e| {
184 if let MirScalarExpr::Column(c, _name) = e {
185 *c -= self.prior_arities[self.input_relation[*c]];
186 }
187 });
188 expr
189 }
190
191 /// Takes an expression from the local context of the `index`th input and
192 /// creates a new version where column references have been remapped to the
193 /// global context.
194 pub fn map_expr_to_global(&self, mut expr: MirScalarExpr, index: usize) -> MirScalarExpr {
195 expr.visit_pre_mut(|e| {
196 if let MirScalarExpr::Column(c, _name) = e {
197 *c += self.prior_arities[index];
198 }
199 });
200 expr
201 }
202
203 /// Remap column numbers from the global to the local context.
204 /// Returns a 2-tuple `(<new column number>, <index of input>)`
205 pub fn map_column_to_local(&self, column: usize) -> (usize, usize) {
206 let index = self.input_relation[column];
207 (column - self.prior_arities[index], index)
208 }
209
210 /// Remap a column number from a local context to the global context.
211 pub fn map_column_to_global(&self, column: usize, index: usize) -> usize {
212 column + self.prior_arities[index]
213 }
214
215 /// Takes a sequence of columns in the global context and splits it into
216 /// a `Vec` containing `self.total_inputs()` `BTreeSet`s, each containing
217 /// the localized columns that belong to the particular input.
218 pub fn split_column_set_by_input<'a, I>(&self, columns: I) -> Vec<BTreeSet<usize>>
219 where
220 I: Iterator<Item = &'a usize>,
221 {
222 let mut new_columns = vec![BTreeSet::new(); self.total_inputs()];
223 for column in columns {
224 let (new_col, input) = self.map_column_to_local(*column);
225 new_columns[input].extend(std::iter::once(new_col));
226 }
227 new_columns
228 }
229
230 /// Find the sorted, dedupped set of inputs an expression references
231 pub fn lookup_inputs(&self, expr: &MirScalarExpr) -> impl Iterator<Item = usize> + use<> {
232 expr.support()
233 .iter()
234 .map(|c| self.input_relation[*c])
235 .sorted()
236 .dedup()
237 }
238
239 /// Returns the index of the only input referenced in the given expression.
240 pub fn single_input(&self, expr: &MirScalarExpr) -> Option<usize> {
241 let mut inputs = self.lookup_inputs(expr);
242 if let Some(first_input) = inputs.next() {
243 if inputs.next().is_none() {
244 return Some(first_input);
245 }
246 }
247 None
248 }
249
250 /// Returns whether the given expr refers to columns of only the `index`th input.
251 pub fn is_localized(&self, expr: &MirScalarExpr, index: usize) -> bool {
252 if let Some(single_input) = self.single_input(expr) {
253 if single_input == index {
254 return true;
255 }
256 }
257 false
258 }
259
260 /// Takes an expression in the global context and looks in `equivalences`
261 /// for an equivalent expression (also expressed in the global context) that
262 /// belongs to one or more of the inputs in `bound_inputs`
263 ///
264 /// # Examples
265 ///
266 /// ```
267 /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
268 /// use mz_expr::{JoinInputMapper, MirRelationExpr, MirScalarExpr};
269 ///
270 /// // A two-column schema common to each of the three inputs
271 /// let schema = SqlRelationType::new(vec![
272 /// SqlScalarType::Int32.nullable(false),
273 /// SqlScalarType::Int32.nullable(false),
274 /// ]);
275 ///
276 /// // the specific data are not important here.
277 /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
278 /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
279 /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
280 /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
281 ///
282 /// // [input0(#0) = input2(#1)], [input0(#1) = input1(#0) = input2(#0)]
283 /// let equivalences = vec![
284 /// vec![MirScalarExpr::column(0), MirScalarExpr::column(5)],
285 /// vec![MirScalarExpr::column(1), MirScalarExpr::column(2), MirScalarExpr::column(4)],
286 /// ];
287 ///
288 /// let input_mapper = JoinInputMapper::new(&[input0, input1, input2]);
289 /// assert_eq!(
290 /// Some(MirScalarExpr::column(4)),
291 /// input_mapper.find_bound_expr(&MirScalarExpr::column(2), &[2], &equivalences)
292 /// );
293 /// assert_eq!(
294 /// None,
295 /// input_mapper.find_bound_expr(&MirScalarExpr::column(0), &[1], &equivalences)
296 /// );
297 /// ```
298 pub fn find_bound_expr(
299 &self,
300 expr: &MirScalarExpr,
301 bound_inputs: &[usize],
302 equivalences: &[Vec<MirScalarExpr>],
303 ) -> Option<MirScalarExpr> {
304 if let Some(equivalence) = equivalences.iter().find(|equivs| equivs.contains(expr)) {
305 if let Some(bound_expr) = equivalence
306 .iter()
307 .find(|expr| self.lookup_inputs(expr).all(|i| bound_inputs.contains(&i)))
308 {
309 return Some(bound_expr.clone());
310 }
311 }
312 None
313 }
314
315 /// Try to rewrite `expr` from the global context so that all the
316 /// columns point to the `index`th input by replacing subexpressions with their
317 /// bound equivalents in the `index`th input if necessary.
318 /// Returns whether the rewriting was successful.
319 /// If it returns true, then `expr` is in the context of the `index`th input.
320 /// If it returns false, then still some subexpressions might have been rewritten. However,
321 /// `expr` is still in the global context.
322 pub fn try_localize_to_input_with_bound_expr(
323 &self,
324 expr: &mut MirScalarExpr,
325 index: usize,
326 equivalences: &[Vec<MirScalarExpr>],
327 ) -> bool {
328 // TODO (wangandi): Consider changing this code to be post-order
329 // instead of pre-order? `lookup_inputs` traverses all the nodes in
330 // `e` anyway, so we end up visiting nodes in `e` multiple times
331 // here. Alternatively, consider having the future `PredicateKnowledge`
332 // take over the responsibilities of this code?
333 #[allow(deprecated)]
334 expr.visit_mut_pre_post_nolimit(
335 &mut |e| {
336 let mut inputs = self.lookup_inputs(e);
337 if let Some(first_input) = inputs.next() {
338 if inputs.next().is_none() && first_input == index {
339 // there is only one input, and it is equal to index, so we're
340 // good. do not continue the recursion
341 return Some(vec![]);
342 }
343 }
344
345 if let Some(bound_expr) = self.find_bound_expr(e, &[index], equivalences) {
346 // Replace the subexpression with the equivalent one from input `index`
347 *e = bound_expr;
348 // The entire subexpression has been rewritten, so there is
349 // no need to visit any child expressions.
350 Some(vec![])
351 } else {
352 None
353 }
354 },
355 &mut |_| {},
356 );
357 if self.is_localized(expr, index) {
358 // If the localization attempt is successful, all columns in `expr`
359 // should only come from input `index`. Switch to the local context.
360 *expr = self.map_expr_to_local(expr.clone());
361 return true;
362 }
363 false
364 }
365
366 /// Try to find a consequence `c` of the given expression `e` for the given input.
367 ///
368 /// If we return `Some(c)`, that means
369 /// 1. `c` uses only columns from the given input;
370 /// 2. if `c` doesn't hold on a row of the input, then `e` also wouldn't hold;
371 /// 3. if `c` holds on a row of the input, then `e` might or might not hold.
372 /// 1. and 2. means that if we have a join with predicate `e` then we can use `c` for
373 /// pre-filtering a join input before the join. However, 3. means that `e` shouldn't be deleted
374 /// from the join predicates, i.e., we can't do a "traditional" predicate pushdown.
375 ///
376 /// Note that "`c` is a consequence of `e`" is the same thing as 2., see
377 /// <https://en.wikipedia.org/wiki/Contraposition>
378 ///
379 /// Example: For
380 /// `(t1.f2 = 3 AND t2.f2 = 4) OR (t1.f2 = 5 AND t2.f2 = 6)`
381 /// we find
382 /// `t1.f2 = 3 OR t1.f2 = 5` for t1, and
383 /// `t2.f2 = 4 OR t2.f2 = 6` for t2.
384 ///
385 /// Further examples are in TPC-H Q07, Q19, and chbench Q07, Q19.
386 ///
387 /// Parameters:
388 /// - `expr`: The expression `e` from above. `try_localize_to_input_with_bound_expr` should
389 /// be called on `expr` before us!
390 /// - `index`: The index of the join input whose columns we will use.
391 /// - `equivalences`: Join equivalences that we can use for `try_map_to_input_with_bound_expr`.
392 /// If successful, the returned expression is in the local context of the specified input.
393 pub fn consequence_for_input(
394 &self,
395 expr: &MirScalarExpr,
396 index: usize,
397 ) -> Option<MirScalarExpr> {
398 if self.is_localized(expr, index) {
399 Some(self.map_expr_to_local(expr.clone()))
400 } else {
401 match expr {
402 MirScalarExpr::CallVariadic {
403 func: VariadicFunc::Or,
404 exprs: or_args,
405 } => {
406 // Each OR arg should provide a consequence. If they do, we OR them.
407 let consequences_per_arg = or_args
408 .into_iter()
409 .map(|or_arg| {
410 mz_ore::stack::maybe_grow(|| self.consequence_for_input(or_arg, index))
411 })
412 .collect::<Option<Vec<_>>>()?; // return None if any of them are None
413 Some(MirScalarExpr::CallVariadic {
414 func: VariadicFunc::Or,
415 exprs: consequences_per_arg,
416 })
417 }
418 MirScalarExpr::CallVariadic {
419 func: VariadicFunc::And,
420 exprs: and_args,
421 } => {
422 // If any of the AND args provide a consequence, then we take those that do,
423 // and AND them.
424 let consequences_per_arg = and_args
425 .into_iter()
426 .map(|and_arg| {
427 mz_ore::stack::maybe_grow(|| self.consequence_for_input(and_arg, index))
428 })
429 .flat_map(|c| c) // take only those that provide a consequence
430 .collect_vec();
431 if consequences_per_arg.is_empty() {
432 None
433 } else {
434 Some(MirScalarExpr::CallVariadic {
435 func: VariadicFunc::And,
436 exprs: consequences_per_arg,
437 })
438 }
439 }
440 _ => None,
441 }
442 }
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use mz_repr::{Datum, SqlScalarType};
449
450 use crate::{BinaryFunc, MirScalarExpr, UnaryFunc};
451
452 use super::*;
453
454 #[mz_ore::test]
455 #[cfg_attr(miri, ignore)] // unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
456 fn try_map_to_input_with_bound_expr_test() {
457 let input_mapper = JoinInputMapper {
458 arities: vec![2, 3, 3],
459 input_relation: vec![0, 0, 1, 1, 1, 2, 2, 2],
460 prior_arities: vec![0, 2, 5],
461 };
462
463 // keys are numbered by (equivalence class #, input #)
464 let key10 = MirScalarExpr::column(0);
465 let key12 = MirScalarExpr::column(6);
466 let localized_key12 = MirScalarExpr::column(1);
467
468 let mut equivalences = vec![vec![key10.clone(), key12.clone()]];
469
470 // when the column is already part of the target input, all that happens
471 // is that it gets localized
472 let mut cloned = key12.clone();
473 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
474 assert_eq!(MirScalarExpr::column(1), cloned);
475
476 // basic tests that we can find a column's corresponding column in a
477 // different input
478 let mut cloned = key12.clone();
479 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
480 assert_eq!(key10, cloned);
481 let mut cloned = key12.clone();
482 assert_eq!(
483 false,
484 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
485 );
486
487 let key20 = MirScalarExpr::CallUnary {
488 func: UnaryFunc::NegInt32(crate::func::NegInt32),
489 expr: Box::new(MirScalarExpr::column(1)),
490 };
491 let key21 = MirScalarExpr::CallBinary {
492 func: BinaryFunc::AddInt32,
493 expr1: Box::new(MirScalarExpr::column(2)),
494 expr2: Box::new(MirScalarExpr::literal(
495 Ok(Datum::Int32(4)),
496 SqlScalarType::Int32,
497 )),
498 };
499 let key22 = MirScalarExpr::column(5);
500 let localized_key22 = MirScalarExpr::column(0);
501 equivalences.push(vec![key22.clone(), key20.clone(), key21.clone()]);
502
503 // basic tests that we can find an expression's corresponding expression in a
504 // different input
505 let mut cloned = key21.clone();
506 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
507 assert_eq!(key20, cloned);
508 let mut cloned = key21.clone();
509 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
510 assert_eq!(localized_key22, cloned);
511
512 // test that `try_map_to_input_with_bound_expr` will map multiple
513 // subexpressions to the corresponding expressions bound to a different input
514 let key_comp = MirScalarExpr::CallBinary {
515 func: BinaryFunc::MulInt32,
516 expr1: Box::new(key12.clone()),
517 expr2: Box::new(key22),
518 };
519 let mut cloned = key_comp.clone();
520 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
521 assert_eq!(
522 MirScalarExpr::CallBinary {
523 func: BinaryFunc::MulInt32,
524 expr1: Box::new(key10.clone()),
525 expr2: Box::new(key20.clone()),
526 },
527 cloned,
528 );
529
530 // test that the function returns None when part
531 // of the expression can be mapped to an input but the rest can't
532 let mut cloned = key_comp.clone();
533 assert_eq!(
534 false,
535 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 1, &equivalences),
536 );
537
538 let key_comp_plus_non_key = MirScalarExpr::CallBinary {
539 func: BinaryFunc::Eq,
540 expr1: Box::new(key_comp),
541 expr2: Box::new(MirScalarExpr::column(7)),
542 };
543 let mut mutab = key_comp_plus_non_key;
544 assert_eq!(
545 false,
546 input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 0, &equivalences),
547 );
548
549 let key_comp_multi_input = MirScalarExpr::CallBinary {
550 func: BinaryFunc::Eq,
551 expr1: Box::new(key12),
552 expr2: Box::new(key21),
553 };
554 // test that the function works when part of the expression is already
555 // part of the target input
556 let mut cloned = key_comp_multi_input.clone();
557 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 2, &equivalences);
558 assert_eq!(
559 MirScalarExpr::CallBinary {
560 func: BinaryFunc::Eq,
561 expr1: Box::new(localized_key12),
562 expr2: Box::new(localized_key22),
563 },
564 cloned,
565 );
566 // test that the function works when parts of the expression come from
567 // multiple inputs
568 let mut cloned = key_comp_multi_input.clone();
569 input_mapper.try_localize_to_input_with_bound_expr(&mut cloned, 0, &equivalences);
570 assert_eq!(
571 MirScalarExpr::CallBinary {
572 func: BinaryFunc::Eq,
573 expr1: Box::new(key10),
574 expr2: Box::new(key20),
575 },
576 cloned,
577 );
578 let mut mutab = key_comp_multi_input;
579 assert_eq!(
580 false,
581 input_mapper.try_localize_to_input_with_bound_expr(&mut mutab, 1, &equivalences),
582 )
583 }
584}