mz_transform/analysis/equivalences.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use itertools::Itertools;
22use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
23use mz_expr::explain::{HumanizedExplain, HumanizerMode};
24use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
25use mz_ore::str::{bracketed, separated};
26use mz_repr::{Datum, ReprColumnType, ReprScalarType};
27
28use crate::analysis::{Analysis, Lattice};
29use crate::analysis::{Arity, ReprRelationType};
30use crate::analysis::{Derived, DerivedBuilder};
31
32/// Pulls up and pushes down predicate information represented as equivalences
33#[derive(Debug, Default)]
34pub struct Equivalences;
35
36impl Analysis for Equivalences {
37 // A `Some(list)` indicates a list of classes of equivalent expressions.
38 // A `None` indicates all expressions are equivalent, including contradictions;
39 // this is only possible for the empty collection, and as an initial result for
40 // unconstrained recursive terms.
41 type Value = Option<EquivalenceClasses>;
42
43 fn announce_dependencies(builder: &mut DerivedBuilder) {
44 builder.require(Arity);
45 builder.require(ReprRelationType); // needed for expression reduction.
46 }
47
48 fn derive(
49 &self,
50 expr: &MirRelationExpr,
51 index: usize,
52 results: &[Self::Value],
53 depends: &Derived,
54 ) -> Self::Value {
55 let mut equivalences = match expr {
56 MirRelationExpr::Constant { rows, typ } => {
57 // Trawl `rows` for any constant information worth recording.
58 // Literal columns may be valuable; non-nullability could be too.
59 let mut equivalences = EquivalenceClasses::default();
60 if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
61 // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
62 let len = row.iter().count();
63 let mut common = Vec::with_capacity(len);
64 common.extend(row.iter().map(Some));
65 // Prep initial nullability information.
66 let mut nullable_cols = common
67 .iter()
68 .map(|datum| datum == &Some(Datum::Null))
69 .collect::<Vec<_>>();
70
71 for (row, _cnt) in rows.iter() {
72 for ((datum, common), nullable) in row
73 .iter()
74 .zip_eq(common.iter_mut())
75 .zip_eq(nullable_cols.iter_mut())
76 {
77 if Some(datum) != *common {
78 *common = None;
79 }
80 if datum == Datum::Null {
81 *nullable = true;
82 }
83 }
84 }
85 for (index, common) in common.into_iter().enumerate() {
86 if let Some(datum) = common {
87 equivalences.classes.push(vec![
88 MirScalarExpr::column(index),
89 MirScalarExpr::literal_ok(
90 datum,
91 typ.column_types[index].scalar_type.clone(),
92 ),
93 ]);
94 }
95 }
96 // If any columns are non-null, introduce this fact.
97 if nullable_cols.iter().any(|x| !*x) {
98 let mut class = vec![MirScalarExpr::literal_false()];
99 for (index, nullable) in nullable_cols.iter().enumerate() {
100 if !*nullable {
101 class.push(MirScalarExpr::column(index).call_is_null());
102 }
103 }
104 equivalences.classes.push(class);
105 }
106 }
107 Some(equivalences)
108 }
109 MirRelationExpr::Get { id, typ, .. } => {
110 let mut equivalences = Some(EquivalenceClasses::default());
111 // Find local identifiers, but nothing for external identifiers.
112 if let Id::Local(id) = id {
113 if let Some(offset) = depends.bindings().get(id) {
114 // It is possible we have derived nothing for a recursive term
115 if let Some(result) = results.get(*offset) {
116 equivalences.clone_from(result);
117 } else {
118 // No top element was prepared.
119 // This means we are executing pessimistically,
120 // but perhaps we must because optimism is off.
121 }
122 }
123 }
124 // Incorporate statements about column nullability.
125 let mut non_null_cols = vec![MirScalarExpr::literal_false()];
126 for (index, col_type) in typ.column_types.iter().enumerate() {
127 if !col_type.nullable {
128 non_null_cols.push(MirScalarExpr::column(index).call_is_null());
129 }
130 }
131 if non_null_cols.len() > 1 {
132 if let Some(equivalences) = equivalences.as_mut() {
133 equivalences.classes.push(non_null_cols);
134 }
135 }
136
137 equivalences
138 }
139 MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
140 MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
141 MirRelationExpr::Project { outputs, .. } => {
142 // restrict equivalences, and introduce equivalences for repeated outputs.
143 let mut equivalences = results.get(index - 1).unwrap().clone();
144 equivalences
145 .as_mut()
146 .map(|e| e.project(outputs.iter().cloned()));
147 equivalences
148 }
149 MirRelationExpr::Map { scalars, .. } => {
150 // introduce equivalences for new columns and expressions that define them.
151 let mut equivalences = results.get(index - 1).unwrap().clone();
152 if let Some(equivalences) = &mut equivalences {
153 let input_arity = depends.results::<Arity>()[index - 1];
154 for (pos, expr) in scalars.iter().enumerate() {
155 equivalences
156 .classes
157 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
158 }
159 }
160 equivalences
161 }
162 MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
163 MirRelationExpr::Filter { predicates, .. } => {
164 let mut equivalences = results.get(index - 1).unwrap().clone();
165 if let Some(equivalences) = &mut equivalences {
166 let mut class = predicates.clone();
167 class.push(MirScalarExpr::literal_ok(Datum::True, ReprScalarType::Bool));
168 equivalences.classes.push(class);
169 }
170 equivalences
171 }
172 MirRelationExpr::Join { equivalences, .. } => {
173 // Collect equivalences from all inputs;
174 let expr_index = index;
175 let mut children = depends
176 .children_of_rev(expr_index, expr.children().count())
177 .collect::<Vec<_>>();
178 children.reverse();
179
180 let arity = depends.results::<Arity>();
181 let mut columns = 0;
182 let mut result = Some(EquivalenceClasses::default());
183 for child in children.into_iter() {
184 let input_arity = arity[child];
185 let equivalences = results[child].clone();
186 if let Some(mut equivalences) = equivalences {
187 let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
188 equivalences.permute(&permutation);
189 result
190 .as_mut()
191 .map(|e| e.classes.extend(equivalences.classes));
192 } else {
193 result = None;
194 }
195 columns += input_arity;
196 }
197
198 // Fold join equivalences into our results.
199 result
200 .as_mut()
201 .map(|e| e.classes.extend(equivalences.iter().cloned()));
202 result
203 }
204 MirRelationExpr::Reduce {
205 group_key,
206 aggregates,
207 ..
208 } => {
209 let input_arity = depends.results::<Arity>()[index - 1];
210 let mut equivalences = results.get(index - 1).unwrap().clone();
211 if let Some(equivalences) = &mut equivalences {
212 // Introduce keys column equivalences as if a map, then project to those columns.
213 // This should retain as much information as possible about these columns.
214 for (pos, expr) in group_key.iter().enumerate() {
215 equivalences
216 .classes
217 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
218 }
219
220 // Having added classes to `equivalences`, we should minimize the classes to fold the
221 // information in before applying the `project`, to set it up for success.
222 equivalences.minimize(None);
223
224 // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
225 let extended = equivalences.clone();
226 // Now project down the equivalences, as we will extend them in terms of the output columns.
227 equivalences.project(input_arity..(input_arity + group_key.len()));
228
229 // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
230 // They also pass through equivalences of them and other constant columns (e.g. key columns).
231 // However, it is not correct to simply project onto these columns, as relationships amongst
232 // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
233 // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
234 // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
235 for (index, aggregate) in aggregates.iter().enumerate() {
236 if aggregate_is_input(&aggregate.func) {
237 let mut temp_equivs = extended.clone();
238 temp_equivs.classes.push(vec![
239 MirScalarExpr::column(input_arity + group_key.len()),
240 aggregate.expr.clone(),
241 ]);
242 temp_equivs.minimize(None);
243 temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
244 let columns = (0..group_key.len())
245 .chain(std::iter::once(group_key.len() + index))
246 .collect::<Vec<_>>();
247 temp_equivs.permute(&columns[..]);
248 equivalences.classes.extend(temp_equivs.classes);
249 }
250 }
251 }
252 equivalences
253 }
254 MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
255 MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
256 MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
257 MirRelationExpr::Union { .. } => {
258 let expr_index = index;
259 let mut child_equivs = depends
260 .children_of_rev(expr_index, expr.children().count())
261 .flat_map(|c| &results[c]);
262 if let Some(first) = child_equivs.next() {
263 Some(first.union_many(child_equivs))
264 } else {
265 None
266 }
267 }
268 MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
269 };
270
271 let repr_expr_type = depends.results::<ReprRelationType>()[index].as_ref();
272 equivalences
273 .as_mut()
274 .map(|e| e.minimize(repr_expr_type.as_ref().map(|x| &x[..])));
275 equivalences
276 }
277
278 fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
279 Some(Box::new(EQLattice))
280 }
281}
282
283struct EQLattice;
284
285impl Lattice<Option<EquivalenceClasses>> for EQLattice {
286 fn top(&self) -> Option<EquivalenceClasses> {
287 None
288 }
289
290 fn meet_assign(
291 &self,
292 a: &mut Option<EquivalenceClasses>,
293 b: Option<EquivalenceClasses>,
294 ) -> bool {
295 match (&mut *a, b) {
296 (_, None) => false,
297 (None, b) => {
298 *a = b;
299 true
300 }
301 (Some(a), Some(b)) => {
302 let mut c = a.union(&b);
303 std::mem::swap(a, &mut c);
304 a != &mut c
305 }
306 }
307 }
308}
309
310/// A compact representation of classes of expressions that must be equivalent.
311///
312/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
313/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
314/// and any other element of that list can be replaced by it.
315///
316/// The classes are meant to be minimized, with each expression as reduced as it can be,
317/// and all classes sharing an element merged.
318#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
319pub struct EquivalenceClasses {
320 /// Multiple lists of equivalent expressions, each representing an equivalence class.
321 ///
322 /// The first element should be the "canonical" simplest element, that any other element
323 /// can be replaced by.
324 /// These classes are unified whenever possible, to minimize the number of classes.
325 /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
326 /// which refreshes both `self.classes` and `self.remap`.
327 pub classes: Vec<Vec<MirScalarExpr>>,
328
329 /// An expression simplification map.
330 ///
331 /// This map reflects an equivalence relation based on a prior version of `self.classes`.
332 /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
333 /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
334 ///
335 /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
336 /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
337 /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
338 remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
339}
340
341/// An enum that can represent either a standard `EquivalenceClasses` or a
342/// `EquivalenceClassesWithholdingErrors`.
343#[derive(Clone, Debug)]
344pub enum EqClassesImpl {
345 /// Standard equivalence classes.
346 EquivalenceClasses(EquivalenceClasses),
347 /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
348 EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
349}
350
351impl EqClassesImpl {
352 /// Returns a reference to the underlying `EquivalenceClasses`.
353 pub fn equivalence_classes(&self) -> &EquivalenceClasses {
354 match self {
355 EqClassesImpl::EquivalenceClasses(classes) => classes,
356 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
357 &classes.equivalence_classes
358 }
359 }
360 }
361
362 /// Returns a mutable reference to the underlying `EquivalenceClasses`.
363 pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
364 match self {
365 EqClassesImpl::EquivalenceClasses(classes) => classes,
366 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
367 &mut classes.equivalence_classes
368 }
369 }
370 }
371
372 /// Extend the equivalence classes with new equivalences.
373 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
374 match self {
375 EqClassesImpl::EquivalenceClasses(classes) => {
376 classes.classes.extend(equivalences);
377 }
378 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
379 classes.extend_equivalences(equivalences);
380 }
381 }
382 }
383
384 /// Push a new equivalence class.
385 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
386 match self {
387 EqClassesImpl::EquivalenceClasses(classes) => {
388 classes.classes.push(equivalences);
389 }
390 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
391 classes.push_equivalences(equivalences);
392 }
393 }
394 }
395
396 /// Subject the constraints to the column projection, reworking and removing equivalences.
397 pub fn project<I>(&mut self, output_columns: I)
398 where
399 I: IntoIterator<Item = usize> + Clone,
400 {
401 match self {
402 EqClassesImpl::EquivalenceClasses(classes) => {
403 classes.project(output_columns);
404 }
405 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
406 classes.project(output_columns);
407 }
408 }
409 }
410
411 /// Extract the equivalences
412 pub fn extract_equivalences(
413 &mut self,
414 columns: Option<&[ReprColumnType]>,
415 ) -> Vec<Vec<MirScalarExpr>> {
416 match self {
417 EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
418 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
419 classes.extract_equivalences(columns)
420 }
421 }
422 }
423}
424
425/// A wrapper struct for equivalence classes that witholds errors
426/// from the underlying equivalence classes, reintroducing them
427/// when extracting equivalences.
428#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
429pub struct EquivalenceClassesWithholdingErrors {
430 /// The underlying equivalence classes.
431 pub equivalence_classes: EquivalenceClasses,
432 /// A list of equivalence classes that contain errors, which we hold back
433 pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
434}
435
436impl EquivalenceClassesWithholdingErrors {
437 /// Extend the equivalence classes with new equivalences.
438 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
439 for class in equivalences {
440 self.push_equivalences(class);
441 }
442 }
443
444 /// Push a new equivalence class, which may contain errors.
445 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
446 let (with_errs, without_errs): (Vec<_>, Vec<_>) =
447 equivalences.into_iter().partition(|e| e.contains_err());
448 self.held_back_classes
449 .push((without_errs.first().cloned(), with_errs));
450 if without_errs.len() > 1 {
451 self.equivalence_classes.classes.push(without_errs);
452 }
453 }
454
455 /// Subject the constraints to the column projection, reworking and removing equivalences.
456 /// This method should also introduce equivalences representing any repeated columns.
457 /// This notably does not project the held back classes.
458 pub fn project<I>(&mut self, output_columns: I)
459 where
460 I: IntoIterator<Item = usize> + Clone,
461 {
462 self.equivalence_classes.project(output_columns.clone());
463 }
464
465 /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
466 pub fn extract_equivalences(
467 &mut self,
468 columns: Option<&[ReprColumnType]>,
469 ) -> Vec<Vec<MirScalarExpr>> {
470 self.equivalence_classes.minimize(columns);
471
472 if self.held_back_classes.is_empty() {
473 self.equivalence_classes.classes.clone()
474 } else {
475 let reducer = self.equivalence_classes.reducer();
476 let mut classes = self.equivalence_classes.classes.clone();
477 for (anchor, mut class) in self.held_back_classes.drain(..) {
478 if let Some(mut anchor) = anchor {
479 // Reduce the anchor expression to its canonical form.
480 reducer.reduce_expr(&mut anchor);
481 class.push(anchor.clone());
482 }
483 classes.push(class.clone());
484 }
485 canonicalize_equivalence_classes(&mut classes);
486 classes
487 }
488 }
489}
490
491/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
492/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
493/// [`HumanizedEquivalenceClasses`].
494impl std::fmt::Display for EquivalenceClasses {
495 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
496 HumanizedEquivalenceClasses {
497 equivalence_classes: self,
498 cols: None,
499 mode: HumanizedExplain::default(),
500 }
501 .fmt(f)
502 }
503}
504
505/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
506/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
507/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
508/// `Display` nor `HumanizedExpr` is defined in this crate.)
509#[derive(Debug)]
510pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
511 /// The [`EquivalenceClasses`] to be humanized.
512 pub equivalence_classes: &'a EquivalenceClasses,
513 /// An optional vector of inferred column names to be used when rendering
514 /// column references in expressions.
515 pub cols: Option<&'a Vec<String>>,
516 /// The rendering mode to use. See `HumanizerMode` for details.
517 pub mode: M,
518}
519
520impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
521 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
522 // Only show `classes`.
523 // (The following hopefully avoids allocating any of the intermediate composite strings.)
524 let classes = self.equivalence_classes.classes.iter().map(|class| {
525 format!(
526 "{}",
527 bracketed(
528 "[",
529 "]",
530 separated(
531 ", ",
532 class.iter().map(|expr| self.mode.expr(expr, self.cols))
533 )
534 )
535 )
536 });
537 write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
538 }
539}
540
541impl EquivalenceClasses {
542 /// Comparator function for the complexity of scalar expressions. Simpler expressions are
543 /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
544 pub fn mir_scalar_expr_complexity(
545 e1: &MirScalarExpr,
546 e2: &MirScalarExpr,
547 ) -> std::cmp::Ordering {
548 use MirScalarExpr::*;
549 use std::cmp::Ordering::*;
550 match (e1, e2) {
551 (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
552 (Literal(_, _), _) => Less,
553 (_, Literal(_, _)) => Greater,
554 (Column(_, _), Column(_, _)) => e1.cmp(e2),
555 (Column(_, _), _) => Less,
556 (_, Column(_, _)) => Greater,
557 (x, y) => {
558 // General expressions should be ordered by their size,
559 // to ensure we only simplify expressions by substitution.
560 // If same size, then fall back to the expressions' Ord.
561 match x.size().cmp(&y.size()) {
562 Equal => x.cmp(y),
563 other => other,
564 }
565 }
566 }
567 }
568
569 /// Sorts and deduplicates each class, removing literal errors.
570 ///
571 /// This method does not ensure equivalence relation structure, but instead performs
572 /// only minimal structural clean-up.
573 fn tidy(&mut self) {
574 for class in self.classes.iter_mut() {
575 // Remove all literal errors, as they cannot be equated to other things.
576 class.retain(|e| !e.is_literal_err());
577 class.sort_by(Self::mir_scalar_expr_complexity);
578 class.dedup();
579 }
580 self.classes.retain(|c| c.len() > 1);
581 self.classes.sort();
582 self.classes.dedup();
583 }
584
585 /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
586 ///
587 /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
588 /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
589 /// whether the equivalence classes it represents have experienced any changes since the
590 /// last refresh.
591 fn refresh(&mut self) -> bool {
592 self.tidy();
593
594 // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
595 // If it contains the same number of expressions as `self.classes`, and for every expression in
596 // `self.classes` the two agree on the representative, they are identical.
597 if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
598 && self
599 .classes
600 .iter()
601 .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
602 {
603 // No change, so return false.
604 return false;
605 }
606
607 // Optimistically build the `remap` we would want.
608 // Note if any unions would be required, in which case we have further work to do,
609 // including re-forming `self.classes`.
610 let mut union_find = BTreeMap::default();
611 let mut dirtied = false;
612 for class in self.classes.iter() {
613 for expr in class.iter() {
614 if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
615 // A merge is required, but have the more complex expression point at the simpler one.
616 // This allows `union_find` to end as the `remap` for the new `classes` we form, with
617 // the only required work being compressing all the paths.
618 if Self::mir_scalar_expr_complexity(&other, &class[0])
619 == std::cmp::Ordering::Less
620 {
621 union_find.union(&class[0], &other);
622 } else {
623 union_find.union(&other, &class[0]);
624 }
625 dirtied = true;
626 }
627 }
628 }
629 if dirtied {
630 let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
631 for class in self.classes.drain(..) {
632 for expr in class {
633 let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
634 classes.entry(root).or_default().push(expr);
635 }
636 }
637 self.classes = classes.into_values().collect();
638 self.tidy();
639 }
640
641 let changed = self.remap != union_find;
642 self.remap = union_find;
643 changed
644 }
645
646 /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
647 ///
648 /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
649 pub fn minimize(&mut self, columns: Option<&[ReprColumnType]>) {
650 // Repeatedly, we reduce each of the classes themselves, then unify the classes.
651 // This should strictly reduce complexity, and reach a fixed point.
652 // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
653
654 // We should not rely on nullability information present in `column_types`. (Doing this
655 // every time just before calling `reduce` was found to be a bottleneck during incident-217,
656 // so now we do this nullability tweaking only once here.)
657 let mut columns = columns.map(|x| x.to_vec());
658 let mut nonnull = Vec::new();
659 if let Some(columns) = columns.as_mut() {
660 for (index, col) in columns.iter_mut().enumerate() {
661 let is_null = MirScalarExpr::column(index).call_is_null();
662 if !col.nullable
663 && self
664 .remap
665 .get(&is_null)
666 .map(|e| !e.is_literal_false())
667 .unwrap_or(true)
668 {
669 nonnull.push(is_null);
670 }
671 col.nullable = true;
672 }
673 }
674 if !nonnull.is_empty() {
675 nonnull.push(MirScalarExpr::literal_false());
676 self.classes.push(nonnull);
677 }
678
679 // Ensure `self.classes` and `self.remap` are equivalence relations.
680 // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
681 // We have also likely mutated `self.classes` just above with non-nullability information.
682 self.refresh();
683
684 // Termination will be detected by comparing to the map of equivalence classes.
685 let mut previous = Some(self.remap.clone());
686 while let Some(prev) = previous {
687 // Attempt to add new equivalences.
688 let novel = self.expand();
689 if !novel.is_empty() {
690 self.classes.extend(novel);
691 self.refresh();
692 }
693
694 // We continue as long as any simplification has occurred.
695 // An expression can be simplified, a duplication found, or two classes unified.
696 let mut stable = false;
697 while !stable {
698 stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
699 }
700
701 // Termination detection.
702 if prev != self.remap {
703 previous = Some(self.remap.clone());
704 } else {
705 previous = None;
706 }
707 }
708 }
709
710 /// Proposes new equivalences that are likely to be novel.
711 ///
712 /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
713 /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
714 /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
715 /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
716 /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
717 /// number of times we call and amount of work we do in `self.refresh()`.
718 fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
719 // Consider expanding `self.classes` with novel equivalences.
720 let mut novel = self.implications();
721 for class in novel.iter_mut() {
722 // reduce each expression to its canonical form.
723 for expr in class.iter_mut() {
724 self.remap.reduce_expr(expr);
725 }
726 class.sort();
727 class.dedup();
728 // for a class to be interesting we require at least two elements that do not reference the same root.
729 let common_class = class
730 .iter()
731 .map(|x| self.remap.get(x))
732 .reduce(|prev, this| if prev == this { prev } else { None });
733 if class.len() == 1 || common_class != Some(None) {
734 class.clear();
735 }
736 }
737 novel.retain(|c| !c.is_empty());
738 novel
739 }
740
741 /// Derives potentially novel equivalences without regard for minimization.
742 ///
743 /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
744 /// and therefore should not be used in `minimize_once`. They are still potentially important, but
745 /// required additional guardrails to ensure we reach a fixed point.
746 ///
747 /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
748 /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
749 /// For example, before producing a new equivalence, one could check that the involved terms are not
750 /// already present in the same class.
751 fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
752 let mut new_equivalences = Vec::new();
753
754 // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
755 let mut non_null = std::collections::BTreeSet::default();
756 for class in self.classes.iter() {
757 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
758 for e in class.iter() {
759 if let MirScalarExpr::CallUnary {
760 func: mz_expr::UnaryFunc::IsNull(_),
761 expr,
762 } = e
763 {
764 expr.non_null_requirements(&mut non_null);
765 }
766 }
767 }
768 }
769 // If we see `true == foo` we can add the non-null implications of `foo`.
770 // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
771 // an important idiom to identify for how we express predicates.
772 for class in self.classes.iter() {
773 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
774 for expr in class.iter() {
775 expr.non_null_requirements(&mut non_null);
776 }
777 }
778 }
779 // Only keep constraints that are not already known.
780 // Known constraints will present as `COL(_) IS NULL == false`,
781 // which can only happen if `false` is present, and both terms
782 // map to the same canonical representative>
783 let lit_false = MirScalarExpr::literal_false();
784 let target = self.remap.get(&lit_false);
785 if target.is_some() {
786 non_null.retain(|c| {
787 let is_null = MirScalarExpr::column(*c).call_is_null();
788 self.remap.get(&is_null) != target
789 });
790 }
791 if !non_null.is_empty() {
792 let mut class = Vec::with_capacity(non_null.len() + 1);
793 class.push(MirScalarExpr::literal_false());
794 class.extend(
795 non_null
796 .into_iter()
797 .map(|c| MirScalarExpr::column(c).call_is_null()),
798 );
799 new_equivalences.push(class);
800 }
801
802 // If we see records formed from other expressions, we can equate the expressions with
803 // accessors applied to the class of the record former. In `minimize_once` we reduce by
804 // equivalence class representative before we perform expression simplification, so we
805 // shoud be able to just use the expression former, rather than find its representative.
806 // The risk, potentially, is that we would apply accessors to the record former and then
807 // just simplify it away learning nothing.
808 for class in self.classes.iter() {
809 for expr in class.iter() {
810 // Record-forming expressions can equate their accessors and their members.
811 if let MirScalarExpr::CallVariadic {
812 func: mz_expr::VariadicFunc::RecordCreate(..),
813 exprs,
814 } = expr
815 {
816 for (index, e) in exprs.iter().enumerate() {
817 new_equivalences.push(vec![
818 e.clone(),
819 expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
820 mz_expr::func::RecordGet(index),
821 )),
822 ]);
823 }
824 }
825 }
826 }
827
828 // Return all newly established equivalences.
829 new_equivalences
830 }
831
832 /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
833 ///
834 /// This invocation should take roughly linear time.
835 /// It starts with equivalence class invariants maintained (closed under transitivity), and then
836 /// 1. Performs per-expression reduction, including the class structure to replace subexpressions.
837 /// 2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
838 /// 3. Restores the equivalence class invariants.
839 fn minimize_once(&mut self, columns: Option<&[ReprColumnType]>) -> bool {
840 // 1. Reduce each expression
841 //
842 // This reduction first looks for subexpression substitutions that can be performed,
843 // and then applies expression reduction if column type information is provided.
844 for class in self.classes.iter_mut() {
845 for expr in class.iter_mut() {
846 self.remap.reduce_child(expr);
847 if let Some(columns) = columns {
848 let orig_expr = expr.clone();
849 expr.reduce(columns);
850 if expr.contains_err() {
851 // Rollback to the original expression if it contains an error.
852 *expr = orig_expr;
853 }
854 }
855 }
856 }
857
858 // 2. Identify idioms
859 // E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
860 let mut to_add = Vec::new();
861 for class in self.classes.iter_mut() {
862 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
863 for expr in class.iter() {
864 // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
865 // This substitution replaces a complex expression with several smaller expressions, and cannot
866 // cycle if we follow that practice.
867 if let MirScalarExpr::CallBinary {
868 func: mz_expr::BinaryFunc::Eq(_),
869 expr1,
870 expr2,
871 } = expr
872 {
873 to_add.push(vec![*expr1.clone(), *expr2.clone()]);
874 to_add.push(vec![
875 MirScalarExpr::literal_false(),
876 expr1.clone().call_is_null(),
877 expr2.clone().call_is_null(),
878 ]);
879 }
880 }
881 // Remove the more complex form of the expression.
882 class.retain(|expr| {
883 if let MirScalarExpr::CallBinary {
884 func: mz_expr::BinaryFunc::Eq(_),
885 ..
886 } = expr
887 {
888 false
889 } else {
890 true
891 }
892 });
893 for expr in class.iter() {
894 // If TRUE == NOT(X) then FALSE == X is a simpler form.
895 if let MirScalarExpr::CallUnary {
896 func: mz_expr::UnaryFunc::Not(_),
897 expr: e,
898 } = expr
899 {
900 to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
901 }
902 }
903 class.retain(|expr| {
904 if let MirScalarExpr::CallUnary {
905 func: mz_expr::UnaryFunc::Not(_),
906 ..
907 } = expr
908 {
909 false
910 } else {
911 true
912 }
913 });
914 }
915 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
916 for expr in class.iter() {
917 // If FALSE == NOT(X) then TRUE == X is a simpler form.
918 if let MirScalarExpr::CallUnary {
919 func: mz_expr::UnaryFunc::Not(_),
920 expr: e,
921 } = expr
922 {
923 to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
924 }
925 }
926 class.retain(|expr| {
927 if let MirScalarExpr::CallUnary {
928 func: mz_expr::UnaryFunc::Not(_),
929 ..
930 } = expr
931 {
932 false
933 } else {
934 true
935 }
936 });
937 }
938 }
939 self.classes.extend(to_add);
940
941 // 3. Restore equivalence relation structure and observe if any changes result.
942 self.refresh()
943 }
944
945 /// Produce the equivalences present in both inputs.
946 pub fn union(&self, other: &Self) -> Self {
947 self.union_many([other])
948 }
949
950 /// The equivalence classes of terms equivalent in all inputs.
951 ///
952 /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
953 /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
954 /// correct, result.
955 ///
956 /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
957 /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
958 /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
959 pub fn union_many<'a, I>(&self, others: I) -> Self
960 where
961 I: IntoIterator<Item = &'a Self>,
962 {
963 // List of expressions in the intersection, and a proxy equivalence class identifier.
964 let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
965 // Map from expression to a proxy equivalence class identifier.
966 let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
967 for (key, val) in self.remap.iter() {
968 if !rekey.contains_key(val) {
969 rekey.insert(val, rekey.len());
970 }
971 intersection.push((key, rekey[val]));
972 }
973 for other in others {
974 // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
975 let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
976 intersection.retain_mut(|(key, idx)| {
977 if let Some(val) = other.remap.get(key) {
978 if !rekey.contains_key(&(*idx, val)) {
979 rekey.insert((*idx, val), rekey.len());
980 }
981 *idx = rekey[&(*idx, val)];
982 true
983 } else {
984 false
985 }
986 });
987 }
988 let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
989 for (key, vals) in intersection {
990 classes.entry(vals).or_default().push(key.clone())
991 }
992 let classes = classes.into_values().collect::<Vec<_>>();
993 let mut equivalences = EquivalenceClasses {
994 classes,
995 remap: Default::default(),
996 };
997 equivalences.minimize(None);
998 equivalences
999 }
1000
1001 /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1002 pub fn permute(&mut self, permutation: &[usize]) {
1003 for class in self.classes.iter_mut() {
1004 for expr in class.iter_mut() {
1005 expr.permute(permutation);
1006 }
1007 }
1008 self.remap.clear();
1009 self.minimize(None);
1010 }
1011
1012 /// Subject the constraints to the column projection, reworking and removing equivalences.
1013 ///
1014 /// This method should also introduce equivalences representing any repeated columns.
1015 pub fn project<I>(&mut self, output_columns: I)
1016 where
1017 I: IntoIterator<Item = usize> + Clone,
1018 {
1019 // Retain the first instance of each column, and record subsequent instances as duplicates.
1020 let mut dupes = Vec::new();
1021 let mut remap = BTreeMap::default();
1022 for (idx, col) in output_columns.into_iter().enumerate() {
1023 if let Some(pos) = remap.get(&col) {
1024 dupes.push((*pos, idx));
1025 } else {
1026 remap.insert(col, idx);
1027 }
1028 }
1029
1030 // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1031 // Many expressions may not be localized, but may reference canonical non-localized expressions
1032 // for classes that contain a localized expression; in that case we can "backport" the localized
1033 // expression to give expressions referencing the canonical expression a shot at localization.
1034 //
1035 // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1036 // to look any further than backporting those. Backporting should have the property that the simplest
1037 // localized expression in each class does not contain any non-localized canonical expressions
1038 // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1039 // expressions should never fire a second
1040
1041 // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1042 // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1043 // As we find localized expressions, we can replace uses of their equivalent representative with them,
1044 // which may allow further expression localization.
1045 // We continue the process until no further classes can be localized.
1046
1047 // A map from representatives to our first localization of their equivalence class.
1048 let mut localized = false;
1049 while !localized {
1050 localized = true;
1051 let mut current_map = BTreeMap::default();
1052 for class in self.classes.iter_mut() {
1053 if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1054 if let Some(pos) = class
1055 .iter()
1056 .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1057 {
1058 class.swap(0, pos);
1059 localized = false;
1060 }
1061 }
1062 for expr in class[1..].iter() {
1063 current_map.insert(expr.clone(), class[0].clone());
1064 }
1065 }
1066
1067 // attempt to replace representatives with equivalent localizeable expressions.
1068 for class_index in 0..self.classes.len() {
1069 for index in 0..self.classes[class_index].len() {
1070 current_map.reduce_child(&mut self.classes[class_index][index]);
1071 }
1072 }
1073 // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1074 }
1075
1076 // Localize all localizable expressions and discard others.
1077 for class in self.classes.iter_mut() {
1078 class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1079 for expr in class.iter_mut() {
1080 expr.permute_map(&remap);
1081 }
1082 }
1083 self.classes.retain(|c| c.len() > 1);
1084 // If column repetitions, introduce them as equivalences.
1085 // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1086 for (col1, col2) in dupes {
1087 self.classes.push(vec![
1088 MirScalarExpr::column(col1),
1089 MirScalarExpr::column(col2),
1090 ]);
1091 }
1092 self.remap.clear();
1093 self.minimize(None);
1094 }
1095
1096 /// True if any equivalence class contains two distinct non-error literals.
1097 pub fn unsatisfiable(&self) -> bool {
1098 for class in self.classes.iter() {
1099 let mut literal_ok = None;
1100 for expr in class.iter() {
1101 if let MirScalarExpr::Literal(Ok(row), _) = expr {
1102 if literal_ok.is_some() && literal_ok != Some(row) {
1103 return true;
1104 } else {
1105 literal_ok = Some(row);
1106 }
1107 }
1108 }
1109 }
1110 false
1111 }
1112
1113 /// Returns a map that can be used to replace (sub-)expressions.
1114 pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1115 &self.remap
1116 }
1117
1118 /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1119 ///
1120 /// This test bails out as soon as it sees a non-literal, and may have false negatives
1121 /// if the data are not sorted with literals at the front.
1122 fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1123 where
1124 P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1125 {
1126 class
1127 .iter()
1128 .take_while(|e| e.is_literal())
1129 .filter_map(|e| e.as_literal())
1130 .any(move |e| predicate(&e))
1131 }
1132}
1133
1134/// A type capable of simplifying `MirScalarExpr`s.
1135pub trait ExpressionReducer {
1136 /// Attempt to replace `expr` itself with another expression.
1137 /// Returns true if it does so.
1138 fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1139 /// Attempt to replace any subexpressions of `expr` with other expressions.
1140 /// Returns true if it does so.
1141 fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1142 let mut simplified = false;
1143 simplified = simplified || self.reduce_child(expr);
1144 simplified = simplified || self.replace(expr);
1145 simplified
1146 }
1147 /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1148 /// Returns true if it does so.
1149 fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1150 let mut simplified = false;
1151 match expr {
1152 MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1153 simplified = self.reduce_expr(expr1) || simplified;
1154 simplified = self.reduce_expr(expr2) || simplified;
1155 }
1156 MirScalarExpr::CallUnary { expr, .. } => {
1157 simplified = self.reduce_expr(expr) || simplified;
1158 }
1159 MirScalarExpr::CallVariadic { exprs, .. } => {
1160 for expr in exprs.iter_mut() {
1161 simplified = self.reduce_expr(expr) || simplified;
1162 }
1163 }
1164 MirScalarExpr::If { cond: _, then, els } => {
1165 // Do not simplify `cond`, as we cannot ensure the simplification
1166 // continues to hold as expressions migrate around.
1167 simplified = self.reduce_expr(then) || simplified;
1168 simplified = self.reduce_expr(els) || simplified;
1169 }
1170 _ => {}
1171 }
1172 simplified
1173 }
1174}
1175
1176impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1177 /// Perform any exact replacement for `expr`, report if it had an effect.
1178 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1179 if let Some(other) = self.get(expr) {
1180 if other != &expr {
1181 expr.clone_from(other);
1182 return true;
1183 }
1184 }
1185 false
1186 }
1187}
1188
1189impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1190 /// Perform any exact replacement for `expr`, report if it had an effect.
1191 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1192 if let Some(other) = self.get(expr) {
1193 if other != expr {
1194 expr.clone_from(other);
1195 return true;
1196 }
1197 }
1198 false
1199 }
1200}
1201
1202/// True iff the aggregate function returns an input datum.
1203fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1204 match aggregate {
1205 AggregateFunc::MaxInt16
1206 | AggregateFunc::MaxInt32
1207 | AggregateFunc::MaxInt64
1208 | AggregateFunc::MaxUInt16
1209 | AggregateFunc::MaxUInt32
1210 | AggregateFunc::MaxUInt64
1211 | AggregateFunc::MaxMzTimestamp
1212 | AggregateFunc::MaxFloat32
1213 | AggregateFunc::MaxFloat64
1214 | AggregateFunc::MaxBool
1215 | AggregateFunc::MaxString
1216 | AggregateFunc::MaxDate
1217 | AggregateFunc::MaxTimestamp
1218 | AggregateFunc::MaxTimestampTz
1219 | AggregateFunc::MinInt16
1220 | AggregateFunc::MinInt32
1221 | AggregateFunc::MinInt64
1222 | AggregateFunc::MinUInt16
1223 | AggregateFunc::MinUInt32
1224 | AggregateFunc::MinUInt64
1225 | AggregateFunc::MinMzTimestamp
1226 | AggregateFunc::MinFloat32
1227 | AggregateFunc::MinFloat64
1228 | AggregateFunc::MinBool
1229 | AggregateFunc::MinString
1230 | AggregateFunc::MinDate
1231 | AggregateFunc::MinTimestamp
1232 | AggregateFunc::MinTimestampTz
1233 | AggregateFunc::Any
1234 | AggregateFunc::All => true,
1235 _ => false,
1236 }
1237}