mz_expr/relation/canonicalize.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utility functions to transform parts of a single `MirRelationExpr`
11//! into canonical form.
12
13use std::cmp::Ordering;
14use std::collections::{BTreeMap, BTreeSet};
15use std::rc::Rc;
16
17use itertools::Itertools;
18use mz_ore::soft_assert_or_log;
19use mz_repr::{ReprColumnType, ReprScalarType};
20
21use crate::visit::Visit;
22use crate::{MirScalarExpr, UnaryFunc, VariadicFunc, func};
23
24/// Canonicalize equivalence classes of a join and expressions contained in them.
25///
26/// `input_types` can be the [ReprColumnType]s of the join or the [ReprColumnType]s of
27/// the individual inputs of the join in order.
28///
29/// This function:
30/// * simplifies expressions to involve the least number of non-literal nodes.
31/// This ensures that we only replace expressions by "even simpler"
32/// expressions and that repeated substitutions reduce the complexity of
33/// expressions and a fixed point is certain to be reached. Without this
34/// rule, we might repeatedly replace a simple expression with an equivalent
35/// complex expression containing that (or another replaceable) simple
36/// expression, and repeat indefinitely.
37/// * reduces all expressions contained in `equivalences`.
38/// * Does everything that [canonicalize_equivalence_classes] does.
39pub fn canonicalize_equivalences<'a, I>(
40 equivalences: &mut Vec<Vec<MirScalarExpr>>,
41 input_column_types: I,
42) where
43 I: Iterator<Item = &'a Vec<ReprColumnType>>,
44{
45 let repr_column_types = input_column_types
46 .flat_map(|f| f.clone())
47 .collect::<Vec<_>>();
48 // Calculate the number of non-leaves for each expression.
49 let mut to_reduce = equivalences
50 .drain(..)
51 .filter_map(|mut cls| {
52 let mut result = cls
53 .drain(..)
54 .map(|expr| (rank_complexity(&expr), expr))
55 .collect::<Vec<_>>();
56 result.sort();
57 result.dedup();
58 if result.len() > 1 { Some(result) } else { None }
59 })
60 .collect::<Vec<_>>();
61
62 let mut expressions_rewritten = true;
63 while expressions_rewritten {
64 expressions_rewritten = false;
65 for i in 0..to_reduce.len() {
66 // `to_reduce` will be borrowed as immutable, so in order to modify
67 // elements of `to_reduce[i]`, we are going to pop them out of
68 // `to_reduce[i]` and put the modified version in `new_equivalence`,
69 // which will then replace `to_reduce[i]`.
70 let mut new_equivalence = Vec::with_capacity(to_reduce[i].len());
71 while let Some((_, mut popped_expr)) = to_reduce[i].pop() {
72 #[allow(deprecated)]
73 popped_expr.visit_mut_post_nolimit(&mut |e: &mut MirScalarExpr| {
74 // If a simpler expression can be found that is equivalent
75 // to e,
76 if let Some(simpler_e) = to_reduce.iter().find_map(|cls| {
77 if cls.iter().skip(1).position(|(_, expr)| e == expr).is_some() {
78 Some(cls[0].1.clone())
79 } else {
80 None
81 }
82 }) {
83 // Replace e with the simpler expression.
84 *e = simpler_e;
85 expressions_rewritten = true;
86 }
87 });
88 popped_expr.reduce(&repr_column_types);
89 new_equivalence.push((rank_complexity(&popped_expr), popped_expr));
90 }
91 new_equivalence.sort();
92 new_equivalence.dedup();
93 to_reduce[i] = new_equivalence;
94 }
95 }
96
97 // Map away the complexity rating.
98 *equivalences = to_reduce
99 .drain(..)
100 .map(|mut cls| cls.drain(..).map(|(_, expr)| expr).collect::<Vec<_>>())
101 .collect::<Vec<_>>();
102
103 canonicalize_equivalence_classes(equivalences);
104}
105
106/// Canonicalize only the equivalence classes of a join.
107///
108/// This function:
109/// * ensures the same expression appears in only one equivalence class.
110/// * ensures the equivalence classes are sorted and dedupped.
111/// ```rust
112/// use mz_expr::MirScalarExpr;
113/// use mz_expr::canonicalize::canonicalize_equivalence_classes;
114///
115/// let mut equivalences = vec![
116/// vec![MirScalarExpr::column(1), MirScalarExpr::column(4)],
117/// vec![MirScalarExpr::column(3), MirScalarExpr::column(5)],
118/// vec![MirScalarExpr::column(0), MirScalarExpr::column(3)],
119/// vec![MirScalarExpr::column(2), MirScalarExpr::column(2)],
120/// ];
121/// let expected = vec![
122/// vec![MirScalarExpr::column(0),
123/// MirScalarExpr::column(3),
124/// MirScalarExpr::column(5)],
125/// vec![MirScalarExpr::column(1), MirScalarExpr::column(4)],
126/// ];
127/// canonicalize_equivalence_classes(&mut equivalences);
128/// assert_eq!(expected, equivalences)
129/// ````
130pub fn canonicalize_equivalence_classes(equivalences: &mut Vec<Vec<MirScalarExpr>>) {
131 let mut uf = BTreeMap::new();
132 for class in equivalences.iter_mut() {
133 let mut iter = class.drain(..);
134 if let Some(first) = iter.next() {
135 let head = Rc::new(first);
136 for rest in iter {
137 uf.union(&head, &Rc::new(rest));
138 }
139 }
140 }
141
142 let mut eqs: BTreeMap<Rc<MirScalarExpr>, BTreeSet<Rc<MirScalarExpr>>> = BTreeMap::new();
143 for (k, v) in uf {
144 eqs.entry(v).or_default().insert(Rc::clone(&k));
145 }
146
147 let classes = eqs.into_values().collect::<Vec<_>>();
148 equivalences.resize(classes.len(), Vec::new());
149 equivalences
150 .iter_mut()
151 .zip_eq(classes)
152 .for_each(|(equivalence, class)| {
153 equivalence.extend(
154 class
155 .into_iter()
156 .map(|e| Rc::try_unwrap(e).expect("there to be only one strong ref")),
157 );
158 });
159
160 equivalences.retain(|es| es.len() > 1);
161 equivalences.sort();
162}
163
164/// Gives a relative complexity ranking for an expression. Higher numbers mean
165/// greater complexity.
166///
167/// Currently, this method weighs literals as the least complex and weighs all
168/// other expressions by the number of non-literals. In the future, we can
169/// change how complexity is ranked so that repeated substitutions would result
170/// in arriving at "better" fixed points. For example, we could try to improve
171/// performance by ranking expressions by their estimated computation time.
172///
173/// To ensure we arrive at a fixed point after repeated substitutions, valid
174/// complexity rankings must fulfill the following property:
175/// For any expression `e`, there does not exist a SQL function `f` such
176/// that `complexity(e) >= complexity(f(e))`.
177///
178/// For ease of intuiting the fixed point that we will arrive at after
179/// repeated substitutions, it is nice but not required that complexity
180/// rankings additionally fulfill the following property:
181/// If expressions `e1` and `e2` are such that
182/// `complexity(e1) < complexity(e2)` then for all SQL functions `f`,
183/// `complexity(f(e1)) < complexity(f(e2))`.
184fn rank_complexity(expr: &MirScalarExpr) -> usize {
185 if expr.is_literal() {
186 // literals are the least complex
187 return 0;
188 }
189 let mut non_literal_count = 1;
190 expr.visit_pre(|e| {
191 if !e.is_literal() {
192 non_literal_count += 1
193 }
194 });
195 non_literal_count
196}
197
198/// Applies a flat_map on a Vec, and overwrites the vec with the result.
199fn flat_map_modify<T, I, F>(v: &mut Vec<T>, f: F)
200where
201 F: FnMut(T) -> I,
202 I: IntoIterator<Item = T>,
203{
204 let mut xx = v.drain(..).flat_map(f).collect();
205 v.append(&mut xx);
206}
207
208/// Canonicalize predicates of a filter.
209///
210/// This function reduces and canonicalizes the structure of each individual
211/// predicate. Then, it transforms predicates of the form "A and B" into two: "A"
212/// and "B". Afterwards, it reduces predicates based on information from other
213/// predicates in the set. Finally, it sorts and deduplicates the predicates.
214///
215/// Additionally, it also removes IS NOT NULL predicates if there is another
216/// null rejecting predicate for the same sub-expression.
217pub fn canonicalize_predicates(
218 predicates: &mut Vec<MirScalarExpr>,
219 repr_column_types: &[ReprColumnType],
220) {
221 soft_assert_or_log!(
222 predicates
223 .iter()
224 .all(|p| p.typ(repr_column_types).scalar_type == ReprScalarType::Bool),
225 "cannot canonicalize predicates that are not of type bool"
226 );
227
228 // 1) Reduce each individual predicate.
229 predicates
230 .iter_mut()
231 .for_each(|p| p.reduce(repr_column_types));
232
233 // 2) Split "A and B" into two predicates: "A" and "B"
234 // Relies on the `reduce` above having flattened nested ANDs.
235 flat_map_modify(predicates, |p| {
236 if let MirScalarExpr::CallVariadic {
237 func: VariadicFunc::And(_),
238 exprs,
239 } = p
240 {
241 exprs
242 } else {
243 vec![p]
244 }
245 });
246
247 // 3) Make non-null requirements explicit as predicates in order for
248 // step 4) to be able to simplify AND/OR expressions with IS NULL
249 // sub-predicates. This redundancy is removed later by step 5).
250 let mut non_null_columns = BTreeSet::new();
251 for p in predicates.iter() {
252 p.non_null_requirements(&mut non_null_columns);
253 }
254 predicates.extend(non_null_columns.iter().map(|c| {
255 MirScalarExpr::column(*c)
256 .call_unary(UnaryFunc::IsNull(func::IsNull))
257 .call_unary(UnaryFunc::Not(func::Not))
258 }));
259
260 // 4) Reduce across `predicates`.
261 // If a predicate `p` cannot be null, and `f(p)` is a nullable bool
262 // then the predicate `p & f(p)` is equal to `p & f(true)`, and
263 // `!p & f(p)` is equal to `!p & f(false)`. For any index i, the `Vec` of
264 // predicates `[p1, ... pi, ... pn]` is equivalent to the single predicate
265 // `pi & (p1 & ... & p(i-1) & p(i+1) ... & pn)`. Thus, if `pi`
266 // (resp. `!pi`) cannot be null, it is valid to replace with `true` (resp.
267 // `false`) every subexpression in `(p1 & ... & p(i-1) & p(i+1) ... & pn)`
268 // that is equal to `pi`.
269
270 // If `p` is null and `q` is a nullable bool, then `p & q` can be either
271 // `null` or `false` depending on what `q`. Our rendering pipeline treats
272 // both as "remove this row." Thus, in the specific context of filter
273 // predicates, it is acceptable to make the aforementioned substitution
274 // even if `pi` can be null.
275
276 // Note that this does some dedupping of predicates since if `p1 = p2`
277 // then this reduction process will replace `p1` with true.
278
279 // Maintain respectively:
280 // 1) A list of predicates for which we have checked for matching
281 // subexpressions
282 // 2) A list of predicates for which we have yet to do so.
283 let mut completed = Vec::new();
284 let mut todo = Vec::new();
285 // Seed `todo` with all predicates.
286 std::mem::swap(&mut todo, predicates);
287
288 while let Some(predicate_to_apply) = todo.pop() {
289 // Helper method: for each predicate `p`, see if all other predicates
290 // (a.k.a. the union of todo & completed) contains `p` as a
291 // subexpression, and replace the subexpression accordingly.
292 // This method lives inside the loop because in order to comply with
293 // Rust rules that only one mutable reference to `todo` can be held at a
294 // time.
295 let mut replace_subexpr_other_predicates =
296 |expr: &MirScalarExpr, constant_bool: &MirScalarExpr| {
297 // Do not replace subexpressions equal to `expr` if `expr` is a
298 // literal to avoid infinite looping.
299 if !expr.is_literal() {
300 for other_predicate in todo.iter_mut() {
301 replace_subexpr_and_reduce(
302 other_predicate,
303 expr,
304 constant_bool,
305 repr_column_types,
306 );
307 }
308 for other_idx in (0..completed.len()).rev() {
309 if replace_subexpr_and_reduce(
310 &mut completed[other_idx],
311 expr,
312 constant_bool,
313 repr_column_types,
314 ) {
315 // If a predicate in the `completed` list has
316 // been simplified, stick it back into the `todo` list.
317 todo.push(completed.remove(other_idx));
318 }
319 }
320 }
321 };
322 // Meat of loop starts here. If a predicate p is of the form `!q`, replace
323 // every instance of `q` in every other predicate with `false.`
324 // Otherwise, replace every instance of `p` in every other predicate
325 // with `true`.
326 if let MirScalarExpr::CallUnary {
327 func: UnaryFunc::Not(func::Not),
328 expr,
329 } = &predicate_to_apply
330 {
331 replace_subexpr_other_predicates(expr, &MirScalarExpr::literal_false())
332 } else {
333 replace_subexpr_other_predicates(&predicate_to_apply, &MirScalarExpr::literal_true());
334 }
335 completed.push(predicate_to_apply);
336 }
337
338 // 5) Remove redundant !isnull/isnull predicates after performing the replacements
339 // in the loop above.
340 std::mem::swap(&mut todo, &mut completed);
341 while let Some(predicate_to_apply) = todo.pop() {
342 // Remove redundant !isnull(x) predicates if there is another predicate
343 // that evaluates to NULL when `x` is NULL.
344 if let Some(operand) = is_not_null(&predicate_to_apply) {
345 if todo
346 .iter_mut()
347 .chain(completed.iter_mut())
348 .any(|p| is_null_rejecting_predicate(p, &operand))
349 {
350 // skip this predicate
351 continue;
352 }
353 } else if let MirScalarExpr::CallUnary {
354 func: UnaryFunc::IsNull(func::IsNull),
355 expr,
356 } = &predicate_to_apply
357 {
358 if todo
359 .iter_mut()
360 .chain(completed.iter_mut())
361 .any(|p| is_null_rejecting_predicate(p, expr))
362 {
363 completed.push(MirScalarExpr::literal_false());
364 break;
365 }
366 }
367 completed.push(predicate_to_apply);
368 }
369
370 if completed.iter().any(|p| {
371 (p.is_literal_false() || p.is_literal_null()) &&
372 // This extra check is only needed if we determine that the soft-assert
373 // at the top of this function would ever fail for a good reason.
374 p.typ(repr_column_types).scalar_type == ReprScalarType::Bool
375 }) {
376 // all rows get filtered away if any predicate is null or false.
377 *predicates = vec![MirScalarExpr::literal_false()]
378 } else {
379 // Remove any predicates that have been reduced to "true"
380 completed.retain(|p| !p.is_literal_true());
381 *predicates = completed;
382 }
383
384 // 6) Sort and dedup predicates.
385 predicates.sort_by(compare_predicates);
386 predicates.dedup();
387}
388
389/// Replace any matching subexpressions in `predicate`, and if `predicate` has
390/// changed, reduce it. Return whether `predicate` has changed.
391fn replace_subexpr_and_reduce(
392 predicate: &mut MirScalarExpr,
393 replace_if_equal_to: &MirScalarExpr,
394 replace_with: &MirScalarExpr,
395 repr_column_types: &[ReprColumnType],
396) -> bool {
397 let mut changed = false;
398 #[allow(deprecated)]
399 predicate.visit_mut_pre_post_nolimit(
400 &mut |e| {
401 // The `cond` of an if statement is not visited to prevent `then`
402 // or `els` from being evaluated before `cond`, resulting in a
403 // correctness error.
404 if let MirScalarExpr::If { then, els, .. } = e {
405 return Some(vec![then, els]);
406 }
407 None
408 },
409 &mut |e| {
410 if e == replace_if_equal_to {
411 *e = replace_with.clone();
412 changed = true;
413 } else if let MirScalarExpr::CallBinary {
414 func: r_func,
415 expr1: r_expr1,
416 expr2: r_expr2,
417 } = replace_if_equal_to
418 {
419 if let Some(negation) = r_func.negate() {
420 if let MirScalarExpr::CallBinary {
421 func: l_func,
422 expr1: l_expr1,
423 expr2: l_expr2,
424 } = e
425 {
426 if negation == *l_func && l_expr1 == r_expr1 && l_expr2 == r_expr2 {
427 *e = MirScalarExpr::CallUnary {
428 func: UnaryFunc::Not(func::Not),
429 expr: Box::new(replace_with.clone()),
430 };
431 changed = true;
432 }
433 }
434 }
435 }
436 },
437 );
438 if changed {
439 predicate.reduce(repr_column_types);
440 }
441 changed
442}
443
444/// Returns the inner operand if the given predicate is an IS NOT NULL expression.
445fn is_not_null(predicate: &MirScalarExpr) -> Option<MirScalarExpr> {
446 if let MirScalarExpr::CallUnary {
447 func: UnaryFunc::Not(func::Not),
448 expr,
449 } = &predicate
450 {
451 if let MirScalarExpr::CallUnary {
452 func: UnaryFunc::IsNull(func::IsNull),
453 expr,
454 } = &**expr
455 {
456 return Some((**expr).clone());
457 }
458 }
459 None
460}
461
462/// Whether the given predicate evaluates to NULL when the given operand expression is NULL.
463#[inline(always)]
464fn is_null_rejecting_predicate(predicate: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
465 propagates_null_from_subexpression(predicate, operand)
466}
467
468fn propagates_null_from_subexpression(expr: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
469 if operand == expr {
470 true
471 } else if let MirScalarExpr::CallVariadic { func, exprs } = &expr {
472 func.propagates_nulls()
473 && (exprs
474 .iter()
475 .any(|e| propagates_null_from_subexpression(e, operand)))
476 } else if let MirScalarExpr::CallBinary { func, expr1, expr2 } = &expr {
477 func.propagates_nulls()
478 && (propagates_null_from_subexpression(expr1, operand)
479 || propagates_null_from_subexpression(expr2, operand))
480 } else if let MirScalarExpr::CallUnary { func, expr } = &expr {
481 func.propagates_nulls() && propagates_null_from_subexpression(expr, operand)
482 } else {
483 false
484 }
485}
486
487/// Comparison method for sorting predicates by their complexity, measured by the total
488/// number of non-literal expression nodes within the expression.
489fn compare_predicates(x: &MirScalarExpr, y: &MirScalarExpr) -> Ordering {
490 (rank_complexity(x), x).cmp(&(rank_complexity(y), y))
491}
492
493/// For each equivalence class, it finds the simplest expression, which will be the canonical one.
494/// Returns a Map that maps from each expression in each equivalence class to the canonical
495/// expression in the same equivalence class.
496pub fn get_canonicalizer_map(
497 equivalences: &Vec<Vec<MirScalarExpr>>,
498) -> BTreeMap<MirScalarExpr, MirScalarExpr> {
499 let mut canonicalizer_map = BTreeMap::new();
500 for equivalence in equivalences {
501 // The unwrap is ok, because a join equivalence class can't be empty.
502 let canonical_expr = equivalence
503 .iter()
504 .min_by(|a, b| compare_predicates(*a, *b))
505 .unwrap();
506 for e in equivalence {
507 if e != canonical_expr {
508 canonicalizer_map.insert(e.clone(), canonical_expr.clone());
509 }
510 }
511 }
512 canonicalizer_map
513}
514
515/// A trait for a union-find data structure.
516pub trait UnionFind<T> {
517 /// Sets `self[x]` to the root from `x`, and returns a reference to the root.
518 fn find<'a>(&'a mut self, x: &T) -> Option<&'a T>;
519 /// Ensures that `x` and `y` have the same root.
520 fn union(&mut self, x: &T, y: &T);
521}
522
523impl<T: Clone + Ord> UnionFind<T> for BTreeMap<T, T> {
524 fn find<'a>(&'a mut self, x: &T) -> Option<&'a T> {
525 if !self.contains_key(x) {
526 None
527 } else {
528 if self[x] != self[&self[x]] {
529 // Path halving
530 let mut y = self[x].clone();
531 while y != self[&y] {
532 let grandparent = self[&self[&y]].clone();
533 *self.get_mut(&y).unwrap() = grandparent;
534 y.clone_from(&self[&y]);
535 }
536 *self.get_mut(x).unwrap() = y;
537 }
538 Some(&self[x])
539 }
540 }
541
542 fn union(&mut self, x: &T, y: &T) {
543 match (self.find(x).is_some(), self.find(y).is_some()) {
544 (true, true) => {
545 if self[x] != self[y] {
546 let root_x = self[x].clone();
547 let root_y = self[y].clone();
548 self.insert(root_x, root_y);
549 }
550 }
551 (false, true) => {
552 self.insert(x.clone(), self[y].clone());
553 }
554 (true, false) => {
555 self.insert(y.clone(), self[x].clone());
556 }
557 (false, false) => {
558 self.insert(x.clone(), x.clone());
559 self.insert(y.clone(), x.clone());
560 }
561 }
562 }
563}