mz_expr/relation/
canonicalize.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utility functions to transform parts of a single `MirRelationExpr`
11//! into canonical form.
12
13use std::cmp::Ordering;
14use std::collections::{BTreeMap, BTreeSet};
15
16use mz_ore::soft_assert_or_log;
17use mz_repr::{ColumnType, ScalarType};
18
19use crate::visit::Visit;
20use crate::{MirScalarExpr, UnaryFunc, VariadicFunc, func};
21
22/// Canonicalize equivalence classes of a join and expressions contained in them.
23///
24/// `input_types` can be the [ColumnType]s of the join or the [ColumnType]s of
25/// the individual inputs of the join in order.
26///
27/// This function:
28/// * simplifies expressions to involve the least number of non-literal nodes.
29///   This ensures that we only replace expressions by "even simpler"
30///   expressions and that repeated substitutions reduce the complexity of
31///   expressions and a fixed point is certain to be reached. Without this
32///   rule, we might repeatedly replace a simple expression with an equivalent
33///   complex expression containing that (or another replaceable) simple
34///   expression, and repeat indefinitely.
35/// * reduces all expressions contained in `equivalences`.
36/// * Does everything that [canonicalize_equivalence_classes] does.
37pub fn canonicalize_equivalences<'a, I>(
38    equivalences: &mut Vec<Vec<MirScalarExpr>>,
39    input_column_types: I,
40) where
41    I: Iterator<Item = &'a Vec<ColumnType>>,
42{
43    let column_types = input_column_types
44        .flat_map(|f| f.clone())
45        .collect::<Vec<_>>();
46    // Calculate the number of non-leaves for each expression.
47    let mut to_reduce = equivalences
48        .drain(..)
49        .filter_map(|mut cls| {
50            let mut result = cls
51                .drain(..)
52                .map(|expr| (rank_complexity(&expr), expr))
53                .collect::<Vec<_>>();
54            result.sort();
55            result.dedup();
56            if result.len() > 1 { Some(result) } else { None }
57        })
58        .collect::<Vec<_>>();
59
60    let mut expressions_rewritten = true;
61    while expressions_rewritten {
62        expressions_rewritten = false;
63        for i in 0..to_reduce.len() {
64            // `to_reduce` will be borrowed as immutable, so in order to modify
65            // elements of `to_reduce[i]`, we are going to pop them out of
66            // `to_reduce[i]` and put the modified version in `new_equivalence`,
67            // which will then replace `to_reduce[i]`.
68            let mut new_equivalence = Vec::with_capacity(to_reduce[i].len());
69            while let Some((_, mut popped_expr)) = to_reduce[i].pop() {
70                #[allow(deprecated)]
71                popped_expr.visit_mut_post_nolimit(&mut |e: &mut MirScalarExpr| {
72                    // If a simpler expression can be found that is equivalent
73                    // to e,
74                    if let Some(simpler_e) = to_reduce.iter().find_map(|cls| {
75                        if cls.iter().skip(1).position(|(_, expr)| e == expr).is_some() {
76                            Some(cls[0].1.clone())
77                        } else {
78                            None
79                        }
80                    }) {
81                        // Replace e with the simpler expression.
82                        *e = simpler_e;
83                        expressions_rewritten = true;
84                    }
85                });
86                popped_expr.reduce(&column_types);
87                new_equivalence.push((rank_complexity(&popped_expr), popped_expr));
88            }
89            new_equivalence.sort();
90            new_equivalence.dedup();
91            to_reduce[i] = new_equivalence;
92        }
93    }
94
95    // Map away the complexity rating.
96    *equivalences = to_reduce
97        .drain(..)
98        .map(|mut cls| cls.drain(..).map(|(_, expr)| expr).collect::<Vec<_>>())
99        .collect::<Vec<_>>();
100
101    canonicalize_equivalence_classes(equivalences);
102}
103
104/// Canonicalize only the equivalence classes of a join.
105///
106/// This function:
107/// * ensures the same expression appears in only one equivalence class.
108/// * ensures the equivalence classes are sorted and dedupped.
109/// ```rust
110/// use mz_expr::MirScalarExpr;
111/// use mz_expr::canonicalize::canonicalize_equivalence_classes;
112///
113/// let mut equivalences = vec![
114///     vec![MirScalarExpr::Column(1), MirScalarExpr::Column(4)],
115///     vec![MirScalarExpr::Column(3), MirScalarExpr::Column(5)],
116///     vec![MirScalarExpr::Column(0), MirScalarExpr::Column(3)],
117///     vec![MirScalarExpr::Column(2), MirScalarExpr::Column(2)],
118/// ];
119/// let expected = vec![
120///     vec![MirScalarExpr::Column(0),
121///         MirScalarExpr::Column(3),
122///         MirScalarExpr::Column(5)],
123///     vec![MirScalarExpr::Column(1), MirScalarExpr::Column(4)],
124/// ];
125/// canonicalize_equivalence_classes(&mut equivalences);
126/// assert_eq!(expected, equivalences)
127/// ````
128pub fn canonicalize_equivalence_classes(equivalences: &mut Vec<Vec<MirScalarExpr>>) {
129    // Fuse equivalence classes containing the same expression.
130    for index in 1..equivalences.len() {
131        for inner in 0..index {
132            if equivalences[index]
133                .iter()
134                .any(|pair| equivalences[inner].contains(pair))
135            {
136                let to_extend = std::mem::replace(&mut equivalences[index], Vec::new());
137                equivalences[inner].extend(to_extend);
138            }
139        }
140    }
141    for equivalence in equivalences.iter_mut() {
142        equivalence.sort();
143        equivalence.dedup();
144    }
145    equivalences.retain(|es| es.len() > 1);
146    equivalences.sort();
147}
148
149/// Gives a relative complexity ranking for an expression. Higher numbers mean
150/// greater complexity.
151///
152/// Currently, this method weighs literals as the least complex and weighs all
153/// other expressions by the number of non-literals. In the future, we can
154/// change how complexity is ranked so that repeated substitutions would result
155/// in arriving at "better" fixed points. For example, we could try to improve
156/// performance by ranking expressions by their estimated computation time.
157///
158/// To ensure we arrive at a fixed point after repeated substitutions, valid
159/// complexity rankings must fulfill the following property:
160/// For any expression `e`, there does not exist a SQL function `f` such
161/// that `complexity(e) >= complexity(f(e))`.
162///
163/// For ease of intuiting the fixed point that we will arrive at after
164/// repeated substitutions, it is nice but not required that complexity
165/// rankings additionally fulfill the following property:
166/// If expressions `e1` and `e2` are such that
167/// `complexity(e1) < complexity(e2)` then for all SQL functions `f`,
168/// `complexity(f(e1)) < complexity(f(e2))`.
169fn rank_complexity(expr: &MirScalarExpr) -> usize {
170    if expr.is_literal() {
171        // literals are the least complex
172        return 0;
173    }
174    let mut non_literal_count = 1;
175    expr.visit_pre(|e| {
176        if !e.is_literal() {
177            non_literal_count += 1
178        }
179    });
180    non_literal_count
181}
182
183/// Applies a flat_map on a Vec, and overwrites the vec with the result.
184fn flat_map_modify<T, I, F>(v: &mut Vec<T>, f: F)
185where
186    F: FnMut(T) -> I,
187    I: IntoIterator<Item = T>,
188{
189    let mut xx = v.drain(..).flat_map(f).collect();
190    v.append(&mut xx);
191}
192
193/// Canonicalize predicates of a filter.
194///
195/// This function reduces and canonicalizes the structure of each individual
196/// predicate. Then, it transforms predicates of the form "A and B" into two: "A"
197/// and "B". Afterwards, it reduces predicates based on information from other
198/// predicates in the set. Finally, it sorts and deduplicates the predicates.
199///
200/// Additionally, it also removes IS NOT NULL predicates if there is another
201/// null rejecting predicate for the same sub-expression.
202pub fn canonicalize_predicates(predicates: &mut Vec<MirScalarExpr>, column_types: &[ColumnType]) {
203    soft_assert_or_log!(
204        predicates
205            .iter()
206            .all(|p| p.typ(column_types).scalar_type == ScalarType::Bool),
207        "cannot canonicalize predicates that are not of type bool"
208    );
209
210    // 1) Reduce each individual predicate.
211    predicates.iter_mut().for_each(|p| p.reduce(column_types));
212
213    // 2) Split "A and B" into two predicates: "A" and "B"
214    // Relies on the `reduce` above having flattened nested ANDs.
215    flat_map_modify(predicates, |p| {
216        if let MirScalarExpr::CallVariadic {
217            func: VariadicFunc::And,
218            exprs,
219        } = p
220        {
221            exprs
222        } else {
223            vec![p]
224        }
225    });
226
227    // 3) Make non-null requirements explicit as predicates in order for
228    // step 4) to be able to simplify AND/OR expressions with IS NULL
229    // sub-predicates. This redundancy is removed later by step 5).
230    let mut non_null_columns = BTreeSet::new();
231    for p in predicates.iter() {
232        p.non_null_requirements(&mut non_null_columns);
233    }
234    predicates.extend(non_null_columns.iter().map(|c| {
235        MirScalarExpr::column(*c)
236            .call_unary(UnaryFunc::IsNull(func::IsNull))
237            .call_unary(UnaryFunc::Not(func::Not))
238    }));
239
240    // 4) Reduce across `predicates`.
241    // If a predicate `p` cannot be null, and `f(p)` is a nullable bool
242    // then the predicate `p & f(p)` is equal to `p & f(true)`, and
243    // `!p & f(p)` is equal to `!p & f(false)`. For any index i, the `Vec` of
244    // predicates `[p1, ... pi, ... pn]` is equivalent to the single predicate
245    // `pi & (p1 & ... & p(i-1) & p(i+1) ... & pn)`. Thus, if `pi`
246    // (resp. `!pi`) cannot be null, it is valid to replace with `true` (resp.
247    // `false`) every subexpression in `(p1 & ... & p(i-1) & p(i+1) ... & pn)`
248    // that is equal to `pi`.
249
250    // If `p` is null and `q` is a nullable bool, then `p & q` can be either
251    // `null` or `false` depending on what `q`. Our rendering pipeline treats
252    // both as "remove this row." Thus, in the specific context of filter
253    // predicates, it is acceptable to make the aforementioned substitution
254    // even if `pi` can be null.
255
256    // Note that this does some dedupping of predicates since if `p1 = p2`
257    // then this reduction process will replace `p1` with true.
258
259    // Maintain respectively:
260    // 1) A list of predicates for which we have checked for matching
261    // subexpressions
262    // 2) A list of predicates for which we have yet to do so.
263    let mut completed = Vec::new();
264    let mut todo = Vec::new();
265    // Seed `todo` with all predicates.
266    std::mem::swap(&mut todo, predicates);
267
268    while let Some(predicate_to_apply) = todo.pop() {
269        // Helper method: for each predicate `p`, see if all other predicates
270        // (a.k.a. the union of todo & completed) contains `p` as a
271        // subexpression, and replace the subexpression accordingly.
272        // This method lives inside the loop because in order to comply with
273        // Rust rules that only one mutable reference to `todo` can be held at a
274        // time.
275        let mut replace_subexpr_other_predicates =
276            |expr: &MirScalarExpr, constant_bool: &MirScalarExpr| {
277                // Do not replace subexpressions equal to `expr` if `expr` is a
278                // literal to avoid infinite looping.
279                if !expr.is_literal() {
280                    for other_predicate in todo.iter_mut() {
281                        replace_subexpr_and_reduce(
282                            other_predicate,
283                            expr,
284                            constant_bool,
285                            column_types,
286                        );
287                    }
288                    for other_idx in (0..completed.len()).rev() {
289                        if replace_subexpr_and_reduce(
290                            &mut completed[other_idx],
291                            expr,
292                            constant_bool,
293                            column_types,
294                        ) {
295                            // If a predicate in the `completed` list has
296                            // been simplified, stick it back into the `todo` list.
297                            todo.push(completed.remove(other_idx));
298                        }
299                    }
300                }
301            };
302        // Meat of loop starts here. If a predicate p is of the form `!q`, replace
303        // every instance of `q` in every other predicate with `false.`
304        // Otherwise, replace every instance of `p` in every other predicate
305        // with `true`.
306        if let MirScalarExpr::CallUnary {
307            func: UnaryFunc::Not(func::Not),
308            expr,
309        } = &predicate_to_apply
310        {
311            replace_subexpr_other_predicates(expr, &MirScalarExpr::literal_false())
312        } else {
313            replace_subexpr_other_predicates(&predicate_to_apply, &MirScalarExpr::literal_true());
314        }
315        completed.push(predicate_to_apply);
316    }
317
318    // 5) Remove redundant !isnull/isnull predicates after performing the replacements
319    // in the loop above.
320    std::mem::swap(&mut todo, &mut completed);
321    while let Some(predicate_to_apply) = todo.pop() {
322        // Remove redundant !isnull(x) predicates if there is another predicate
323        // that evaluates to NULL when `x` is NULL.
324        if let Some(operand) = is_not_null(&predicate_to_apply) {
325            if todo
326                .iter_mut()
327                .chain(completed.iter_mut())
328                .any(|p| is_null_rejecting_predicate(p, &operand))
329            {
330                // skip this predicate
331                continue;
332            }
333        } else if let MirScalarExpr::CallUnary {
334            func: UnaryFunc::IsNull(func::IsNull),
335            expr,
336        } = &predicate_to_apply
337        {
338            if todo
339                .iter_mut()
340                .chain(completed.iter_mut())
341                .any(|p| is_null_rejecting_predicate(p, expr))
342            {
343                completed.push(MirScalarExpr::literal_false());
344                break;
345            }
346        }
347        completed.push(predicate_to_apply);
348    }
349
350    if completed.iter().any(|p| {
351        (p.is_literal_false() || p.is_literal_null()) &&
352        // This extra check is only needed if we determine that the soft-assert
353        // at the top of this function would ever fail for a good reason.
354        p.typ(column_types).scalar_type == ScalarType::Bool
355    }) {
356        // all rows get filtered away if any predicate is null or false.
357        *predicates = vec![MirScalarExpr::literal_false()]
358    } else {
359        // Remove any predicates that have been reduced to "true"
360        completed.retain(|p| !p.is_literal_true());
361        *predicates = completed;
362    }
363
364    // 6) Sort and dedup predicates.
365    predicates.sort_by(compare_predicates);
366    predicates.dedup();
367}
368
369/// Replace any matching subexpressions in `predicate`, and if `predicate` has
370/// changed, reduce it. Return whether `predicate` has changed.
371fn replace_subexpr_and_reduce(
372    predicate: &mut MirScalarExpr,
373    replace_if_equal_to: &MirScalarExpr,
374    replace_with: &MirScalarExpr,
375    column_types: &[ColumnType],
376) -> bool {
377    let mut changed = false;
378    #[allow(deprecated)]
379    predicate.visit_mut_pre_post_nolimit(
380        &mut |e| {
381            // The `cond` of an if statement is not visited to prevent `then`
382            // or `els` from being evaluated before `cond`, resulting in a
383            // correctness error.
384            if let MirScalarExpr::If { then, els, .. } = e {
385                return Some(vec![then, els]);
386            }
387            None
388        },
389        &mut |e| {
390            if e == replace_if_equal_to {
391                *e = replace_with.clone();
392                changed = true;
393            } else if let MirScalarExpr::CallBinary {
394                func: r_func,
395                expr1: r_expr1,
396                expr2: r_expr2,
397            } = replace_if_equal_to
398            {
399                if let Some(negation) = r_func.negate() {
400                    if let MirScalarExpr::CallBinary {
401                        func: l_func,
402                        expr1: l_expr1,
403                        expr2: l_expr2,
404                    } = e
405                    {
406                        if negation == *l_func && l_expr1 == r_expr1 && l_expr2 == r_expr2 {
407                            *e = MirScalarExpr::CallUnary {
408                                func: UnaryFunc::Not(func::Not),
409                                expr: Box::new(replace_with.clone()),
410                            };
411                            changed = true;
412                        }
413                    }
414                }
415            }
416        },
417    );
418    if changed {
419        predicate.reduce(column_types);
420    }
421    changed
422}
423
424/// Returns the inner operand if the given predicate is an IS NOT NULL expression.
425fn is_not_null(predicate: &MirScalarExpr) -> Option<MirScalarExpr> {
426    if let MirScalarExpr::CallUnary {
427        func: UnaryFunc::Not(func::Not),
428        expr,
429    } = &predicate
430    {
431        if let MirScalarExpr::CallUnary {
432            func: UnaryFunc::IsNull(func::IsNull),
433            expr,
434        } = &**expr
435        {
436            return Some((**expr).clone());
437        }
438    }
439    None
440}
441
442/// Whether the given predicate evaluates to NULL when the given operand expression is NULL.
443#[inline(always)]
444fn is_null_rejecting_predicate(predicate: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
445    propagates_null_from_subexpression(predicate, operand)
446}
447
448fn propagates_null_from_subexpression(expr: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
449    if operand == expr {
450        true
451    } else if let MirScalarExpr::CallVariadic { func, exprs } = &expr {
452        func.propagates_nulls()
453            && (exprs
454                .iter()
455                .any(|e| propagates_null_from_subexpression(e, operand)))
456    } else if let MirScalarExpr::CallBinary { func, expr1, expr2 } = &expr {
457        func.propagates_nulls()
458            && (propagates_null_from_subexpression(expr1, operand)
459                || propagates_null_from_subexpression(expr2, operand))
460    } else if let MirScalarExpr::CallUnary { func, expr } = &expr {
461        func.propagates_nulls() && propagates_null_from_subexpression(expr, operand)
462    } else {
463        false
464    }
465}
466
467/// Comparison method for sorting predicates by their complexity, measured by the total
468/// number of non-literal expression nodes within the expression.
469fn compare_predicates(x: &MirScalarExpr, y: &MirScalarExpr) -> Ordering {
470    (rank_complexity(x), x).cmp(&(rank_complexity(y), y))
471}
472
473/// For each equivalence class, it finds the simplest expression, which will be the canonical one.
474/// Returns a Map that maps from each expression in each equivalence class to the canonical
475/// expression in the same equivalence class.
476pub fn get_canonicalizer_map(
477    equivalences: &Vec<Vec<MirScalarExpr>>,
478) -> BTreeMap<MirScalarExpr, MirScalarExpr> {
479    let mut canonicalizer_map = BTreeMap::new();
480    for equivalence in equivalences {
481        // The unwrap is ok, because a join equivalence class can't be empty.
482        let canonical_expr = equivalence
483            .iter()
484            .min_by(|a, b| compare_predicates(*a, *b))
485            .unwrap();
486        for e in equivalence {
487            if e != canonical_expr {
488                canonicalizer_map.insert(e.clone(), canonical_expr.clone());
489            }
490        }
491    }
492    canonicalizer_map
493}