mz_expr/relation/canonicalize.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Utility functions to transform parts of a single `MirRelationExpr`
11//! into canonical form.
12
13use std::cmp::Ordering;
14use std::collections::{BTreeMap, BTreeSet};
15
16use mz_ore::soft_assert_or_log;
17use mz_repr::{ColumnType, ScalarType};
18
19use crate::visit::Visit;
20use crate::{MirScalarExpr, UnaryFunc, VariadicFunc, func};
21
22/// Canonicalize equivalence classes of a join and expressions contained in them.
23///
24/// `input_types` can be the [ColumnType]s of the join or the [ColumnType]s of
25/// the individual inputs of the join in order.
26///
27/// This function:
28/// * simplifies expressions to involve the least number of non-literal nodes.
29/// This ensures that we only replace expressions by "even simpler"
30/// expressions and that repeated substitutions reduce the complexity of
31/// expressions and a fixed point is certain to be reached. Without this
32/// rule, we might repeatedly replace a simple expression with an equivalent
33/// complex expression containing that (or another replaceable) simple
34/// expression, and repeat indefinitely.
35/// * reduces all expressions contained in `equivalences`.
36/// * Does everything that [canonicalize_equivalence_classes] does.
37pub fn canonicalize_equivalences<'a, I>(
38 equivalences: &mut Vec<Vec<MirScalarExpr>>,
39 input_column_types: I,
40) where
41 I: Iterator<Item = &'a Vec<ColumnType>>,
42{
43 let column_types = input_column_types
44 .flat_map(|f| f.clone())
45 .collect::<Vec<_>>();
46 // Calculate the number of non-leaves for each expression.
47 let mut to_reduce = equivalences
48 .drain(..)
49 .filter_map(|mut cls| {
50 let mut result = cls
51 .drain(..)
52 .map(|expr| (rank_complexity(&expr), expr))
53 .collect::<Vec<_>>();
54 result.sort();
55 result.dedup();
56 if result.len() > 1 { Some(result) } else { None }
57 })
58 .collect::<Vec<_>>();
59
60 let mut expressions_rewritten = true;
61 while expressions_rewritten {
62 expressions_rewritten = false;
63 for i in 0..to_reduce.len() {
64 // `to_reduce` will be borrowed as immutable, so in order to modify
65 // elements of `to_reduce[i]`, we are going to pop them out of
66 // `to_reduce[i]` and put the modified version in `new_equivalence`,
67 // which will then replace `to_reduce[i]`.
68 let mut new_equivalence = Vec::with_capacity(to_reduce[i].len());
69 while let Some((_, mut popped_expr)) = to_reduce[i].pop() {
70 #[allow(deprecated)]
71 popped_expr.visit_mut_post_nolimit(&mut |e: &mut MirScalarExpr| {
72 // If a simpler expression can be found that is equivalent
73 // to e,
74 if let Some(simpler_e) = to_reduce.iter().find_map(|cls| {
75 if cls.iter().skip(1).position(|(_, expr)| e == expr).is_some() {
76 Some(cls[0].1.clone())
77 } else {
78 None
79 }
80 }) {
81 // Replace e with the simpler expression.
82 *e = simpler_e;
83 expressions_rewritten = true;
84 }
85 });
86 popped_expr.reduce(&column_types);
87 new_equivalence.push((rank_complexity(&popped_expr), popped_expr));
88 }
89 new_equivalence.sort();
90 new_equivalence.dedup();
91 to_reduce[i] = new_equivalence;
92 }
93 }
94
95 // Map away the complexity rating.
96 *equivalences = to_reduce
97 .drain(..)
98 .map(|mut cls| cls.drain(..).map(|(_, expr)| expr).collect::<Vec<_>>())
99 .collect::<Vec<_>>();
100
101 canonicalize_equivalence_classes(equivalences);
102}
103
104/// Canonicalize only the equivalence classes of a join.
105///
106/// This function:
107/// * ensures the same expression appears in only one equivalence class.
108/// * ensures the equivalence classes are sorted and dedupped.
109/// ```rust
110/// use mz_expr::MirScalarExpr;
111/// use mz_expr::canonicalize::canonicalize_equivalence_classes;
112///
113/// let mut equivalences = vec![
114/// vec![MirScalarExpr::Column(1), MirScalarExpr::Column(4)],
115/// vec![MirScalarExpr::Column(3), MirScalarExpr::Column(5)],
116/// vec![MirScalarExpr::Column(0), MirScalarExpr::Column(3)],
117/// vec![MirScalarExpr::Column(2), MirScalarExpr::Column(2)],
118/// ];
119/// let expected = vec![
120/// vec![MirScalarExpr::Column(0),
121/// MirScalarExpr::Column(3),
122/// MirScalarExpr::Column(5)],
123/// vec![MirScalarExpr::Column(1), MirScalarExpr::Column(4)],
124/// ];
125/// canonicalize_equivalence_classes(&mut equivalences);
126/// assert_eq!(expected, equivalences)
127/// ````
128pub fn canonicalize_equivalence_classes(equivalences: &mut Vec<Vec<MirScalarExpr>>) {
129 // Fuse equivalence classes containing the same expression.
130 for index in 1..equivalences.len() {
131 for inner in 0..index {
132 if equivalences[index]
133 .iter()
134 .any(|pair| equivalences[inner].contains(pair))
135 {
136 let to_extend = std::mem::replace(&mut equivalences[index], Vec::new());
137 equivalences[inner].extend(to_extend);
138 }
139 }
140 }
141 for equivalence in equivalences.iter_mut() {
142 equivalence.sort();
143 equivalence.dedup();
144 }
145 equivalences.retain(|es| es.len() > 1);
146 equivalences.sort();
147}
148
149/// Gives a relative complexity ranking for an expression. Higher numbers mean
150/// greater complexity.
151///
152/// Currently, this method weighs literals as the least complex and weighs all
153/// other expressions by the number of non-literals. In the future, we can
154/// change how complexity is ranked so that repeated substitutions would result
155/// in arriving at "better" fixed points. For example, we could try to improve
156/// performance by ranking expressions by their estimated computation time.
157///
158/// To ensure we arrive at a fixed point after repeated substitutions, valid
159/// complexity rankings must fulfill the following property:
160/// For any expression `e`, there does not exist a SQL function `f` such
161/// that `complexity(e) >= complexity(f(e))`.
162///
163/// For ease of intuiting the fixed point that we will arrive at after
164/// repeated substitutions, it is nice but not required that complexity
165/// rankings additionally fulfill the following property:
166/// If expressions `e1` and `e2` are such that
167/// `complexity(e1) < complexity(e2)` then for all SQL functions `f`,
168/// `complexity(f(e1)) < complexity(f(e2))`.
169fn rank_complexity(expr: &MirScalarExpr) -> usize {
170 if expr.is_literal() {
171 // literals are the least complex
172 return 0;
173 }
174 let mut non_literal_count = 1;
175 expr.visit_pre(|e| {
176 if !e.is_literal() {
177 non_literal_count += 1
178 }
179 });
180 non_literal_count
181}
182
183/// Applies a flat_map on a Vec, and overwrites the vec with the result.
184fn flat_map_modify<T, I, F>(v: &mut Vec<T>, f: F)
185where
186 F: FnMut(T) -> I,
187 I: IntoIterator<Item = T>,
188{
189 let mut xx = v.drain(..).flat_map(f).collect();
190 v.append(&mut xx);
191}
192
193/// Canonicalize predicates of a filter.
194///
195/// This function reduces and canonicalizes the structure of each individual
196/// predicate. Then, it transforms predicates of the form "A and B" into two: "A"
197/// and "B". Afterwards, it reduces predicates based on information from other
198/// predicates in the set. Finally, it sorts and deduplicates the predicates.
199///
200/// Additionally, it also removes IS NOT NULL predicates if there is another
201/// null rejecting predicate for the same sub-expression.
202pub fn canonicalize_predicates(predicates: &mut Vec<MirScalarExpr>, column_types: &[ColumnType]) {
203 soft_assert_or_log!(
204 predicates
205 .iter()
206 .all(|p| p.typ(column_types).scalar_type == ScalarType::Bool),
207 "cannot canonicalize predicates that are not of type bool"
208 );
209
210 // 1) Reduce each individual predicate.
211 predicates.iter_mut().for_each(|p| p.reduce(column_types));
212
213 // 2) Split "A and B" into two predicates: "A" and "B"
214 // Relies on the `reduce` above having flattened nested ANDs.
215 flat_map_modify(predicates, |p| {
216 if let MirScalarExpr::CallVariadic {
217 func: VariadicFunc::And,
218 exprs,
219 } = p
220 {
221 exprs
222 } else {
223 vec![p]
224 }
225 });
226
227 // 3) Make non-null requirements explicit as predicates in order for
228 // step 4) to be able to simplify AND/OR expressions with IS NULL
229 // sub-predicates. This redundancy is removed later by step 5).
230 let mut non_null_columns = BTreeSet::new();
231 for p in predicates.iter() {
232 p.non_null_requirements(&mut non_null_columns);
233 }
234 predicates.extend(non_null_columns.iter().map(|c| {
235 MirScalarExpr::column(*c)
236 .call_unary(UnaryFunc::IsNull(func::IsNull))
237 .call_unary(UnaryFunc::Not(func::Not))
238 }));
239
240 // 4) Reduce across `predicates`.
241 // If a predicate `p` cannot be null, and `f(p)` is a nullable bool
242 // then the predicate `p & f(p)` is equal to `p & f(true)`, and
243 // `!p & f(p)` is equal to `!p & f(false)`. For any index i, the `Vec` of
244 // predicates `[p1, ... pi, ... pn]` is equivalent to the single predicate
245 // `pi & (p1 & ... & p(i-1) & p(i+1) ... & pn)`. Thus, if `pi`
246 // (resp. `!pi`) cannot be null, it is valid to replace with `true` (resp.
247 // `false`) every subexpression in `(p1 & ... & p(i-1) & p(i+1) ... & pn)`
248 // that is equal to `pi`.
249
250 // If `p` is null and `q` is a nullable bool, then `p & q` can be either
251 // `null` or `false` depending on what `q`. Our rendering pipeline treats
252 // both as "remove this row." Thus, in the specific context of filter
253 // predicates, it is acceptable to make the aforementioned substitution
254 // even if `pi` can be null.
255
256 // Note that this does some dedupping of predicates since if `p1 = p2`
257 // then this reduction process will replace `p1` with true.
258
259 // Maintain respectively:
260 // 1) A list of predicates for which we have checked for matching
261 // subexpressions
262 // 2) A list of predicates for which we have yet to do so.
263 let mut completed = Vec::new();
264 let mut todo = Vec::new();
265 // Seed `todo` with all predicates.
266 std::mem::swap(&mut todo, predicates);
267
268 while let Some(predicate_to_apply) = todo.pop() {
269 // Helper method: for each predicate `p`, see if all other predicates
270 // (a.k.a. the union of todo & completed) contains `p` as a
271 // subexpression, and replace the subexpression accordingly.
272 // This method lives inside the loop because in order to comply with
273 // Rust rules that only one mutable reference to `todo` can be held at a
274 // time.
275 let mut replace_subexpr_other_predicates =
276 |expr: &MirScalarExpr, constant_bool: &MirScalarExpr| {
277 // Do not replace subexpressions equal to `expr` if `expr` is a
278 // literal to avoid infinite looping.
279 if !expr.is_literal() {
280 for other_predicate in todo.iter_mut() {
281 replace_subexpr_and_reduce(
282 other_predicate,
283 expr,
284 constant_bool,
285 column_types,
286 );
287 }
288 for other_idx in (0..completed.len()).rev() {
289 if replace_subexpr_and_reduce(
290 &mut completed[other_idx],
291 expr,
292 constant_bool,
293 column_types,
294 ) {
295 // If a predicate in the `completed` list has
296 // been simplified, stick it back into the `todo` list.
297 todo.push(completed.remove(other_idx));
298 }
299 }
300 }
301 };
302 // Meat of loop starts here. If a predicate p is of the form `!q`, replace
303 // every instance of `q` in every other predicate with `false.`
304 // Otherwise, replace every instance of `p` in every other predicate
305 // with `true`.
306 if let MirScalarExpr::CallUnary {
307 func: UnaryFunc::Not(func::Not),
308 expr,
309 } = &predicate_to_apply
310 {
311 replace_subexpr_other_predicates(expr, &MirScalarExpr::literal_false())
312 } else {
313 replace_subexpr_other_predicates(&predicate_to_apply, &MirScalarExpr::literal_true());
314 }
315 completed.push(predicate_to_apply);
316 }
317
318 // 5) Remove redundant !isnull/isnull predicates after performing the replacements
319 // in the loop above.
320 std::mem::swap(&mut todo, &mut completed);
321 while let Some(predicate_to_apply) = todo.pop() {
322 // Remove redundant !isnull(x) predicates if there is another predicate
323 // that evaluates to NULL when `x` is NULL.
324 if let Some(operand) = is_not_null(&predicate_to_apply) {
325 if todo
326 .iter_mut()
327 .chain(completed.iter_mut())
328 .any(|p| is_null_rejecting_predicate(p, &operand))
329 {
330 // skip this predicate
331 continue;
332 }
333 } else if let MirScalarExpr::CallUnary {
334 func: UnaryFunc::IsNull(func::IsNull),
335 expr,
336 } = &predicate_to_apply
337 {
338 if todo
339 .iter_mut()
340 .chain(completed.iter_mut())
341 .any(|p| is_null_rejecting_predicate(p, expr))
342 {
343 completed.push(MirScalarExpr::literal_false());
344 break;
345 }
346 }
347 completed.push(predicate_to_apply);
348 }
349
350 if completed.iter().any(|p| {
351 (p.is_literal_false() || p.is_literal_null()) &&
352 // This extra check is only needed if we determine that the soft-assert
353 // at the top of this function would ever fail for a good reason.
354 p.typ(column_types).scalar_type == ScalarType::Bool
355 }) {
356 // all rows get filtered away if any predicate is null or false.
357 *predicates = vec![MirScalarExpr::literal_false()]
358 } else {
359 // Remove any predicates that have been reduced to "true"
360 completed.retain(|p| !p.is_literal_true());
361 *predicates = completed;
362 }
363
364 // 6) Sort and dedup predicates.
365 predicates.sort_by(compare_predicates);
366 predicates.dedup();
367}
368
369/// Replace any matching subexpressions in `predicate`, and if `predicate` has
370/// changed, reduce it. Return whether `predicate` has changed.
371fn replace_subexpr_and_reduce(
372 predicate: &mut MirScalarExpr,
373 replace_if_equal_to: &MirScalarExpr,
374 replace_with: &MirScalarExpr,
375 column_types: &[ColumnType],
376) -> bool {
377 let mut changed = false;
378 #[allow(deprecated)]
379 predicate.visit_mut_pre_post_nolimit(
380 &mut |e| {
381 // The `cond` of an if statement is not visited to prevent `then`
382 // or `els` from being evaluated before `cond`, resulting in a
383 // correctness error.
384 if let MirScalarExpr::If { then, els, .. } = e {
385 return Some(vec![then, els]);
386 }
387 None
388 },
389 &mut |e| {
390 if e == replace_if_equal_to {
391 *e = replace_with.clone();
392 changed = true;
393 } else if let MirScalarExpr::CallBinary {
394 func: r_func,
395 expr1: r_expr1,
396 expr2: r_expr2,
397 } = replace_if_equal_to
398 {
399 if let Some(negation) = r_func.negate() {
400 if let MirScalarExpr::CallBinary {
401 func: l_func,
402 expr1: l_expr1,
403 expr2: l_expr2,
404 } = e
405 {
406 if negation == *l_func && l_expr1 == r_expr1 && l_expr2 == r_expr2 {
407 *e = MirScalarExpr::CallUnary {
408 func: UnaryFunc::Not(func::Not),
409 expr: Box::new(replace_with.clone()),
410 };
411 changed = true;
412 }
413 }
414 }
415 }
416 },
417 );
418 if changed {
419 predicate.reduce(column_types);
420 }
421 changed
422}
423
424/// Returns the inner operand if the given predicate is an IS NOT NULL expression.
425fn is_not_null(predicate: &MirScalarExpr) -> Option<MirScalarExpr> {
426 if let MirScalarExpr::CallUnary {
427 func: UnaryFunc::Not(func::Not),
428 expr,
429 } = &predicate
430 {
431 if let MirScalarExpr::CallUnary {
432 func: UnaryFunc::IsNull(func::IsNull),
433 expr,
434 } = &**expr
435 {
436 return Some((**expr).clone());
437 }
438 }
439 None
440}
441
442/// Whether the given predicate evaluates to NULL when the given operand expression is NULL.
443#[inline(always)]
444fn is_null_rejecting_predicate(predicate: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
445 propagates_null_from_subexpression(predicate, operand)
446}
447
448fn propagates_null_from_subexpression(expr: &MirScalarExpr, operand: &MirScalarExpr) -> bool {
449 if operand == expr {
450 true
451 } else if let MirScalarExpr::CallVariadic { func, exprs } = &expr {
452 func.propagates_nulls()
453 && (exprs
454 .iter()
455 .any(|e| propagates_null_from_subexpression(e, operand)))
456 } else if let MirScalarExpr::CallBinary { func, expr1, expr2 } = &expr {
457 func.propagates_nulls()
458 && (propagates_null_from_subexpression(expr1, operand)
459 || propagates_null_from_subexpression(expr2, operand))
460 } else if let MirScalarExpr::CallUnary { func, expr } = &expr {
461 func.propagates_nulls() && propagates_null_from_subexpression(expr, operand)
462 } else {
463 false
464 }
465}
466
467/// Comparison method for sorting predicates by their complexity, measured by the total
468/// number of non-literal expression nodes within the expression.
469fn compare_predicates(x: &MirScalarExpr, y: &MirScalarExpr) -> Ordering {
470 (rank_complexity(x), x).cmp(&(rank_complexity(y), y))
471}
472
473/// For each equivalence class, it finds the simplest expression, which will be the canonical one.
474/// Returns a Map that maps from each expression in each equivalence class to the canonical
475/// expression in the same equivalence class.
476pub fn get_canonicalizer_map(
477 equivalences: &Vec<Vec<MirScalarExpr>>,
478) -> BTreeMap<MirScalarExpr, MirScalarExpr> {
479 let mut canonicalizer_map = BTreeMap::new();
480 for equivalence in equivalences {
481 // The unwrap is ok, because a join equivalence class can't be empty.
482 let canonical_expr = equivalence
483 .iter()
484 .min_by(|a, b| compare_predicates(*a, *b))
485 .unwrap();
486 for e in equivalence {
487 if e != canonical_expr {
488 canonicalizer_map.insert(e.clone(), canonical_expr.clone());
489 }
490 }
491 }
492 canonicalizer_map
493}