mz_transform/analysis/equivalences.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use itertools::Itertools;
22use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
23use mz_expr::explain::{HumanizedExplain, HumanizerMode};
24use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
25use mz_ore::str::{bracketed, separated};
26use mz_repr::{ColumnType, Datum};
27
28use crate::analysis::{Analysis, Lattice};
29use crate::analysis::{Arity, RelationType};
30use crate::analysis::{Derived, DerivedBuilder};
31
32/// Pulls up and pushes down predicate information represented as equivalences
33#[derive(Debug, Default)]
34pub struct Equivalences;
35
36impl Analysis for Equivalences {
37 // A `Some(list)` indicates a list of classes of equivalent expressions.
38 // A `None` indicates all expressions are equivalent, including contradictions;
39 // this is only possible for the empty collection, and as an initial result for
40 // unconstrained recursive terms.
41 type Value = Option<EquivalenceClasses>;
42
43 fn announce_dependencies(builder: &mut DerivedBuilder) {
44 builder.require(Arity);
45 builder.require(RelationType); // needed for expression reduction.
46 }
47
48 fn derive(
49 &self,
50 expr: &MirRelationExpr,
51 index: usize,
52 results: &[Self::Value],
53 depends: &Derived,
54 ) -> Self::Value {
55 let mut equivalences = match expr {
56 MirRelationExpr::Constant { rows, typ } => {
57 // Trawl `rows` for any constant information worth recording.
58 // Literal columns may be valuable; non-nullability could be too.
59 let mut equivalences = EquivalenceClasses::default();
60 if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
61 // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
62 let len = row.iter().count();
63 let mut common = Vec::with_capacity(len);
64 common.extend(row.iter().map(Some));
65 // Prep initial nullability information.
66 let mut nullable_cols = common
67 .iter()
68 .map(|datum| datum == &Some(Datum::Null))
69 .collect::<Vec<_>>();
70
71 for (row, _cnt) in rows.iter() {
72 for ((datum, common), nullable) in row
73 .iter()
74 .zip_eq(common.iter_mut())
75 .zip_eq(nullable_cols.iter_mut())
76 {
77 if Some(datum) != *common {
78 *common = None;
79 }
80 if datum == Datum::Null {
81 *nullable = true;
82 }
83 }
84 }
85 for (index, common) in common.into_iter().enumerate() {
86 if let Some(datum) = common {
87 equivalences.classes.push(vec![
88 MirScalarExpr::column(index),
89 MirScalarExpr::literal_ok(
90 datum,
91 typ.column_types[index].scalar_type.clone(),
92 ),
93 ]);
94 }
95 }
96 // If any columns are non-null, introduce this fact.
97 if nullable_cols.iter().any(|x| !*x) {
98 let mut class = vec![MirScalarExpr::literal_false()];
99 for (index, nullable) in nullable_cols.iter().enumerate() {
100 if !*nullable {
101 class.push(MirScalarExpr::column(index).call_is_null());
102 }
103 }
104 equivalences.classes.push(class);
105 }
106 }
107 Some(equivalences)
108 }
109 MirRelationExpr::Get { id, typ, .. } => {
110 let mut equivalences = Some(EquivalenceClasses::default());
111 // Find local identifiers, but nothing for external identifiers.
112 if let Id::Local(id) = id {
113 if let Some(offset) = depends.bindings().get(id) {
114 // It is possible we have derived nothing for a recursive term
115 if let Some(result) = results.get(*offset) {
116 equivalences.clone_from(result);
117 } else {
118 // No top element was prepared.
119 // This means we are executing pessimistically,
120 // but perhaps we must because optimism is off.
121 }
122 }
123 }
124 // Incorporate statements about column nullability.
125 let mut non_null_cols = vec![MirScalarExpr::literal_false()];
126 for (index, col_type) in typ.column_types.iter().enumerate() {
127 if !col_type.nullable {
128 non_null_cols.push(MirScalarExpr::column(index).call_is_null());
129 }
130 }
131 if non_null_cols.len() > 1 {
132 if let Some(equivalences) = equivalences.as_mut() {
133 equivalences.classes.push(non_null_cols);
134 }
135 }
136
137 equivalences
138 }
139 MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
140 MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
141 MirRelationExpr::Project { outputs, .. } => {
142 // restrict equivalences, and introduce equivalences for repeated outputs.
143 let mut equivalences = results.get(index - 1).unwrap().clone();
144 equivalences
145 .as_mut()
146 .map(|e| e.project(outputs.iter().cloned()));
147 equivalences
148 }
149 MirRelationExpr::Map { scalars, .. } => {
150 // introduce equivalences for new columns and expressions that define them.
151 let mut equivalences = results.get(index - 1).unwrap().clone();
152 if let Some(equivalences) = &mut equivalences {
153 let input_arity = depends.results::<Arity>()[index - 1];
154 for (pos, expr) in scalars.iter().enumerate() {
155 equivalences
156 .classes
157 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
158 }
159 }
160 equivalences
161 }
162 MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
163 MirRelationExpr::Filter { predicates, .. } => {
164 let mut equivalences = results.get(index - 1).unwrap().clone();
165 if let Some(equivalences) = &mut equivalences {
166 let mut class = predicates.clone();
167 class.push(MirScalarExpr::literal_ok(
168 Datum::True,
169 mz_repr::ScalarType::Bool,
170 ));
171 equivalences.classes.push(class);
172 }
173 equivalences
174 }
175 MirRelationExpr::Join { equivalences, .. } => {
176 // Collect equivalences from all inputs;
177 let expr_index = index;
178 let mut children = depends
179 .children_of_rev(expr_index, expr.children().count())
180 .collect::<Vec<_>>();
181 children.reverse();
182
183 let arity = depends.results::<Arity>();
184 let mut columns = 0;
185 let mut result = Some(EquivalenceClasses::default());
186 for child in children.into_iter() {
187 let input_arity = arity[child];
188 let equivalences = results[child].clone();
189 if let Some(mut equivalences) = equivalences {
190 let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
191 equivalences.permute(&permutation);
192 result
193 .as_mut()
194 .map(|e| e.classes.extend(equivalences.classes));
195 } else {
196 result = None;
197 }
198 columns += input_arity;
199 }
200
201 // Fold join equivalences into our results.
202 result
203 .as_mut()
204 .map(|e| e.classes.extend(equivalences.iter().cloned()));
205 result
206 }
207 MirRelationExpr::Reduce {
208 group_key,
209 aggregates,
210 ..
211 } => {
212 let input_arity = depends.results::<Arity>()[index - 1];
213 let mut equivalences = results.get(index - 1).unwrap().clone();
214 if let Some(equivalences) = &mut equivalences {
215 // Introduce keys column equivalences as if a map, then project to those columns.
216 // This should retain as much information as possible about these columns.
217 for (pos, expr) in group_key.iter().enumerate() {
218 equivalences
219 .classes
220 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
221 }
222
223 // Having added classes to `equivalences`, we should minimize the classes to fold the
224 // information in before applying the `project`, to set it up for success.
225 equivalences.minimize(None);
226
227 // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
228 let extended = equivalences.clone();
229 // Now project down the equivalences, as we will extend them in terms of the output columns.
230 equivalences.project(input_arity..(input_arity + group_key.len()));
231
232 // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
233 // They also pass through equivalences of them and other constant columns (e.g. key columns).
234 // However, it is not correct to simply project onto these columns, as relationships amongst
235 // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
236 // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
237 // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
238 for (index, aggregate) in aggregates.iter().enumerate() {
239 if aggregate_is_input(&aggregate.func) {
240 let mut temp_equivs = extended.clone();
241 temp_equivs.classes.push(vec![
242 MirScalarExpr::column(input_arity + group_key.len()),
243 aggregate.expr.clone(),
244 ]);
245 temp_equivs.minimize(None);
246 temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
247 let columns = (0..group_key.len())
248 .chain(std::iter::once(group_key.len() + index))
249 .collect::<Vec<_>>();
250 temp_equivs.permute(&columns[..]);
251 equivalences.classes.extend(temp_equivs.classes);
252 }
253 }
254 }
255 equivalences
256 }
257 MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
258 MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
259 MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
260 MirRelationExpr::Union { .. } => {
261 let expr_index = index;
262 let mut child_equivs = depends
263 .children_of_rev(expr_index, expr.children().count())
264 .flat_map(|c| &results[c]);
265 if let Some(first) = child_equivs.next() {
266 Some(first.union_many(child_equivs))
267 } else {
268 None
269 }
270 }
271 MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
272 };
273
274 let expr_type = depends.results::<RelationType>()[index].as_ref();
275 equivalences
276 .as_mut()
277 .map(|e| e.minimize(expr_type.map(|x| &x[..])));
278 equivalences
279 }
280
281 fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
282 Some(Box::new(EQLattice))
283 }
284}
285
286struct EQLattice;
287
288impl Lattice<Option<EquivalenceClasses>> for EQLattice {
289 fn top(&self) -> Option<EquivalenceClasses> {
290 None
291 }
292
293 fn meet_assign(
294 &self,
295 a: &mut Option<EquivalenceClasses>,
296 b: Option<EquivalenceClasses>,
297 ) -> bool {
298 match (&mut *a, b) {
299 (_, None) => false,
300 (None, b) => {
301 *a = b;
302 true
303 }
304 (Some(a), Some(b)) => {
305 let mut c = a.union(&b);
306 std::mem::swap(a, &mut c);
307 a != &mut c
308 }
309 }
310 }
311}
312
313/// A compact representation of classes of expressions that must be equivalent.
314///
315/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
316/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
317/// and any other element of that list can be replaced by it.
318///
319/// The classes are meant to be minimized, with each expression as reduced as it can be,
320/// and all classes sharing an element merged.
321#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
322pub struct EquivalenceClasses {
323 /// Multiple lists of equivalent expressions, each representing an equivalence class.
324 ///
325 /// The first element should be the "canonical" simplest element, that any other element
326 /// can be replaced by.
327 /// These classes are unified whenever possible, to minimize the number of classes.
328 /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
329 /// which refreshes both `self.classes` and `self.remap`.
330 pub classes: Vec<Vec<MirScalarExpr>>,
331
332 /// An expression simplification map.
333 ///
334 /// This map reflects an equivalence relation based on a prior version of `self.classes`.
335 /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
336 /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
337 ///
338 /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
339 /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
340 /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
341 remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
342}
343
344/// An enum that can represent either a standard `EquivalenceClasses` or a
345/// `EquivalenceClassesWithholdingErrors`.
346#[derive(Clone, Debug)]
347pub enum EqClassesImpl {
348 /// Standard equivalence classes.
349 EquivalenceClasses(EquivalenceClasses),
350 /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
351 EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
352}
353
354impl EqClassesImpl {
355 /// Returns a reference to the underlying `EquivalenceClasses`.
356 pub fn equivalence_classes(&self) -> &EquivalenceClasses {
357 match self {
358 EqClassesImpl::EquivalenceClasses(classes) => classes,
359 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
360 &classes.equivalence_classes
361 }
362 }
363 }
364
365 /// Returns a mutable reference to the underlying `EquivalenceClasses`.
366 pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
367 match self {
368 EqClassesImpl::EquivalenceClasses(classes) => classes,
369 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
370 &mut classes.equivalence_classes
371 }
372 }
373 }
374
375 /// Extend the equivalence classes with new equivalences.
376 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
377 match self {
378 EqClassesImpl::EquivalenceClasses(classes) => {
379 classes.classes.extend(equivalences);
380 }
381 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
382 classes.extend_equivalences(equivalences);
383 }
384 }
385 }
386
387 /// Push a new equivalence class.
388 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
389 match self {
390 EqClassesImpl::EquivalenceClasses(classes) => {
391 classes.classes.push(equivalences);
392 }
393 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
394 classes.push_equivalences(equivalences);
395 }
396 }
397 }
398
399 /// Subject the constraints to the column projection, reworking and removing equivalences.
400 pub fn project<I>(&mut self, output_columns: I)
401 where
402 I: IntoIterator<Item = usize> + Clone,
403 {
404 match self {
405 EqClassesImpl::EquivalenceClasses(classes) => {
406 classes.project(output_columns);
407 }
408 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
409 classes.project(output_columns);
410 }
411 }
412 }
413
414 /// Extract the equivalences
415 pub fn extract_equivalences(
416 &mut self,
417 columns: Option<&[ColumnType]>,
418 ) -> Vec<Vec<MirScalarExpr>> {
419 match self {
420 EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
421 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
422 classes.extract_equivalences(columns)
423 }
424 }
425 }
426}
427
428/// A wrapper struct for equivalence classes that witholds errors
429/// from the underlying equivalence classes, reintroducing them
430/// when extracting equivalences.
431#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
432pub struct EquivalenceClassesWithholdingErrors {
433 /// The underlying equivalence classes.
434 pub equivalence_classes: EquivalenceClasses,
435 /// A list of equivalence classes that contain errors, which we hold back
436 pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
437}
438
439impl EquivalenceClassesWithholdingErrors {
440 /// Extend the equivalence classes with new equivalences.
441 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
442 for class in equivalences {
443 self.push_equivalences(class);
444 }
445 }
446
447 /// Push a new equivalence class, which may contain errors.
448 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
449 let (with_errs, without_errs): (Vec<_>, Vec<_>) =
450 equivalences.into_iter().partition(|e| e.contains_err());
451 self.held_back_classes
452 .push((without_errs.first().cloned(), with_errs));
453 if without_errs.len() > 1 {
454 self.equivalence_classes.classes.push(without_errs);
455 }
456 }
457
458 /// Subject the constraints to the column projection, reworking and removing equivalences.
459 /// This method should also introduce equivalences representing any repeated columns.
460 /// This notably does not project the held back classes.
461 pub fn project<I>(&mut self, output_columns: I)
462 where
463 I: IntoIterator<Item = usize> + Clone,
464 {
465 self.equivalence_classes.project(output_columns.clone());
466 }
467
468 /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
469 pub fn extract_equivalences(
470 &mut self,
471 columns: Option<&[ColumnType]>,
472 ) -> Vec<Vec<MirScalarExpr>> {
473 self.equivalence_classes.minimize(columns);
474
475 if self.held_back_classes.is_empty() {
476 self.equivalence_classes.classes.clone()
477 } else {
478 let reducer = self.equivalence_classes.reducer();
479 let mut classes = self.equivalence_classes.classes.clone();
480 for (anchor, mut class) in self.held_back_classes.drain(..) {
481 if let Some(mut anchor) = anchor {
482 // Reduce the anchor expression to its canonical form.
483 reducer.reduce_expr(&mut anchor);
484 class.push(anchor.clone());
485 }
486 classes.push(class.clone());
487 }
488 canonicalize_equivalence_classes(&mut classes);
489 classes
490 }
491 }
492}
493
494/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
495/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
496/// [`HumanizedEquivalenceClasses`].
497impl std::fmt::Display for EquivalenceClasses {
498 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
499 HumanizedEquivalenceClasses {
500 equivalence_classes: self,
501 cols: None,
502 mode: HumanizedExplain::default(),
503 }
504 .fmt(f)
505 }
506}
507
508/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
509/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
510/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
511/// `Display` nor `HumanizedExpr` is defined in this crate.)
512#[derive(Debug)]
513pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
514 /// The [`EquivalenceClasses`] to be humanized.
515 pub equivalence_classes: &'a EquivalenceClasses,
516 /// An optional vector of inferred column names to be used when rendering
517 /// column references in expressions.
518 pub cols: Option<&'a Vec<String>>,
519 /// The rendering mode to use. See `HumanizerMode` for details.
520 pub mode: M,
521}
522
523impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
524 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
525 // Only show `classes`.
526 // (The following hopefully avoids allocating any of the intermediate composite strings.)
527 let classes = self.equivalence_classes.classes.iter().map(|class| {
528 format!(
529 "{}",
530 bracketed(
531 "[",
532 "]",
533 separated(
534 ", ",
535 class.iter().map(|expr| self.mode.expr(expr, self.cols))
536 )
537 )
538 )
539 });
540 write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
541 }
542}
543
544impl EquivalenceClasses {
545 /// Comparator function for the complexity of scalar expressions. Simpler expressions are
546 /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
547 pub fn mir_scalar_expr_complexity(
548 e1: &MirScalarExpr,
549 e2: &MirScalarExpr,
550 ) -> std::cmp::Ordering {
551 use MirScalarExpr::*;
552 use std::cmp::Ordering::*;
553 match (e1, e2) {
554 (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
555 (Literal(_, _), _) => Less,
556 (_, Literal(_, _)) => Greater,
557 (Column(_, _), Column(_, _)) => e1.cmp(e2),
558 (Column(_, _), _) => Less,
559 (_, Column(_, _)) => Greater,
560 (x, y) => {
561 // General expressions should be ordered by their size,
562 // to ensure we only simplify expressions by substitution.
563 // If same size, then fall back to the expressions' Ord.
564 match x.size().cmp(&y.size()) {
565 Equal => x.cmp(y),
566 other => other,
567 }
568 }
569 }
570 }
571
572 /// Sorts and deduplicates each class, removing literal errors.
573 ///
574 /// This method does not ensure equivalence relation structure, but instead performs
575 /// only minimal structural clean-up.
576 fn tidy(&mut self) {
577 for class in self.classes.iter_mut() {
578 // Remove all literal errors, as they cannot be equated to other things.
579 class.retain(|e| !e.is_literal_err());
580 class.sort_by(Self::mir_scalar_expr_complexity);
581 class.dedup();
582 }
583 self.classes.retain(|c| c.len() > 1);
584 self.classes.sort();
585 self.classes.dedup();
586 }
587
588 /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
589 ///
590 /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
591 /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
592 /// whether the equivalence classes it represents have experienced any changes since the
593 /// last refresh.
594 fn refresh(&mut self) -> bool {
595 self.tidy();
596
597 // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
598 // If it contains the same number of expressions as `self.classes`, and for every expression in
599 // `self.classes` the two agree on the representative, they are identical.
600 if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
601 && self
602 .classes
603 .iter()
604 .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
605 {
606 // No change, so return false.
607 return false;
608 }
609
610 // Optimistically build the `remap` we would want.
611 // Note if any unions would be required, in which case we have further work to do,
612 // including re-forming `self.classes`.
613 let mut union_find = BTreeMap::default();
614 let mut dirtied = false;
615 for class in self.classes.iter() {
616 for expr in class.iter() {
617 if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
618 // A merge is required, but have the more complex expression point at the simpler one.
619 // This allows `union_find` to end as the `remap` for the new `classes` we form, with
620 // the only required work being compressing all the paths.
621 if Self::mir_scalar_expr_complexity(&other, &class[0])
622 == std::cmp::Ordering::Less
623 {
624 union_find.union(&class[0], &other);
625 } else {
626 union_find.union(&other, &class[0]);
627 }
628 dirtied = true;
629 }
630 }
631 }
632 if dirtied {
633 let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
634 for class in self.classes.drain(..) {
635 for expr in class {
636 let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
637 classes.entry(root).or_default().push(expr);
638 }
639 }
640 self.classes = classes.into_values().collect();
641 self.tidy();
642 }
643
644 let changed = self.remap != union_find;
645 self.remap = union_find;
646 changed
647 }
648
649 /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
650 ///
651 /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
652 pub fn minimize(&mut self, columns: Option<&[ColumnType]>) {
653 // Repeatedly, we reduce each of the classes themselves, then unify the classes.
654 // This should strictly reduce complexity, and reach a fixed point.
655 // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
656
657 // We should not rely on nullability information present in `column_types`. (Doing this
658 // every time just before calling `reduce` was found to be a bottleneck during incident-217,
659 // so now we do this nullability tweaking only once here.)
660 let mut columns = columns.map(|x| x.to_vec());
661 let mut nonnull = Vec::new();
662 if let Some(columns) = columns.as_mut() {
663 for (index, col) in columns.iter_mut().enumerate() {
664 let is_null = MirScalarExpr::column(index).call_is_null();
665 if !col.nullable
666 && self
667 .remap
668 .get(&is_null)
669 .map(|e| !e.is_literal_false())
670 .unwrap_or(true)
671 {
672 nonnull.push(is_null);
673 }
674 col.nullable = true;
675 }
676 }
677 if !nonnull.is_empty() {
678 nonnull.push(MirScalarExpr::literal_false());
679 self.classes.push(nonnull);
680 }
681
682 // Ensure `self.classes` and `self.remap` are equivalence relations.
683 // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
684 // We have also likely mutated `self.classes` just above with non-nullability information.
685 self.refresh();
686
687 // Termination will be detected by comparing to the map of equivalence classes.
688 let mut previous = Some(self.remap.clone());
689 while let Some(prev) = previous {
690 // Attempt to add new equivalences.
691 let novel = self.expand();
692 if !novel.is_empty() {
693 self.classes.extend(novel);
694 self.refresh();
695 }
696
697 // We continue as long as any simplification has occurred.
698 // An expression can be simplified, a duplication found, or two classes unified.
699 let mut stable = false;
700 while !stable {
701 stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
702 }
703
704 // Termination detection.
705 if prev != self.remap {
706 previous = Some(self.remap.clone());
707 } else {
708 previous = None;
709 }
710 }
711 }
712
713 /// Proposes new equivalences that are likely to be novel.
714 ///
715 /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
716 /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
717 /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
718 /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
719 /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
720 /// number of times we call and amount of work we do in `self.refresh()`.
721 fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
722 // Consider expanding `self.classes` with novel equivalences.
723 let mut novel = self.implications();
724 for class in novel.iter_mut() {
725 // reduce each expression to its canonical form.
726 for expr in class.iter_mut() {
727 self.remap.reduce_expr(expr);
728 }
729 class.sort();
730 class.dedup();
731 // for a class to be interesting we require at least two elements that do not reference the same root.
732 let common_class = class
733 .iter()
734 .map(|x| self.remap.get(x))
735 .reduce(|prev, this| if prev == this { prev } else { None });
736 if class.len() == 1 || common_class != Some(None) {
737 class.clear();
738 }
739 }
740 novel.retain(|c| !c.is_empty());
741 novel
742 }
743
744 /// Derives potentially novel equivalences without regard for minimization.
745 ///
746 /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
747 /// and therefore should not be used in `minimize_once`. They are still potentially important, but
748 /// required additional guardrails to ensure we reach a fixed point.
749 ///
750 /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
751 /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
752 /// For example, before producing a new equivalence, one could check that the involved terms are not
753 /// already present in the same class.
754 fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
755 let mut new_equivalences = Vec::new();
756
757 // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
758 let mut non_null = std::collections::BTreeSet::default();
759 for class in self.classes.iter() {
760 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
761 for e in class.iter() {
762 if let MirScalarExpr::CallUnary {
763 func: mz_expr::UnaryFunc::IsNull(_),
764 expr,
765 } = e
766 {
767 expr.non_null_requirements(&mut non_null);
768 }
769 }
770 }
771 }
772 // If we see `true == foo` we can add the non-null implications of `foo`.
773 // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
774 // an important idiom to identify for how we express predicates.
775 for class in self.classes.iter() {
776 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
777 for expr in class.iter() {
778 expr.non_null_requirements(&mut non_null);
779 }
780 }
781 }
782 // Only keep constraints that are not already known.
783 // Known constraints will present as `COL(_) IS NULL == false`,
784 // which can only happen if `false` is present, and both terms
785 // map to the same canonical representative>
786 let lit_false = MirScalarExpr::literal_false();
787 let target = self.remap.get(&lit_false);
788 if target.is_some() {
789 non_null.retain(|c| {
790 let is_null = MirScalarExpr::column(*c).call_is_null();
791 self.remap.get(&is_null) != target
792 });
793 }
794 if !non_null.is_empty() {
795 let mut class = Vec::with_capacity(non_null.len() + 1);
796 class.push(MirScalarExpr::literal_false());
797 class.extend(
798 non_null
799 .into_iter()
800 .map(|c| MirScalarExpr::column(c).call_is_null()),
801 );
802 new_equivalences.push(class);
803 }
804
805 // If we see records formed from other expressions, we can equate the expressions with
806 // accessors applied to the class of the record former. In `minimize_once` we reduce by
807 // equivalence class representative before we perform expression simplification, so we
808 // shoud be able to just use the expression former, rather than find its representative.
809 // The risk, potentially, is that we would apply accessors to the record former and then
810 // just simplify it away learning nothing.
811 for class in self.classes.iter() {
812 for expr in class.iter() {
813 // Record-forming expressions can equate their accessors and their members.
814 if let MirScalarExpr::CallVariadic {
815 func: mz_expr::VariadicFunc::RecordCreate { .. },
816 exprs,
817 } = expr
818 {
819 for (index, e) in exprs.iter().enumerate() {
820 new_equivalences.push(vec![
821 e.clone(),
822 expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
823 mz_expr::func::RecordGet(index),
824 )),
825 ]);
826 }
827 }
828 }
829 }
830
831 // Return all newly established equivalences.
832 new_equivalences
833 }
834
835 /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
836 ///
837 /// This invocation should take roughly linear time.
838 /// It starts with equivalence class invariants maintained (closed under transitivity), and then
839 /// 1. Performs per-expression reduction, including the class structure to replace subexpressions.
840 /// 2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
841 /// 3. Restores the equivalence class invariants.
842 fn minimize_once(&mut self, columns: Option<&[ColumnType]>) -> bool {
843 // 1. Reduce each expression
844 //
845 // This reduction first looks for subexpression substitutions that can be performed,
846 // and then applies expression reduction if column type information is provided.
847 for class in self.classes.iter_mut() {
848 for expr in class.iter_mut() {
849 self.remap.reduce_child(expr);
850 if let Some(columns) = columns {
851 let orig_expr = expr.clone();
852 expr.reduce(columns);
853 if expr.contains_err() {
854 // Rollback to the original expression if it contains an error.
855 *expr = orig_expr;
856 }
857 }
858 }
859 }
860
861 // 2. Identify idioms
862 // E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
863 let mut to_add = Vec::new();
864 for class in self.classes.iter_mut() {
865 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
866 for expr in class.iter() {
867 // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
868 // This substitution replaces a complex expression with several smaller expressions, and cannot
869 // cycle if we follow that practice.
870 if let MirScalarExpr::CallBinary {
871 func: mz_expr::BinaryFunc::Eq,
872 expr1,
873 expr2,
874 } = expr
875 {
876 to_add.push(vec![*expr1.clone(), *expr2.clone()]);
877 to_add.push(vec![
878 MirScalarExpr::literal_false(),
879 expr1.clone().call_is_null(),
880 expr2.clone().call_is_null(),
881 ]);
882 }
883 }
884 // Remove the more complex form of the expression.
885 class.retain(|expr| {
886 if let MirScalarExpr::CallBinary {
887 func: mz_expr::BinaryFunc::Eq,
888 ..
889 } = expr
890 {
891 false
892 } else {
893 true
894 }
895 });
896 for expr in class.iter() {
897 // If TRUE == NOT(X) then FALSE == X is a simpler form.
898 if let MirScalarExpr::CallUnary {
899 func: mz_expr::UnaryFunc::Not(_),
900 expr: e,
901 } = expr
902 {
903 to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
904 }
905 }
906 class.retain(|expr| {
907 if let MirScalarExpr::CallUnary {
908 func: mz_expr::UnaryFunc::Not(_),
909 ..
910 } = expr
911 {
912 false
913 } else {
914 true
915 }
916 });
917 }
918 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
919 for expr in class.iter() {
920 // If FALSE == NOT(X) then TRUE == X is a simpler form.
921 if let MirScalarExpr::CallUnary {
922 func: mz_expr::UnaryFunc::Not(_),
923 expr: e,
924 } = expr
925 {
926 to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
927 }
928 }
929 class.retain(|expr| {
930 if let MirScalarExpr::CallUnary {
931 func: mz_expr::UnaryFunc::Not(_),
932 ..
933 } = expr
934 {
935 false
936 } else {
937 true
938 }
939 });
940 }
941 }
942 self.classes.extend(to_add);
943
944 // 3. Restore equivalence relation structure and observe if any changes result.
945 self.refresh()
946 }
947
948 /// Produce the equivalences present in both inputs.
949 pub fn union(&self, other: &Self) -> Self {
950 self.union_many([other])
951 }
952
953 /// The equivalence classes of terms equivalent in all inputs.
954 ///
955 /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
956 /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
957 /// correct, result.
958 ///
959 /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
960 /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
961 /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
962 pub fn union_many<'a, I>(&self, others: I) -> Self
963 where
964 I: IntoIterator<Item = &'a Self>,
965 {
966 // List of expressions in the intersection, and a proxy equivalence class identifier.
967 let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
968 // Map from expression to a proxy equivalence class identifier.
969 let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
970 for (key, val) in self.remap.iter() {
971 if !rekey.contains_key(val) {
972 rekey.insert(val, rekey.len());
973 }
974 intersection.push((key, rekey[val]));
975 }
976 for other in others {
977 // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
978 let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
979 intersection.retain_mut(|(key, idx)| {
980 if let Some(val) = other.remap.get(key) {
981 if !rekey.contains_key(&(*idx, val)) {
982 rekey.insert((*idx, val), rekey.len());
983 }
984 *idx = rekey[&(*idx, val)];
985 true
986 } else {
987 false
988 }
989 });
990 }
991 let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
992 for (key, vals) in intersection {
993 classes.entry(vals).or_default().push(key.clone())
994 }
995 let classes = classes.into_values().collect::<Vec<_>>();
996 let mut equivalences = EquivalenceClasses {
997 classes,
998 remap: Default::default(),
999 };
1000 equivalences.minimize(None);
1001 equivalences
1002 }
1003
1004 /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1005 pub fn permute(&mut self, permutation: &[usize]) {
1006 for class in self.classes.iter_mut() {
1007 for expr in class.iter_mut() {
1008 expr.permute(permutation);
1009 }
1010 }
1011 self.remap.clear();
1012 self.minimize(None);
1013 }
1014
1015 /// Subject the constraints to the column projection, reworking and removing equivalences.
1016 ///
1017 /// This method should also introduce equivalences representing any repeated columns.
1018 pub fn project<I>(&mut self, output_columns: I)
1019 where
1020 I: IntoIterator<Item = usize> + Clone,
1021 {
1022 // Retain the first instance of each column, and record subsequent instances as duplicates.
1023 let mut dupes = Vec::new();
1024 let mut remap = BTreeMap::default();
1025 for (idx, col) in output_columns.into_iter().enumerate() {
1026 if let Some(pos) = remap.get(&col) {
1027 dupes.push((*pos, idx));
1028 } else {
1029 remap.insert(col, idx);
1030 }
1031 }
1032
1033 // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1034 // Many expressions may not be localized, but may reference canonical non-localized expressions
1035 // for classes that contain a localized expression; in that case we can "backport" the localized
1036 // expression to give expressions referencing the canonical expression a shot at localization.
1037 //
1038 // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1039 // to look any further than backporting those. Backporting should have the property that the simplest
1040 // localized expression in each class does not contain any non-localized canonical expressions
1041 // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1042 // expressions should never fire a second
1043
1044 // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1045 // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1046 // As we find localized expressions, we can replace uses of their equivalent representative with them,
1047 // which may allow further expression localization.
1048 // We continue the process until no further classes can be localized.
1049
1050 // A map from representatives to our first localization of their equivalence class.
1051 let mut localized = false;
1052 while !localized {
1053 localized = true;
1054 let mut current_map = BTreeMap::default();
1055 for class in self.classes.iter_mut() {
1056 if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1057 if let Some(pos) = class
1058 .iter()
1059 .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1060 {
1061 class.swap(0, pos);
1062 localized = false;
1063 }
1064 }
1065 for expr in class[1..].iter() {
1066 current_map.insert(expr.clone(), class[0].clone());
1067 }
1068 }
1069
1070 // attempt to replace representatives with equivalent localizeable expressions.
1071 for class_index in 0..self.classes.len() {
1072 for index in 0..self.classes[class_index].len() {
1073 current_map.reduce_child(&mut self.classes[class_index][index]);
1074 }
1075 }
1076 // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1077 }
1078
1079 // Localize all localizable expressions and discard others.
1080 for class in self.classes.iter_mut() {
1081 class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1082 for expr in class.iter_mut() {
1083 expr.permute_map(&remap);
1084 }
1085 }
1086 self.classes.retain(|c| c.len() > 1);
1087 // If column repetitions, introduce them as equivalences.
1088 // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1089 for (col1, col2) in dupes {
1090 self.classes.push(vec![
1091 MirScalarExpr::column(col1),
1092 MirScalarExpr::column(col2),
1093 ]);
1094 }
1095 self.remap.clear();
1096 self.minimize(None);
1097 }
1098
1099 /// True if any equivalence class contains two distinct non-error literals.
1100 pub fn unsatisfiable(&self) -> bool {
1101 for class in self.classes.iter() {
1102 let mut literal_ok = None;
1103 for expr in class.iter() {
1104 if let MirScalarExpr::Literal(Ok(row), _) = expr {
1105 if literal_ok.is_some() && literal_ok != Some(row) {
1106 return true;
1107 } else {
1108 literal_ok = Some(row);
1109 }
1110 }
1111 }
1112 }
1113 false
1114 }
1115
1116 /// Returns a map that can be used to replace (sub-)expressions.
1117 pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1118 &self.remap
1119 }
1120
1121 /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1122 ///
1123 /// This test bails out as soon as it sees a non-literal, and may have false negatives
1124 /// if the data are not sorted with literals at the front.
1125 fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1126 where
1127 P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1128 {
1129 class
1130 .iter()
1131 .take_while(|e| e.is_literal())
1132 .filter_map(|e| e.as_literal())
1133 .any(move |e| predicate(&e))
1134 }
1135}
1136
1137/// A type capable of simplifying `MirScalarExpr`s.
1138pub trait ExpressionReducer {
1139 /// Attempt to replace `expr` itself with another expression.
1140 /// Returns true if it does so.
1141 fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1142 /// Attempt to replace any subexpressions of `expr` with other expressions.
1143 /// Returns true if it does so.
1144 fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1145 let mut simplified = false;
1146 simplified = simplified || self.reduce_child(expr);
1147 simplified = simplified || self.replace(expr);
1148 simplified
1149 }
1150 /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1151 /// Returns true if it does so.
1152 fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1153 let mut simplified = false;
1154 match expr {
1155 MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1156 simplified = self.reduce_expr(expr1) || simplified;
1157 simplified = self.reduce_expr(expr2) || simplified;
1158 }
1159 MirScalarExpr::CallUnary { expr, .. } => {
1160 simplified = self.reduce_expr(expr) || simplified;
1161 }
1162 MirScalarExpr::CallVariadic { exprs, .. } => {
1163 for expr in exprs.iter_mut() {
1164 simplified = self.reduce_expr(expr) || simplified;
1165 }
1166 }
1167 MirScalarExpr::If { cond: _, then, els } => {
1168 // Do not simplify `cond`, as we cannot ensure the simplification
1169 // continues to hold as expressions migrate around.
1170 simplified = self.reduce_expr(then) || simplified;
1171 simplified = self.reduce_expr(els) || simplified;
1172 }
1173 _ => {}
1174 }
1175 simplified
1176 }
1177}
1178
1179impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1180 /// Perform any exact replacement for `expr`, report if it had an effect.
1181 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1182 if let Some(other) = self.get(expr) {
1183 if other != &expr {
1184 expr.clone_from(other);
1185 return true;
1186 }
1187 }
1188 false
1189 }
1190}
1191
1192impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1193 /// Perform any exact replacement for `expr`, report if it had an effect.
1194 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1195 if let Some(other) = self.get(expr) {
1196 if other != expr {
1197 expr.clone_from(other);
1198 return true;
1199 }
1200 }
1201 false
1202 }
1203}
1204
1205/// True iff the aggregate function returns an input datum.
1206fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1207 match aggregate {
1208 AggregateFunc::MaxInt16
1209 | AggregateFunc::MaxInt32
1210 | AggregateFunc::MaxInt64
1211 | AggregateFunc::MaxUInt16
1212 | AggregateFunc::MaxUInt32
1213 | AggregateFunc::MaxUInt64
1214 | AggregateFunc::MaxMzTimestamp
1215 | AggregateFunc::MaxFloat32
1216 | AggregateFunc::MaxFloat64
1217 | AggregateFunc::MaxBool
1218 | AggregateFunc::MaxString
1219 | AggregateFunc::MaxDate
1220 | AggregateFunc::MaxTimestamp
1221 | AggregateFunc::MaxTimestampTz
1222 | AggregateFunc::MinInt16
1223 | AggregateFunc::MinInt32
1224 | AggregateFunc::MinInt64
1225 | AggregateFunc::MinUInt16
1226 | AggregateFunc::MinUInt32
1227 | AggregateFunc::MinUInt64
1228 | AggregateFunc::MinMzTimestamp
1229 | AggregateFunc::MinFloat32
1230 | AggregateFunc::MinFloat64
1231 | AggregateFunc::MinBool
1232 | AggregateFunc::MinString
1233 | AggregateFunc::MinDate
1234 | AggregateFunc::MinTimestamp
1235 | AggregateFunc::MinTimestampTz
1236 | AggregateFunc::Any
1237 | AggregateFunc::All => true,
1238 _ => false,
1239 }
1240}