mz_transform/analysis/equivalences.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use mz_expr::explain::{HumanizedExplain, HumanizerMode};
22use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
23use mz_ore::str::{bracketed, separated};
24use mz_repr::{ColumnType, Datum};
25
26use crate::analysis::{Analysis, Lattice};
27use crate::analysis::{Arity, RelationType};
28use crate::analysis::{Derived, DerivedBuilder};
29
30/// Pulls up and pushes down predicate information represented as equivalences
31#[derive(Debug, Default)]
32pub struct Equivalences;
33
34impl Analysis for Equivalences {
35 // A `Some(list)` indicates a list of classes of equivalent expressions.
36 // A `None` indicates all expressions are equivalent, including contradictions;
37 // this is only possible for the empty collection, and as an initial result for
38 // unconstrained recursive terms.
39 type Value = Option<EquivalenceClasses>;
40
41 fn announce_dependencies(builder: &mut DerivedBuilder) {
42 builder.require(Arity);
43 builder.require(RelationType); // needed for expression reduction.
44 }
45
46 fn derive(
47 &self,
48 expr: &MirRelationExpr,
49 index: usize,
50 results: &[Self::Value],
51 depends: &Derived,
52 ) -> Self::Value {
53 let mut equivalences = match expr {
54 MirRelationExpr::Constant { rows, typ } => {
55 // Trawl `rows` for any constant information worth recording.
56 // Literal columns may be valuable; non-nullability could be too.
57 let mut equivalences = EquivalenceClasses::default();
58 if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
59 // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
60 let len = row.iter().count();
61 let mut common = Vec::with_capacity(len);
62 common.extend(row.iter().map(Some));
63 // Prep initial nullability information.
64 let mut nullable_cols = common
65 .iter()
66 .map(|datum| datum == &Some(Datum::Null))
67 .collect::<Vec<_>>();
68
69 for (row, _cnt) in rows.iter() {
70 for ((datum, common), nullable) in row
71 .iter()
72 .zip(common.iter_mut())
73 .zip(nullable_cols.iter_mut())
74 {
75 if Some(datum) != *common {
76 *common = None;
77 }
78 if datum == Datum::Null {
79 *nullable = true;
80 }
81 }
82 }
83 for (index, common) in common.into_iter().enumerate() {
84 if let Some(datum) = common {
85 equivalences.classes.push(vec![
86 MirScalarExpr::Column(index),
87 MirScalarExpr::literal_ok(
88 datum,
89 typ.column_types[index].scalar_type.clone(),
90 ),
91 ]);
92 }
93 }
94 // If any columns are non-null, introduce this fact.
95 if nullable_cols.iter().any(|x| !*x) {
96 let mut class = vec![MirScalarExpr::literal_false()];
97 for (index, nullable) in nullable_cols.iter().enumerate() {
98 if !*nullable {
99 class.push(MirScalarExpr::column(index).call_is_null());
100 }
101 }
102 equivalences.classes.push(class);
103 }
104 }
105 Some(equivalences)
106 }
107 MirRelationExpr::Get { id, typ, .. } => {
108 let mut equivalences = Some(EquivalenceClasses::default());
109 // Find local identifiers, but nothing for external identifiers.
110 if let Id::Local(id) = id {
111 if let Some(offset) = depends.bindings().get(id) {
112 // It is possible we have derived nothing for a recursive term
113 if let Some(result) = results.get(*offset) {
114 equivalences.clone_from(result);
115 } else {
116 // No top element was prepared.
117 // This means we are executing pessimistically,
118 // but perhaps we must because optimism is off.
119 }
120 }
121 }
122 // Incorporate statements about column nullability.
123 let mut non_null_cols = vec![MirScalarExpr::literal_false()];
124 for (index, col_type) in typ.column_types.iter().enumerate() {
125 if !col_type.nullable {
126 non_null_cols.push(MirScalarExpr::column(index).call_is_null());
127 }
128 }
129 if non_null_cols.len() > 1 {
130 if let Some(equivalences) = equivalences.as_mut() {
131 equivalences.classes.push(non_null_cols);
132 }
133 }
134
135 equivalences
136 }
137 MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
138 MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
139 MirRelationExpr::Project { outputs, .. } => {
140 // restrict equivalences, and introduce equivalences for repeated outputs.
141 let mut equivalences = results.get(index - 1).unwrap().clone();
142 equivalences
143 .as_mut()
144 .map(|e| e.project(outputs.iter().cloned()));
145 equivalences
146 }
147 MirRelationExpr::Map { scalars, .. } => {
148 // introduce equivalences for new columns and expressions that define them.
149 let mut equivalences = results.get(index - 1).unwrap().clone();
150 if let Some(equivalences) = &mut equivalences {
151 let input_arity = depends.results::<Arity>()[index - 1];
152 for (pos, expr) in scalars.iter().enumerate() {
153 equivalences
154 .classes
155 .push(vec![MirScalarExpr::Column(input_arity + pos), expr.clone()]);
156 }
157 }
158 equivalences
159 }
160 MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
161 MirRelationExpr::Filter { predicates, .. } => {
162 let mut equivalences = results.get(index - 1).unwrap().clone();
163 if let Some(equivalences) = &mut equivalences {
164 let mut class = predicates.clone();
165 class.push(MirScalarExpr::literal_ok(
166 Datum::True,
167 mz_repr::ScalarType::Bool,
168 ));
169 equivalences.classes.push(class);
170 }
171 equivalences
172 }
173 MirRelationExpr::Join { equivalences, .. } => {
174 // Collect equivalences from all inputs;
175 let expr_index = index;
176 let mut children = depends
177 .children_of_rev(expr_index, expr.children().count())
178 .collect::<Vec<_>>();
179 children.reverse();
180
181 let arity = depends.results::<Arity>();
182 let mut columns = 0;
183 let mut result = Some(EquivalenceClasses::default());
184 for child in children.into_iter() {
185 let input_arity = arity[child];
186 let equivalences = results[child].clone();
187 if let Some(mut equivalences) = equivalences {
188 let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
189 equivalences.permute(&permutation);
190 result
191 .as_mut()
192 .map(|e| e.classes.extend(equivalences.classes));
193 } else {
194 result = None;
195 }
196 columns += input_arity;
197 }
198
199 // Fold join equivalences into our results.
200 result
201 .as_mut()
202 .map(|e| e.classes.extend(equivalences.iter().cloned()));
203 result
204 }
205 MirRelationExpr::Reduce {
206 group_key,
207 aggregates,
208 ..
209 } => {
210 let input_arity = depends.results::<Arity>()[index - 1];
211 let mut equivalences = results.get(index - 1).unwrap().clone();
212 if let Some(equivalences) = &mut equivalences {
213 // Introduce keys column equivalences as if a map, then project to those columns.
214 // This should retain as much information as possible about these columns.
215 for (pos, expr) in group_key.iter().enumerate() {
216 equivalences
217 .classes
218 .push(vec![MirScalarExpr::Column(input_arity + pos), expr.clone()]);
219 }
220
221 // Having added classes to `equivalences`, we should minimize the classes to fold the
222 // information in before applying the `project`, to set it up for success.
223 equivalences.minimize(None);
224
225 // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
226 let extended = equivalences.clone();
227 // Now project down the equivalences, as we will extend them in terms of the output columns.
228 equivalences.project(input_arity..(input_arity + group_key.len()));
229
230 // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
231 // They also pass through equivalences of them and other constant columns (e.g. key columns).
232 // However, it is not correct to simply project onto these columns, as relationships amongst
233 // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
234 // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
235 // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
236 for (index, aggregate) in aggregates.iter().enumerate() {
237 if aggregate_is_input(&aggregate.func) {
238 let mut temp_equivs = extended.clone();
239 temp_equivs.classes.push(vec![
240 MirScalarExpr::column(input_arity + group_key.len()),
241 aggregate.expr.clone(),
242 ]);
243 temp_equivs.minimize(None);
244 temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
245 let columns = (0..group_key.len())
246 .chain(std::iter::once(group_key.len() + index))
247 .collect::<Vec<_>>();
248 temp_equivs.permute(&columns[..]);
249 equivalences.classes.extend(temp_equivs.classes);
250 }
251 }
252 }
253 equivalences
254 }
255 MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
256 MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
257 MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
258 MirRelationExpr::Union { .. } => {
259 let expr_index = index;
260 let mut child_equivs = depends
261 .children_of_rev(expr_index, expr.children().count())
262 .flat_map(|c| &results[c]);
263 if let Some(first) = child_equivs.next() {
264 Some(first.union_many(child_equivs))
265 } else {
266 None
267 }
268 }
269 MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
270 };
271
272 let expr_type = depends.results::<RelationType>()[index].as_ref();
273 equivalences
274 .as_mut()
275 .map(|e| e.minimize(expr_type.map(|x| &x[..])));
276 equivalences
277 }
278
279 fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
280 Some(Box::new(EQLattice))
281 }
282}
283
284struct EQLattice;
285
286impl Lattice<Option<EquivalenceClasses>> for EQLattice {
287 fn top(&self) -> Option<EquivalenceClasses> {
288 None
289 }
290
291 fn meet_assign(
292 &self,
293 a: &mut Option<EquivalenceClasses>,
294 b: Option<EquivalenceClasses>,
295 ) -> bool {
296 match (&mut *a, b) {
297 (_, None) => false,
298 (None, b) => {
299 *a = b;
300 true
301 }
302 (Some(a), Some(b)) => {
303 let mut c = a.union(&b);
304 std::mem::swap(a, &mut c);
305 a != &mut c
306 }
307 }
308 }
309}
310
311/// A compact representation of classes of expressions that must be equivalent.
312///
313/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
314/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
315/// and any other element of that list can be replaced by it.
316///
317/// The classes are meant to be minimized, with each expression as reduced as it can be,
318/// and all classes sharing an element merged.
319#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
320pub struct EquivalenceClasses {
321 /// Multiple lists of equivalent expressions, each representing an equivalence class.
322 ///
323 /// The first element should be the "canonical" simplest element, that any other element
324 /// can be replaced by.
325 /// These classes are unified whenever possible, to minimize the number of classes.
326 /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
327 /// which refreshes both `self.classes` and `self.remap`.
328 pub classes: Vec<Vec<MirScalarExpr>>,
329
330 /// An expression simplification map.
331 ///
332 /// This map reflects an equivalence relation based on a prior version of `self.classes`.
333 /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
334 /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
335 ///
336 /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
337 /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
338 /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
339 remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
340}
341
342/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
343/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
344/// [`HumanizedEquivalenceClasses`].
345impl std::fmt::Display for EquivalenceClasses {
346 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
347 HumanizedEquivalenceClasses {
348 equivalence_classes: self,
349 cols: None,
350 mode: HumanizedExplain::default(),
351 }
352 .fmt(f)
353 }
354}
355
356/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
357/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
358/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
359/// `Display` nor `HumanizedExpr` is defined in this crate.)
360#[derive(Debug)]
361pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
362 /// The [`EquivalenceClasses`] to be humanized.
363 pub equivalence_classes: &'a EquivalenceClasses,
364 /// An optional vector of inferred column names to be used when rendering
365 /// column references in expressions.
366 pub cols: Option<&'a Vec<String>>,
367 /// The rendering mode to use. See `HumanizerMode` for details.
368 pub mode: M,
369}
370
371impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
372 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
373 // Only show `classes`.
374 // (The following hopefully avoids allocating any of the intermediate composite strings.)
375 let classes = self.equivalence_classes.classes.iter().map(|class| {
376 format!(
377 "{}",
378 bracketed(
379 "[",
380 "]",
381 separated(
382 ", ",
383 class.iter().map(|expr| self.mode.expr(expr, self.cols))
384 )
385 )
386 )
387 });
388 write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
389 }
390}
391
392impl EquivalenceClasses {
393 /// Comparator function for the complexity of scalar expressions. Simpler expressions are
394 /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
395 pub fn mir_scalar_expr_complexity(
396 e1: &MirScalarExpr,
397 e2: &MirScalarExpr,
398 ) -> std::cmp::Ordering {
399 use MirScalarExpr::*;
400 use std::cmp::Ordering::*;
401 match (e1, e2) {
402 (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
403 (Literal(_, _), _) => Less,
404 (_, Literal(_, _)) => Greater,
405 (Column(_), Column(_)) => e1.cmp(e2),
406 (Column(_), _) => Less,
407 (_, Column(_)) => Greater,
408 (x, y) => {
409 // General expressions should be ordered by their size,
410 // to ensure we only simplify expressions by substitution.
411 // If same size, then fall back to the expressions' Ord.
412 match x.size().cmp(&y.size()) {
413 Equal => x.cmp(y),
414 other => other,
415 }
416 }
417 }
418 }
419
420 /// Sorts and deduplicates each class, removing literal errors.
421 ///
422 /// This method does not ensure equivalence relation structure, but instead performs
423 /// only minimal structural clean-up.
424 fn tidy(&mut self) {
425 for class in self.classes.iter_mut() {
426 // Remove all literal errors, as they cannot be equated to other things.
427 class.retain(|e| !e.is_literal_err());
428 class.sort_by(Self::mir_scalar_expr_complexity);
429 class.dedup();
430 }
431 self.classes.retain(|c| c.len() > 1);
432 self.classes.sort();
433 self.classes.dedup();
434 }
435
436 /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
437 ///
438 /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
439 /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
440 /// whether the equivalence classes it represents have experienced any changes since the
441 /// last refresh.
442 fn refresh(&mut self) -> bool {
443 self.tidy();
444
445 // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
446 // If it contains the same number of expressions as `self.classes`, and for every expression in
447 // `self.classes` the two agree on the representative, they are identical.
448 if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
449 && self
450 .classes
451 .iter()
452 .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
453 {
454 // No change, so return false.
455 return false;
456 }
457
458 // Optimistically build the `remap` we would want.
459 // Note if any unions would be required, in which case we have further work to do,
460 // including re-forming `self.classes`.
461 let mut union_find = BTreeMap::default();
462 let mut dirtied = false;
463 for class in self.classes.iter() {
464 for expr in class.iter() {
465 if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
466 // A merge is required, but have the more complex expression point at the simpler one.
467 // This allows `union_find` to end as the `remap` for the new `classes` we form, with
468 // the only required work being compressing all the paths.
469 if Self::mir_scalar_expr_complexity(&other, &class[0])
470 == std::cmp::Ordering::Less
471 {
472 union_find.union(&class[0], &other);
473 } else {
474 union_find.union(&other, &class[0]);
475 }
476 dirtied = true;
477 }
478 }
479 }
480 if dirtied {
481 let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
482 for class in self.classes.drain(..) {
483 for expr in class {
484 let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
485 classes.entry(root).or_default().push(expr);
486 }
487 }
488 self.classes = classes.into_values().collect();
489 self.tidy();
490 }
491
492 let changed = self.remap != union_find;
493 self.remap = union_find;
494 changed
495 }
496
497 /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
498 ///
499 /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
500 pub fn minimize(&mut self, columns: Option<&[ColumnType]>) {
501 // Repeatedly, we reduce each of the classes themselves, then unify the classes.
502 // This should strictly reduce complexity, and reach a fixed point.
503 // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
504
505 // We should not rely on nullability information present in `column_types`. (Doing this
506 // every time just before calling `reduce` was found to be a bottleneck during incident-217,
507 // so now we do this nullability tweaking only once here.)
508 let mut columns = columns.map(|x| x.to_vec());
509 let mut nonnull = Vec::new();
510 if let Some(columns) = columns.as_mut() {
511 for (index, col) in columns.iter_mut().enumerate() {
512 let is_null = MirScalarExpr::column(index).call_is_null();
513 if !col.nullable
514 && self
515 .remap
516 .get(&is_null)
517 .map(|e| !e.is_literal_false())
518 .unwrap_or(true)
519 {
520 nonnull.push(is_null);
521 }
522 col.nullable = true;
523 }
524 }
525 if !nonnull.is_empty() {
526 nonnull.push(MirScalarExpr::literal_false());
527 self.classes.push(nonnull);
528 }
529
530 // Ensure `self.classes` and `self.remap` are equivalence relations.
531 // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
532 // We have also likely mutated `self.classes` just above with non-nullability information.
533 self.refresh();
534
535 // Termination will be detected by comparing to the map of equivalence classes.
536 let mut previous = Some(self.remap.clone());
537 while let Some(prev) = previous {
538 // Attempt to add new equivalences.
539 let novel = self.expand();
540 if !novel.is_empty() {
541 self.classes.extend(novel);
542 self.refresh();
543 }
544
545 // We continue as long as any simplification has occurred.
546 // An expression can be simplified, a duplication found, or two classes unified.
547 let mut stable = false;
548 while !stable {
549 stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
550 }
551
552 // Termination detection.
553 if prev != self.remap {
554 previous = Some(self.remap.clone());
555 } else {
556 previous = None;
557 }
558 }
559 }
560
561 /// Proposes new equivalences that are likely to be novel.
562 ///
563 /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
564 /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
565 /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
566 /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
567 /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
568 /// number of times we call and amount of work we do in `self.refresh()`.
569 fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
570 // Consider expanding `self.classes` with novel equivalences.
571 let mut novel = self.implications();
572 for class in novel.iter_mut() {
573 // reduce each expression to its canonical form.
574 for expr in class.iter_mut() {
575 self.remap.reduce_expr(expr);
576 }
577 class.sort();
578 class.dedup();
579 // for a class to be interesting we require at least two elements that do not reference the same root.
580 let common_class = class
581 .iter()
582 .map(|x| self.remap.get(x))
583 .reduce(|prev, this| if prev == this { prev } else { None });
584 if class.len() == 1 || common_class != Some(None) {
585 class.clear();
586 }
587 }
588 novel.retain(|c| !c.is_empty());
589 novel
590 }
591
592 /// Derives potentially novel equivalences without regard for minimization.
593 ///
594 /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
595 /// and therefore should not be used in `minimize_once`. They are still potentially important, but
596 /// required additional guardrails to ensure we reach a fixed point.
597 ///
598 /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
599 /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
600 /// For example, before producing a new equivalence, one could check that the involved terms are not
601 /// already present in the same class.
602 fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
603 let mut new_equivalences = Vec::new();
604
605 // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
606 let mut non_null = std::collections::BTreeSet::default();
607 for class in self.classes.iter() {
608 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
609 for e in class.iter() {
610 if let MirScalarExpr::CallUnary {
611 func: mz_expr::UnaryFunc::IsNull(_),
612 expr,
613 } = e
614 {
615 expr.non_null_requirements(&mut non_null);
616 }
617 }
618 }
619 }
620 // If we see `true == foo` we can add the non-null implications of `foo`.
621 // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
622 // an important idiom to identify for how we express predicates.
623 for class in self.classes.iter() {
624 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
625 for expr in class.iter() {
626 expr.non_null_requirements(&mut non_null);
627 }
628 }
629 }
630 // Only keep constraints that are not already known.
631 // Known constraints will present as `COL(_) IS NULL == false`,
632 // which can only happen if `false` is present, and both terms
633 // map to the same canonical representative>
634 let lit_false = MirScalarExpr::literal_false();
635 let target = self.remap.get(&lit_false);
636 if target.is_some() {
637 non_null.retain(|c| {
638 let is_null = MirScalarExpr::column(*c).call_is_null();
639 self.remap.get(&is_null) != target
640 });
641 }
642 if !non_null.is_empty() {
643 let mut class = Vec::with_capacity(non_null.len() + 1);
644 class.push(MirScalarExpr::literal_false());
645 class.extend(
646 non_null
647 .into_iter()
648 .map(|c| MirScalarExpr::column(c).call_is_null()),
649 );
650 new_equivalences.push(class);
651 }
652
653 // If we see records formed from other expressions, we can equate the expressions with
654 // accessors applied to the class of the record former. In `minimize_once` we reduce by
655 // equivalence class representative before we perform expression simplification, so we
656 // shoud be able to just use the expression former, rather than find its representative.
657 // The risk, potentially, is that we would apply accessors to the record former and then
658 // just simplify it away learning nothing.
659 for class in self.classes.iter() {
660 for expr in class.iter() {
661 // Record-forming expressions can equate their accessors and their members.
662 if let MirScalarExpr::CallVariadic {
663 func: mz_expr::VariadicFunc::RecordCreate { .. },
664 exprs,
665 } = expr
666 {
667 for (index, e) in exprs.iter().enumerate() {
668 new_equivalences.push(vec![
669 e.clone(),
670 expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
671 mz_expr::func::RecordGet(index),
672 )),
673 ]);
674 }
675 }
676 }
677 }
678
679 // Return all newly established equivalences.
680 new_equivalences
681 }
682
683 /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
684 ///
685 /// This invocation should take roughly linear time.
686 /// It starts with equivalence class invariants maintained (closed under transitivity), and then
687 /// 1. Performs per-expression reduction, including the class structure to replace subexpressions.
688 /// 2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
689 /// 3. Restores the equivalence class invariants.
690 fn minimize_once(&mut self, columns: Option<&[ColumnType]>) -> bool {
691 // 1. Reduce each expression
692 //
693 // This reduction first looks for subexpression substitutions that can be performed,
694 // and then applies expression reduction if column type information is provided.
695 for class in self.classes.iter_mut() {
696 for expr in class.iter_mut() {
697 self.remap.reduce_child(expr);
698 if let Some(columns) = columns {
699 expr.reduce(columns);
700 }
701 }
702 }
703
704 // 2. Identify idioms
705 // E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
706 let mut to_add = Vec::new();
707 for class in self.classes.iter_mut() {
708 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
709 for expr in class.iter() {
710 // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
711 // This substitution replaces a complex expression with several smaller expressions, and cannot
712 // cycle if we follow that practice.
713 if let MirScalarExpr::CallBinary {
714 func: mz_expr::BinaryFunc::Eq,
715 expr1,
716 expr2,
717 } = expr
718 {
719 to_add.push(vec![*expr1.clone(), *expr2.clone()]);
720 to_add.push(vec![
721 MirScalarExpr::literal_false(),
722 expr1.clone().call_is_null(),
723 expr2.clone().call_is_null(),
724 ]);
725 }
726 }
727 // Remove the more complex form of the expression.
728 class.retain(|expr| {
729 if let MirScalarExpr::CallBinary {
730 func: mz_expr::BinaryFunc::Eq,
731 ..
732 } = expr
733 {
734 false
735 } else {
736 true
737 }
738 });
739 for expr in class.iter() {
740 // If TRUE == NOT(X) then FALSE == X is a simpler form.
741 if let MirScalarExpr::CallUnary {
742 func: mz_expr::UnaryFunc::Not(_),
743 expr: e,
744 } = expr
745 {
746 to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
747 }
748 }
749 class.retain(|expr| {
750 if let MirScalarExpr::CallUnary {
751 func: mz_expr::UnaryFunc::Not(_),
752 ..
753 } = expr
754 {
755 false
756 } else {
757 true
758 }
759 });
760 }
761 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
762 for expr in class.iter() {
763 // If FALSE == NOT(X) then TRUE == X is a simpler form.
764 if let MirScalarExpr::CallUnary {
765 func: mz_expr::UnaryFunc::Not(_),
766 expr: e,
767 } = expr
768 {
769 to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
770 }
771 }
772 class.retain(|expr| {
773 if let MirScalarExpr::CallUnary {
774 func: mz_expr::UnaryFunc::Not(_),
775 ..
776 } = expr
777 {
778 false
779 } else {
780 true
781 }
782 });
783 }
784 }
785 self.classes.extend(to_add);
786
787 // 3. Restore equivalence relation structure and observe if any changes result.
788 self.refresh()
789 }
790
791 /// Produce the equivalences present in both inputs.
792 pub fn union(&self, other: &Self) -> Self {
793 self.union_many([other])
794 }
795
796 /// The equivalence classes of terms equivalent in all inputs.
797 ///
798 /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
799 /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
800 /// correct, result.
801 ///
802 /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
803 /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
804 /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
805 pub fn union_many<'a, I>(&self, others: I) -> Self
806 where
807 I: IntoIterator<Item = &'a Self>,
808 {
809 // List of expressions in the intersection, and a proxy equivalence class identifier.
810 let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
811 // Map from expression to a proxy equivalence class identifier.
812 let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
813 for (key, val) in self.remap.iter() {
814 if !rekey.contains_key(val) {
815 rekey.insert(val, rekey.len());
816 }
817 intersection.push((key, rekey[val]));
818 }
819 for other in others {
820 // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
821 let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
822 intersection.retain_mut(|(key, idx)| {
823 if let Some(val) = other.remap.get(key) {
824 if !rekey.contains_key(&(*idx, val)) {
825 rekey.insert((*idx, val), rekey.len());
826 }
827 *idx = rekey[&(*idx, val)];
828 true
829 } else {
830 false
831 }
832 });
833 }
834 let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
835 for (key, vals) in intersection {
836 classes.entry(vals).or_default().push(key.clone())
837 }
838 let classes = classes.into_values().collect::<Vec<_>>();
839 let mut equivalences = EquivalenceClasses {
840 classes,
841 remap: Default::default(),
842 };
843 equivalences.minimize(None);
844 equivalences
845 }
846
847 /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
848 pub fn permute(&mut self, permutation: &[usize]) {
849 for class in self.classes.iter_mut() {
850 for expr in class.iter_mut() {
851 expr.permute(permutation);
852 }
853 }
854 self.remap.clear();
855 self.minimize(None);
856 }
857
858 /// Subject the constraints to the column projection, reworking and removing equivalences.
859 ///
860 /// This method should also introduce equivalences representing any repeated columns.
861 pub fn project<I>(&mut self, output_columns: I)
862 where
863 I: IntoIterator<Item = usize> + Clone,
864 {
865 // Retain the first instance of each column, and record subsequent instances as duplicates.
866 let mut dupes = Vec::new();
867 let mut remap = BTreeMap::default();
868 for (idx, col) in output_columns.into_iter().enumerate() {
869 if let Some(pos) = remap.get(&col) {
870 dupes.push((*pos, idx));
871 } else {
872 remap.insert(col, idx);
873 }
874 }
875
876 // Some expressions may be "localized" in that they only reference columns in `output_columns`.
877 // Many expressions may not be localized, but may reference canonical non-localized expressions
878 // for classes that contain a localized expression; in that case we can "backport" the localized
879 // expression to give expressions referencing the canonical expression a shot at localization.
880 //
881 // Expressions should only contain instances of canonical expressions, and so we shouldn't need
882 // to look any further than backporting those. Backporting should have the property that the simplest
883 // localized expression in each class does not contain any non-localized canonical expressions
884 // (as that would make it non-localized); our backporting of non-localized canonicals with localized
885 // expressions should never fire a second
886
887 // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
888 // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
889 // As we find localized expressions, we can replace uses of their equivalent representative with them,
890 // which may allow further expression localization.
891 // We continue the process until no further classes can be localized.
892
893 // A map from representatives to our first localization of their equivalence class.
894 let mut localized = false;
895 while !localized {
896 localized = true;
897 let mut current_map = BTreeMap::default();
898 for class in self.classes.iter_mut() {
899 if !class[0].support().iter().all(|c| remap.contains_key(c)) {
900 if let Some(pos) = class
901 .iter()
902 .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
903 {
904 class.swap(0, pos);
905 localized = false;
906 }
907 }
908 for expr in class[1..].iter() {
909 current_map.insert(expr.clone(), class[0].clone());
910 }
911 }
912
913 // attempt to replace representatives with equivalent localizeable expressions.
914 for class_index in 0..self.classes.len() {
915 for index in 0..self.classes[class_index].len() {
916 current_map.reduce_child(&mut self.classes[class_index][index]);
917 }
918 }
919 // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
920 }
921
922 // Localize all localizable expressions and discard others.
923 for class in self.classes.iter_mut() {
924 class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
925 for expr in class.iter_mut() {
926 expr.permute_map(&remap);
927 }
928 }
929 self.classes.retain(|c| c.len() > 1);
930 // If column repetitions, introduce them as equivalences.
931 // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
932 for (col1, col2) in dupes {
933 self.classes.push(vec![
934 MirScalarExpr::Column(col1),
935 MirScalarExpr::Column(col2),
936 ]);
937 }
938 self.remap.clear();
939 self.minimize(None);
940 }
941
942 /// True if any equivalence class contains two distinct non-error literals.
943 pub fn unsatisfiable(&self) -> bool {
944 for class in self.classes.iter() {
945 let mut literal_ok = None;
946 for expr in class.iter() {
947 if let MirScalarExpr::Literal(Ok(row), _) = expr {
948 if literal_ok.is_some() && literal_ok != Some(row) {
949 return true;
950 } else {
951 literal_ok = Some(row);
952 }
953 }
954 }
955 }
956 false
957 }
958
959 /// Returns a map that can be used to replace (sub-)expressions.
960 pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
961 &self.remap
962 }
963
964 /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
965 ///
966 /// This test bails out as soon as it sees a non-literal, and may have false negatives
967 /// if the data are not sorted with literals at the front.
968 fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
969 where
970 P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
971 {
972 class
973 .iter()
974 .take_while(|e| e.is_literal())
975 .filter_map(|e| e.as_literal())
976 .any(move |e| predicate(&e))
977 }
978}
979
980/// A type capable of simplifying `MirScalarExpr`s.
981pub trait ExpressionReducer {
982 /// Attempt to replace `expr` itself with another expression.
983 /// Returns true if it does so.
984 fn replace(&self, expr: &mut MirScalarExpr) -> bool;
985 /// Attempt to replace any subexpressions of `expr` with other expressions.
986 /// Returns true if it does so.
987 fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
988 let mut simplified = false;
989 simplified = simplified || self.reduce_child(expr);
990 simplified = simplified || self.replace(expr);
991 simplified
992 }
993 /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
994 /// Returns true if it does so.
995 fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
996 let mut simplified = false;
997 match expr {
998 MirScalarExpr::CallBinary { expr1, expr2, .. } => {
999 simplified = self.reduce_expr(expr1) || simplified;
1000 simplified = self.reduce_expr(expr2) || simplified;
1001 }
1002 MirScalarExpr::CallUnary { expr, .. } => {
1003 simplified = self.reduce_expr(expr) || simplified;
1004 }
1005 MirScalarExpr::CallVariadic { exprs, .. } => {
1006 for expr in exprs.iter_mut() {
1007 simplified = self.reduce_expr(expr) || simplified;
1008 }
1009 }
1010 MirScalarExpr::If { cond: _, then, els } => {
1011 // Do not simplify `cond`, as we cannot ensure the simplification
1012 // continues to hold as expressions migrate around.
1013 simplified = self.reduce_expr(then) || simplified;
1014 simplified = self.reduce_expr(els) || simplified;
1015 }
1016 _ => {}
1017 }
1018 simplified
1019 }
1020}
1021
1022impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1023 /// Perform any exact replacement for `expr`, report if it had an effect.
1024 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1025 if let Some(other) = self.get(expr) {
1026 if other != &expr {
1027 expr.clone_from(other);
1028 return true;
1029 }
1030 }
1031 false
1032 }
1033}
1034
1035impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1036 /// Perform any exact replacement for `expr`, report if it had an effect.
1037 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1038 if let Some(other) = self.get(expr) {
1039 if other != expr {
1040 expr.clone_from(other);
1041 return true;
1042 }
1043 }
1044 false
1045 }
1046}
1047
1048trait UnionFind<T> {
1049 /// Sets `self[x]` to the root from `x`, and returns a reference to the root.
1050 fn find<'a>(&'a mut self, x: &T) -> Option<&'a T>;
1051 /// Ensures that `x` and `y` have the same root.
1052 fn union(&mut self, x: &T, y: &T);
1053}
1054
1055impl<T: Clone + Ord> UnionFind<T> for BTreeMap<T, T> {
1056 fn find<'a>(&'a mut self, x: &T) -> Option<&'a T> {
1057 if !self.contains_key(x) {
1058 None
1059 } else {
1060 if self[x] != self[&self[x]] {
1061 // Path halving
1062 let mut y = self[x].clone();
1063 while y != self[&y] {
1064 let grandparent = self[&self[&y]].clone();
1065 *self.get_mut(&y).unwrap() = grandparent;
1066 y.clone_from(&self[&y]);
1067 }
1068 *self.get_mut(x).unwrap() = y;
1069 }
1070 Some(&self[x])
1071 }
1072 }
1073
1074 fn union(&mut self, x: &T, y: &T) {
1075 match (self.find(x).is_some(), self.find(y).is_some()) {
1076 (true, true) => {
1077 if self[x] != self[y] {
1078 let root_x = self[x].clone();
1079 let root_y = self[y].clone();
1080 self.insert(root_x, root_y);
1081 }
1082 }
1083 (false, true) => {
1084 self.insert(x.clone(), self[y].clone());
1085 }
1086 (true, false) => {
1087 self.insert(y.clone(), self[x].clone());
1088 }
1089 (false, false) => {
1090 self.insert(x.clone(), x.clone());
1091 self.insert(y.clone(), x.clone());
1092 }
1093 }
1094 }
1095}
1096
1097/// True iff the aggregate function returns an input datum.
1098fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1099 match aggregate {
1100 AggregateFunc::MaxInt16
1101 | AggregateFunc::MaxInt32
1102 | AggregateFunc::MaxInt64
1103 | AggregateFunc::MaxUInt16
1104 | AggregateFunc::MaxUInt32
1105 | AggregateFunc::MaxUInt64
1106 | AggregateFunc::MaxMzTimestamp
1107 | AggregateFunc::MaxFloat32
1108 | AggregateFunc::MaxFloat64
1109 | AggregateFunc::MaxBool
1110 | AggregateFunc::MaxString
1111 | AggregateFunc::MaxDate
1112 | AggregateFunc::MaxTimestamp
1113 | AggregateFunc::MaxTimestampTz
1114 | AggregateFunc::MinInt16
1115 | AggregateFunc::MinInt32
1116 | AggregateFunc::MinInt64
1117 | AggregateFunc::MinUInt16
1118 | AggregateFunc::MinUInt32
1119 | AggregateFunc::MinUInt64
1120 | AggregateFunc::MinMzTimestamp
1121 | AggregateFunc::MinFloat32
1122 | AggregateFunc::MinFloat64
1123 | AggregateFunc::MinBool
1124 | AggregateFunc::MinString
1125 | AggregateFunc::MinDate
1126 | AggregateFunc::MinTimestamp
1127 | AggregateFunc::MinTimestampTz
1128 | AggregateFunc::Any
1129 | AggregateFunc::All => true,
1130 _ => false,
1131 }
1132}