Skip to main content

mz_expr/relation/
canonicalize.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utility functions to transform parts of a single `MirRelationExpr`
11//! into canonical form.
12
13use std::cmp::Ordering;
14use std::collections::{BTreeMap, BTreeSet};
15use std::rc::Rc;
16
17use itertools::Itertools;
18use mz_ore::soft_assert_or_log;
19use mz_repr::{ReprColumnType, ReprScalarType};
20
21use crate::visit::Visit;
22use crate::{MirScalarExpr, UnaryFunc, VariadicFunc, func};
23
24/// Canonicalize equivalence classes of a join and expressions contained in them.
25///
26/// `input_types` can be the [ReprColumnType]s of the join or the [ReprColumnType]s of
27/// the individual inputs of the join in order.
28///
29/// This function:
30/// * simplifies expressions to involve the least number of non-literal nodes.
31///   This ensures that we only replace expressions by "even simpler"
32///   expressions and that repeated substitutions reduce the complexity of
33///   expressions and a fixed point is certain to be reached. Without this
34///   rule, we might repeatedly replace a simple expression with an equivalent
35///   complex expression containing that (or another replaceable) simple
36///   expression, and repeat indefinitely.
37/// * reduces all expressions contained in `equivalences`.
38/// * Does everything that [canonicalize_equivalence_classes] does.
39pub fn canonicalize_equivalences<'a, I>(
40    equivalences: &mut Vec<Vec<MirScalarExpr>>,
41    input_column_types: I,
42) where
43    I: Iterator<Item = &'a Vec<ReprColumnType>>,
44{
45    let repr_column_types = input_column_types
46        .flat_map(|f| f.clone())
47        .collect::<Vec<_>>();
48    // Calculate the number of non-leaves for each expression.
49    let mut to_reduce = equivalences
50        .drain(..)
51        .filter_map(|mut cls| {
52            let mut result = cls
53                .drain(..)
54                .map(|expr| (rank_complexity(&expr), expr))
55                .collect::<Vec<_>>();
56            result.sort();
57            result.dedup();
58            if result.len() > 1 { Some(result) } else { None }
59        })
60        .collect::<Vec<_>>();
61
62    let mut expressions_rewritten = true;
63    while expressions_rewritten {
64        expressions_rewritten = false;
65        for i in 0..to_reduce.len() {
66            // `to_reduce` will be borrowed as immutable, so in order to modify
67            // elements of `to_reduce[i]`, we are going to pop them out of
68            // `to_reduce[i]` and put the modified version in `new_equivalence`,
69            // which will then replace `to_reduce[i]`.
70            let mut new_equivalence = Vec::with_capacity(to_reduce[i].len());
71            while let Some((_, mut popped_expr)) = to_reduce[i].pop() {
72                popped_expr.visit_mut_post(&mut |e: &mut MirScalarExpr| {
73                    // If a simpler expression can be found that is equivalent
74                    // to e,
75                    if let Some(simpler_e) = to_reduce.iter().find_map(|cls| {
76                        if cls.iter().skip(1).position(|(_, expr)| e == expr).is_some() {
77                            Some(cls[0].1.clone())
78                        } else {
79                            None
80                        }
81                    }) {
82                        // Replace e with the simpler expression.
83                        *e = simpler_e;
84                        expressions_rewritten = true;
85                    }
86                });
87                popped_expr.reduce(&repr_column_types);
88                new_equivalence.push((rank_complexity(&popped_expr), popped_expr));
89            }
90            new_equivalence.sort();
91            new_equivalence.dedup();
92            to_reduce[i] = new_equivalence;
93        }
94    }
95
96    // Map away the complexity rating.
97    *equivalences = to_reduce
98        .drain(..)
99        .map(|mut cls| cls.drain(..).map(|(_, expr)| expr).collect::<Vec<_>>())
100        .collect::<Vec<_>>();
101
102    canonicalize_equivalence_classes(equivalences);
103}
104
105/// Canonicalize only the equivalence classes of a join.
106///
107/// This function:
108/// * ensures the same expression appears in only one equivalence class.
109/// * ensures the equivalence classes are sorted and dedupped.
110/// ```rust
111/// use mz_expr::MirScalarExpr;
112/// use mz_expr::canonicalize::canonicalize_equivalence_classes;
113///
114/// let mut equivalences = vec![
115///     vec![MirScalarExpr::column(1), MirScalarExpr::column(4)],
116///     vec![MirScalarExpr::column(3), MirScalarExpr::column(5)],
117///     vec![MirScalarExpr::column(0), MirScalarExpr::column(3)],
118///     vec![MirScalarExpr::column(2), MirScalarExpr::column(2)],
119/// ];
120/// let expected = vec![
121///     vec![MirScalarExpr::column(0),
122///         MirScalarExpr::column(3),
123///         MirScalarExpr::column(5)],
124///     vec![MirScalarExpr::column(1), MirScalarExpr::column(4)],
125/// ];
126/// canonicalize_equivalence_classes(&mut equivalences);
127/// assert_eq!(expected, equivalences)
128/// ````
129pub fn canonicalize_equivalence_classes(equivalences: &mut Vec<Vec<MirScalarExpr>>) {
130    let mut uf = BTreeMap::new();
131    for class in equivalences.iter_mut() {
132        let mut iter = class.drain(..);
133        if let Some(first) = iter.next() {
134            let head = Rc::new(first);
135            for rest in iter {
136                uf.union(&head, &Rc::new(rest));
137            }
138        }
139    }
140
141    let mut eqs: BTreeMap<Rc<MirScalarExpr>, BTreeSet<Rc<MirScalarExpr>>> = BTreeMap::new();
142    for (k, v) in uf {
143        eqs.entry(v).or_default().insert(Rc::clone(&k));
144    }
145
146    let classes = eqs.into_values().collect::<Vec<_>>();
147    equivalences.resize(classes.len(), Vec::new());
148    equivalences
149        .iter_mut()
150        .zip_eq(classes)
151        .for_each(|(equivalence, class)| {
152            equivalence.extend(
153                class
154                    .into_iter()
155                    .map(|e| Rc::try_unwrap(e).expect("there to be only one strong ref")),
156            );
157        });
158
159    equivalences.retain(|es| es.len() > 1);
160    equivalences.sort();
161}
162
163/// Gives a relative complexity ranking for an expression. Higher numbers mean
164/// greater complexity.
165///
166/// Currently, this method weighs literals as the least complex and weighs all
167/// other expressions by the number of non-literals. In the future, we can
168/// change how complexity is ranked so that repeated substitutions would result
169/// in arriving at "better" fixed points. For example, we could try to improve
170/// performance by ranking expressions by their estimated computation time.
171///
172/// To ensure we arrive at a fixed point after repeated substitutions, valid
173/// complexity rankings must fulfill the following property:
174/// For any expression `e`, there does not exist a SQL function `f` such
175/// that `complexity(e) >= complexity(f(e))`.
176///
177/// For ease of intuiting the fixed point that we will arrive at after
178/// repeated substitutions, it is nice but not required that complexity
179/// rankings additionally fulfill the following property:
180/// If expressions `e1` and `e2` are such that
181/// `complexity(e1) < complexity(e2)` then for all SQL functions `f`,
182/// `complexity(f(e1)) < complexity(f(e2))`.
183fn rank_complexity(expr: &MirScalarExpr) -> usize {
184    if expr.is_literal() {
185        // literals are the least complex
186        return 0;
187    }
188    let mut non_literal_count = 1;
189    expr.visit_pre(|e| {
190        if !e.is_literal() {
191            non_literal_count += 1
192        }
193    });
194    non_literal_count
195}
196
197/// Applies a flat_map on a Vec, and overwrites the vec with the result.
198fn flat_map_modify<T, I, F>(v: &mut Vec<T>, f: F)
199where
200    F: FnMut(T) -> I,
201    I: IntoIterator<Item = T>,
202{
203    let mut xx = v.drain(..).flat_map(f).collect();
204    v.append(&mut xx);
205}
206
207/// Canonicalize predicates of a filter.
208///
209/// This function reduces and canonicalizes the structure of each individual
210/// predicate. Then, it transforms predicates of the form "A and B" into two: "A"
211/// and "B". Afterwards, it reduces predicates based on information from other
212/// predicates in the set. Finally, it sorts and deduplicates the predicates.
213///
214/// Additionally, it also removes IS NOT NULL predicates if there is another
215/// null rejecting predicate for the same sub-expression.
216pub fn canonicalize_predicates(
217    predicates: &mut Vec<MirScalarExpr>,
218    repr_column_types: &[ReprColumnType],
219) {
220    soft_assert_or_log!(
221        predicates
222            .iter()
223            .all(|p| p.typ(repr_column_types).scalar_type == ReprScalarType::Bool),
224        "cannot canonicalize predicates that are not of type bool"
225    );
226
227    // 1) Reduce each individual predicate.
228    predicates
229        .iter_mut()
230        .for_each(|p| p.reduce(repr_column_types));
231
232    // 2) Split "A and B" into two predicates: "A" and "B"
233    // Relies on the `reduce` above having flattened nested ANDs.
234    flat_map_modify(predicates, |p| {
235        if let MirScalarExpr::CallVariadic {
236            func: VariadicFunc::And(_),
237            exprs,
238        } = p
239        {
240            exprs
241        } else {
242            vec![p]
243        }
244    });
245
246    // 3) Make non-null requirements explicit as predicates in order for
247    // step 4) to be able to simplify AND/OR expressions with IS NULL
248    // sub-predicates. This redundancy is removed later by step 5).
249    let mut non_null_columns = BTreeSet::new();
250    for p in predicates.iter() {
251        p.non_null_requirements(&mut non_null_columns);
252    }
253    predicates.extend(non_null_columns.iter().map(|c| {
254        MirScalarExpr::column(*c)
255            .call_unary(UnaryFunc::IsNull(func::IsNull))
256            .call_unary(UnaryFunc::Not(func::Not))
257    }));
258
259    // 4) Reduce across `predicates`.
260    // If a predicate `p` cannot be null, and `f(p)` is a nullable bool
261    // then the predicate `p & f(p)` is equal to `p & f(true)`, and
262    // `!p & f(p)` is equal to `!p & f(false)`. For any index i, the `Vec` of
263    // predicates `[p1, ... pi, ... pn]` is equivalent to the single predicate
264    // `pi & (p1 & ... & p(i-1) & p(i+1) ... & pn)`. Thus, if `pi`
265    // (resp. `!pi`) cannot be null, it is valid to replace with `true` (resp.
266    // `false`) every subexpression in `(p1 & ... & p(i-1) & p(i+1) ... & pn)`
267    // that is equal to `pi`.
268
269    // If `p` is null and `q` is a nullable bool, then `p & q` can be either
270    // `null` or `false` depending on what `q`. Our rendering pipeline treats
271    // both as "remove this row." Thus, in the specific context of filter
272    // predicates, it is acceptable to make the aforementioned substitution
273    // even if `pi` can be null.
274
275    // Note that this does some dedupping of predicates since if `p1 = p2`
276    // then this reduction process will replace `p1` with true.
277
278    // Maintain respectively:
279    // 1) A list of predicates for which we have checked for matching
280    // subexpressions
281    // 2) A list of predicates for which we have yet to do so.
282    let mut completed = Vec::new();
283    let mut todo = Vec::new();
284    // Seed `todo` with all predicates.
285    std::mem::swap(&mut todo, predicates);
286
287    while let Some(predicate_to_apply) = todo.pop() {
288        // Helper method: for each predicate `p`, see if all other predicates
289        // (a.k.a. the union of todo & completed) contains `p` as a
290        // subexpression, and replace the subexpression accordingly.
291        // This method lives inside the loop because in order to comply with
292        // Rust rules that only one mutable reference to `todo` can be held at a
293        // time.
294        let mut replace_subexpr_other_predicates =
295            |expr: &MirScalarExpr, constant_bool: &MirScalarExpr| {
296                // Do not replace subexpressions equal to `expr` if `expr` is a
297                // literal to avoid infinite looping.
298                if !expr.is_literal() {
299                    for other_predicate in todo.iter_mut() {
300                        replace_subexpr_and_reduce(
301                            other_predicate,
302                            expr,
303                            constant_bool,
304                            repr_column_types,
305                        );
306                    }
307                    for other_idx in (0..completed.len()).rev() {
308                        if replace_subexpr_and_reduce(
309                            &mut completed[other_idx],
310                            expr,
311                            constant_bool,
312                            repr_column_types,
313                        ) {
314                            // If a predicate in the `completed` list has
315                            // been simplified, stick it back into the `todo` list.
316                            todo.push(completed.remove(other_idx));
317                        }
318                    }
319                }
320            };
321        // Meat of loop starts here. If a predicate p is of the form `!q`, replace
322        // every instance of `q` in every other predicate with `false.`
323        // Otherwise, replace every instance of `p` in every other predicate
324        // with `true`.
325        if let MirScalarExpr::CallUnary {
326            func: UnaryFunc::Not(func::Not),
327            expr,
328        } = &predicate_to_apply
329        {
330            replace_subexpr_other_predicates(expr, &MirScalarExpr::literal_false())
331        } else {
332            replace_subexpr_other_predicates(&predicate_to_apply, &MirScalarExpr::literal_true());
333        }
334        completed.push(predicate_to_apply);
335    }
336
337    // 5) Remove redundant !isnull/isnull predicates after performing the replacements
338    // in the loop above.
339    std::mem::swap(&mut todo, &mut completed);
340    while let Some(predicate_to_apply) = todo.pop() {
341        // Remove redundant !isnull(x) predicates if there is another predicate
342        // that evaluates to NULL when `x` is NULL.
343        if let Some(operand) = is_not_null(&predicate_to_apply) {
344            if todo
345                .iter_mut()
346                .chain(completed.iter_mut())
347                .any(|p| is_null_rejecting_predicate(p, &operand))
348            {
349                // skip this predicate
350                continue;
351            }
352        } else if let MirScalarExpr::CallUnary {
353            func: UnaryFunc::IsNull(func::IsNull),
354            expr,
355        } = &predicate_to_apply
356        {
357            if todo
358                .iter_mut()
359                .chain(completed.iter_mut())
360                .any(|p| is_null_rejecting_predicate(p, expr))
361            {
362                completed.push(MirScalarExpr::literal_false());
363                break;
364            }
365        }
366        completed.push(predicate_to_apply);
367    }
368
369    if completed.iter().any(|p| {
370        (p.is_literal_false() || p.is_literal_null()) &&
371        // This extra check is only needed if we determine that the soft-assert
372        // at the top of this function would ever fail for a good reason.
373        p.typ(repr_column_types).scalar_type == ReprScalarType::Bool
374    }) {
375        // all rows get filtered away if any predicate is null or false.
376        *predicates = vec![MirScalarExpr::literal_false()]
377    } else {
378        // Remove any predicates that have been reduced to "true"
379        completed.retain(|p| !p.is_literal_true());
380        *predicates = completed;
381    }
382
383    // 6) Sort and dedup predicates.
384    predicates.sort_by(compare_predicates);
385    predicates.dedup();
386}
387
388/// Replace any matching subexpressions in `predicate`, and if `predicate` has
389/// changed, reduce it. Return whether `predicate` has changed.
390fn replace_subexpr_and_reduce(
391    predicate: &mut MirScalarExpr,
392    replace_if_equal_to: &MirScalarExpr,
393    replace_with: &MirScalarExpr,
394    repr_column_types: &[ReprColumnType],
395) -> bool {
396    let mut changed = false;
397    predicate.visit_mut_pre_post(
398        &mut |e| {
399            // The `cond` of an if statement is not visited to prevent `then`
400            // or `els` from being evaluated before `cond`, resulting in a
401            // correctness error.
402            if let MirScalarExpr::If { then, els, .. } = e {
403                return Some(vec![then, els]);
404            }
405            None
406        },
407        &mut |e| {
408            if e == replace_if_equal_to {
409                *e = replace_with.clone();
410                changed = true;
411            } else if let MirScalarExpr::CallBinary {
412                func: r_func,
413                expr1: r_expr1,
414                expr2: r_expr2,
415            } = replace_if_equal_to
416            {
417                if let Some(negation) = r_func.negate() {
418                    if let MirScalarExpr::CallBinary {
419                        func: l_func,
420                        expr1: l_expr1,
421                        expr2: l_expr2,
422                    } = e
423                    {
424                        if negation == *l_func && l_expr1 == r_expr1 && l_expr2 == r_expr2 {
425                            *e = MirScalarExpr::CallUnary {
426                                func: UnaryFunc::Not(func::Not),
427                                expr: Box::new(replace_with.clone()),
428                            };
429                            changed = true;
430                        }
431                    }
432                }
433            }
434        },
435    );
436    if changed {
437        predicate.reduce(repr_column_types);
438    }
439    changed
440}
441
442/// Returns the inner operand if the given predicate is an IS NOT NULL expression.
443fn is_not_null(predicate: &MirScalarExpr) -> Option<MirScalarExpr> {
444    if let MirScalarExpr::CallUnary {
445        func: UnaryFunc::Not(func::Not),
446        expr,
447    } = &predicate
448    {
449        if let MirScalarExpr::CallUnary {
450            func: UnaryFunc::IsNull(func::IsNull),
451            expr,
452        } = &**expr
453        {
454            return Some((**expr).clone());
455        }
456    }
457    None
458}
459
460/// Whether the given predicate evaluates to NULL when the given operand expression is NULL.
461#[inline(always)]
462fn is_null_rejecting_predicate(predicate: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
463    propagates_null_from_subexpression(predicate, operand)
464}
465
466fn propagates_null_from_subexpression(expr: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
467    if operand == expr {
468        true
469    } else if let MirScalarExpr::CallVariadic { func, exprs } = &expr {
470        func.propagates_nulls()
471            && (exprs
472                .iter()
473                .any(|e| propagates_null_from_subexpression(e, operand)))
474    } else if let MirScalarExpr::CallBinary { func, expr1, expr2 } = &expr {
475        func.propagates_nulls()
476            && (propagates_null_from_subexpression(expr1, operand)
477                || propagates_null_from_subexpression(expr2, operand))
478    } else if let MirScalarExpr::CallUnary { func, expr } = &expr {
479        func.propagates_nulls() && propagates_null_from_subexpression(expr, operand)
480    } else {
481        false
482    }
483}
484
485/// Comparison method for sorting predicates by their complexity, measured by the total
486/// number of non-literal expression nodes within the expression.
487fn compare_predicates(x: &MirScalarExpr, y: &MirScalarExpr) -> Ordering {
488    (rank_complexity(x), x).cmp(&(rank_complexity(y), y))
489}
490
491/// For each equivalence class, it finds the simplest expression, which will be the canonical one.
492/// Returns a Map that maps from each expression in each equivalence class to the canonical
493/// expression in the same equivalence class.
494pub fn get_canonicalizer_map(
495    equivalences: &Vec<Vec<MirScalarExpr>>,
496) -> BTreeMap<MirScalarExpr, MirScalarExpr> {
497    let mut canonicalizer_map = BTreeMap::new();
498    for equivalence in equivalences {
499        // The unwrap is ok, because a join equivalence class can't be empty.
500        let canonical_expr = equivalence
501            .iter()
502            .min_by(|a, b| compare_predicates(*a, *b))
503            .unwrap();
504        for e in equivalence {
505            if e != canonical_expr {
506                canonicalizer_map.insert(e.clone(), canonical_expr.clone());
507            }
508        }
509    }
510    canonicalizer_map
511}
512
513/// A trait for a union-find data structure.
514pub trait UnionFind<T> {
515    /// Sets `self[x]` to the root from `x`, and returns a reference to the root.
516    fn find<'a>(&'a mut self, x: &T) -> Option<&'a T>;
517    /// Ensures that `x` and `y` have the same root.
518    fn union(&mut self, x: &T, y: &T);
519}
520
521impl<T: Clone + Ord> UnionFind<T> for BTreeMap<T, T> {
522    fn find<'a>(&'a mut self, x: &T) -> Option<&'a T> {
523        if !self.contains_key(x) {
524            None
525        } else {
526            if self[x] != self[&self[x]] {
527                // Path halving
528                let mut y = self[x].clone();
529                while y != self[&y] {
530                    let grandparent = self[&self[&y]].clone();
531                    *self.get_mut(&y).unwrap() = grandparent;
532                    y.clone_from(&self[&y]);
533                }
534                *self.get_mut(x).unwrap() = y;
535            }
536            Some(&self[x])
537        }
538    }
539
540    fn union(&mut self, x: &T, y: &T) {
541        match (self.find(x).is_some(), self.find(y).is_some()) {
542            (true, true) => {
543                if self[x] != self[y] {
544                    let root_x = self[x].clone();
545                    let root_y = self[y].clone();
546                    self.insert(root_x, root_y);
547                }
548            }
549            (false, true) => {
550                self.insert(x.clone(), self[y].clone());
551            }
552            (true, false) => {
553                self.insert(y.clone(), self[x].clone());
554            }
555            (false, false) => {
556                self.insert(x.clone(), x.clone());
557                self.insert(y.clone(), x.clone());
558            }
559        }
560    }
561}