mz_transform/analysis/equivalences.rs
1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
22use mz_expr::explain::{HumanizedExplain, HumanizerMode};
23use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
24use mz_ore::str::{bracketed, separated};
25use mz_repr::{Datum, SqlColumnType};
26
27use crate::analysis::{Analysis, Lattice};
28use crate::analysis::{Arity, SqlRelationType};
29use crate::analysis::{Derived, DerivedBuilder};
30
31/// Pulls up and pushes down predicate information represented as equivalences
32#[derive(Debug, Default)]
33pub struct Equivalences;
34
35impl Analysis for Equivalences {
36 // A `Some(list)` indicates a list of classes of equivalent expressions.
37 // A `None` indicates all expressions are equivalent, including contradictions;
38 // this is only possible for the empty collection, and as an initial result for
39 // unconstrained recursive terms.
40 type Value = Option<EquivalenceClasses>;
41
42 fn announce_dependencies(builder: &mut DerivedBuilder) {
43 builder.require(Arity);
44 builder.require(SqlRelationType); // needed for expression reduction.
45 }
46
47 fn derive(
48 &self,
49 expr: &MirRelationExpr,
50 index: usize,
51 results: &[Self::Value],
52 depends: &Derived,
53 ) -> Self::Value {
54 let mut equivalences = match expr {
55 MirRelationExpr::Constant { rows, typ } => {
56 // Trawl `rows` for any constant information worth recording.
57 // Literal columns may be valuable; non-nullability could be too.
58 let mut equivalences = EquivalenceClasses::default();
59 if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
60 // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
61 let len = row.iter().count();
62 let mut common = Vec::with_capacity(len);
63 common.extend(row.iter().map(Some));
64 // Prep initial nullability information.
65 let mut nullable_cols = common
66 .iter()
67 .map(|datum| datum == &Some(Datum::Null))
68 .collect::<Vec<_>>();
69
70 for (row, _cnt) in rows.iter() {
71 for ((datum, common), nullable) in row
72 .iter()
73 .zip(common.iter_mut())
74 .zip(nullable_cols.iter_mut())
75 {
76 if Some(datum) != *common {
77 *common = None;
78 }
79 if datum == Datum::Null {
80 *nullable = true;
81 }
82 }
83 }
84 for (index, common) in common.into_iter().enumerate() {
85 if let Some(datum) = common {
86 equivalences.classes.push(vec![
87 MirScalarExpr::column(index),
88 MirScalarExpr::literal_ok(
89 datum,
90 typ.column_types[index].scalar_type.clone(),
91 ),
92 ]);
93 }
94 }
95 // If any columns are non-null, introduce this fact.
96 if nullable_cols.iter().any(|x| !*x) {
97 let mut class = vec![MirScalarExpr::literal_false()];
98 for (index, nullable) in nullable_cols.iter().enumerate() {
99 if !*nullable {
100 class.push(MirScalarExpr::column(index).call_is_null());
101 }
102 }
103 equivalences.classes.push(class);
104 }
105 }
106 Some(equivalences)
107 }
108 MirRelationExpr::Get { id, typ, .. } => {
109 let mut equivalences = Some(EquivalenceClasses::default());
110 // Find local identifiers, but nothing for external identifiers.
111 if let Id::Local(id) = id {
112 if let Some(offset) = depends.bindings().get(id) {
113 // It is possible we have derived nothing for a recursive term
114 if let Some(result) = results.get(*offset) {
115 equivalences.clone_from(result);
116 } else {
117 // No top element was prepared.
118 // This means we are executing pessimistically,
119 // but perhaps we must because optimism is off.
120 }
121 }
122 }
123 // Incorporate statements about column nullability.
124 let mut non_null_cols = vec![MirScalarExpr::literal_false()];
125 for (index, col_type) in typ.column_types.iter().enumerate() {
126 if !col_type.nullable {
127 non_null_cols.push(MirScalarExpr::column(index).call_is_null());
128 }
129 }
130 if non_null_cols.len() > 1 {
131 if let Some(equivalences) = equivalences.as_mut() {
132 equivalences.classes.push(non_null_cols);
133 }
134 }
135
136 equivalences
137 }
138 MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
139 MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
140 MirRelationExpr::Project { outputs, .. } => {
141 // restrict equivalences, and introduce equivalences for repeated outputs.
142 let mut equivalences = results.get(index - 1).unwrap().clone();
143 equivalences
144 .as_mut()
145 .map(|e| e.project(outputs.iter().cloned()));
146 equivalences
147 }
148 MirRelationExpr::Map { scalars, .. } => {
149 // introduce equivalences for new columns and expressions that define them.
150 let mut equivalences = results.get(index - 1).unwrap().clone();
151 if let Some(equivalences) = &mut equivalences {
152 let input_arity = depends.results::<Arity>()[index - 1];
153 for (pos, expr) in scalars.iter().enumerate() {
154 equivalences
155 .classes
156 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
157 }
158 }
159 equivalences
160 }
161 MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
162 MirRelationExpr::Filter { predicates, .. } => {
163 let mut equivalences = results.get(index - 1).unwrap().clone();
164 if let Some(equivalences) = &mut equivalences {
165 let mut class = predicates.clone();
166 class.push(MirScalarExpr::literal_ok(
167 Datum::True,
168 mz_repr::SqlScalarType::Bool,
169 ));
170 equivalences.classes.push(class);
171 }
172 equivalences
173 }
174 MirRelationExpr::Join { equivalences, .. } => {
175 // Collect equivalences from all inputs;
176 let expr_index = index;
177 let mut children = depends
178 .children_of_rev(expr_index, expr.children().count())
179 .collect::<Vec<_>>();
180 children.reverse();
181
182 let arity = depends.results::<Arity>();
183 let mut columns = 0;
184 let mut result = Some(EquivalenceClasses::default());
185 for child in children.into_iter() {
186 let input_arity = arity[child];
187 let equivalences = results[child].clone();
188 if let Some(mut equivalences) = equivalences {
189 let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
190 equivalences.permute(&permutation);
191 result
192 .as_mut()
193 .map(|e| e.classes.extend(equivalences.classes));
194 } else {
195 result = None;
196 }
197 columns += input_arity;
198 }
199
200 // Fold join equivalences into our results.
201 result
202 .as_mut()
203 .map(|e| e.classes.extend(equivalences.iter().cloned()));
204 result
205 }
206 MirRelationExpr::Reduce {
207 group_key,
208 aggregates,
209 ..
210 } => {
211 let input_arity = depends.results::<Arity>()[index - 1];
212 let mut equivalences = results.get(index - 1).unwrap().clone();
213 if let Some(equivalences) = &mut equivalences {
214 // Introduce keys column equivalences as if a map, then project to those columns.
215 // This should retain as much information as possible about these columns.
216 for (pos, expr) in group_key.iter().enumerate() {
217 equivalences
218 .classes
219 .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
220 }
221
222 // Having added classes to `equivalences`, we should minimize the classes to fold the
223 // information in before applying the `project`, to set it up for success.
224 equivalences.minimize(None);
225
226 // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
227 let extended = equivalences.clone();
228 // Now project down the equivalences, as we will extend them in terms of the output columns.
229 equivalences.project(input_arity..(input_arity + group_key.len()));
230
231 // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
232 // They also pass through equivalences of them and other constant columns (e.g. key columns).
233 // However, it is not correct to simply project onto these columns, as relationships amongst
234 // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
235 // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
236 // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
237 for (index, aggregate) in aggregates.iter().enumerate() {
238 if aggregate_is_input(&aggregate.func) {
239 let mut temp_equivs = extended.clone();
240 temp_equivs.classes.push(vec![
241 MirScalarExpr::column(input_arity + group_key.len()),
242 aggregate.expr.clone(),
243 ]);
244 temp_equivs.minimize(None);
245 temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
246 let columns = (0..group_key.len())
247 .chain(std::iter::once(group_key.len() + index))
248 .collect::<Vec<_>>();
249 temp_equivs.permute(&columns[..]);
250 equivalences.classes.extend(temp_equivs.classes);
251 }
252 }
253 }
254 equivalences
255 }
256 MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
257 MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
258 MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
259 MirRelationExpr::Union { .. } => {
260 let expr_index = index;
261 let mut child_equivs = depends
262 .children_of_rev(expr_index, expr.children().count())
263 .flat_map(|c| &results[c]);
264 if let Some(first) = child_equivs.next() {
265 Some(first.union_many(child_equivs))
266 } else {
267 None
268 }
269 }
270 MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
271 };
272
273 let expr_type = depends.results::<SqlRelationType>()[index].as_ref();
274 equivalences
275 .as_mut()
276 .map(|e| e.minimize(expr_type.map(|x| &x[..])));
277 equivalences
278 }
279
280 fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
281 Some(Box::new(EQLattice))
282 }
283}
284
285struct EQLattice;
286
287impl Lattice<Option<EquivalenceClasses>> for EQLattice {
288 fn top(&self) -> Option<EquivalenceClasses> {
289 None
290 }
291
292 fn meet_assign(
293 &self,
294 a: &mut Option<EquivalenceClasses>,
295 b: Option<EquivalenceClasses>,
296 ) -> bool {
297 match (&mut *a, b) {
298 (_, None) => false,
299 (None, b) => {
300 *a = b;
301 true
302 }
303 (Some(a), Some(b)) => {
304 let mut c = a.union(&b);
305 std::mem::swap(a, &mut c);
306 a != &mut c
307 }
308 }
309 }
310}
311
312/// A compact representation of classes of expressions that must be equivalent.
313///
314/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
315/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
316/// and any other element of that list can be replaced by it.
317///
318/// The classes are meant to be minimized, with each expression as reduced as it can be,
319/// and all classes sharing an element merged.
320#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
321pub struct EquivalenceClasses {
322 /// Multiple lists of equivalent expressions, each representing an equivalence class.
323 ///
324 /// The first element should be the "canonical" simplest element, that any other element
325 /// can be replaced by.
326 /// These classes are unified whenever possible, to minimize the number of classes.
327 /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
328 /// which refreshes both `self.classes` and `self.remap`.
329 pub classes: Vec<Vec<MirScalarExpr>>,
330
331 /// An expression simplification map.
332 ///
333 /// This map reflects an equivalence relation based on a prior version of `self.classes`.
334 /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
335 /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
336 ///
337 /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
338 /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
339 /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
340 remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
341}
342
343/// An enum that can represent either a standard `EquivalenceClasses` or a
344/// `EquivalenceClassesWithholdingErrors`.
345#[derive(Clone, Debug)]
346pub enum EqClassesImpl {
347 /// Standard equivalence classes.
348 EquivalenceClasses(EquivalenceClasses),
349 /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
350 EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
351}
352
353impl EqClassesImpl {
354 /// Returns a reference to the underlying `EquivalenceClasses`.
355 pub fn equivalence_classes(&self) -> &EquivalenceClasses {
356 match self {
357 EqClassesImpl::EquivalenceClasses(classes) => classes,
358 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
359 &classes.equivalence_classes
360 }
361 }
362 }
363
364 /// Returns a mutable reference to the underlying `EquivalenceClasses`.
365 pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
366 match self {
367 EqClassesImpl::EquivalenceClasses(classes) => classes,
368 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
369 &mut classes.equivalence_classes
370 }
371 }
372 }
373
374 /// Extend the equivalence classes with new equivalences.
375 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
376 match self {
377 EqClassesImpl::EquivalenceClasses(classes) => {
378 classes.classes.extend(equivalences);
379 }
380 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
381 classes.extend_equivalences(equivalences);
382 }
383 }
384 }
385
386 /// Push a new equivalence class.
387 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
388 match self {
389 EqClassesImpl::EquivalenceClasses(classes) => {
390 classes.classes.push(equivalences);
391 }
392 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
393 classes.push_equivalences(equivalences);
394 }
395 }
396 }
397
398 /// Subject the constraints to the column projection, reworking and removing equivalences.
399 pub fn project<I>(&mut self, output_columns: I)
400 where
401 I: IntoIterator<Item = usize> + Clone,
402 {
403 match self {
404 EqClassesImpl::EquivalenceClasses(classes) => {
405 classes.project(output_columns);
406 }
407 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
408 classes.project(output_columns);
409 }
410 }
411 }
412
413 /// Extract the equivalences
414 pub fn extract_equivalences(
415 &mut self,
416 columns: Option<&[SqlColumnType]>,
417 ) -> Vec<Vec<MirScalarExpr>> {
418 match self {
419 EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
420 EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
421 classes.extract_equivalences(columns)
422 }
423 }
424 }
425}
426
427/// A wrapper struct for equivalence classes that witholds errors
428/// from the underlying equivalence classes, reintroducing them
429/// when extracting equivalences.
430#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
431pub struct EquivalenceClassesWithholdingErrors {
432 /// The underlying equivalence classes.
433 pub equivalence_classes: EquivalenceClasses,
434 /// A list of equivalence classes that contain errors, which we hold back
435 pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
436}
437
438impl EquivalenceClassesWithholdingErrors {
439 /// Extend the equivalence classes with new equivalences.
440 pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
441 for class in equivalences {
442 self.push_equivalences(class);
443 }
444 }
445
446 /// Push a new equivalence class, which may contain errors.
447 pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
448 let (with_errs, without_errs): (Vec<_>, Vec<_>) =
449 equivalences.into_iter().partition(|e| e.contains_err());
450 self.held_back_classes
451 .push((without_errs.first().cloned(), with_errs));
452 if without_errs.len() > 1 {
453 self.equivalence_classes.classes.push(without_errs);
454 }
455 }
456
457 /// Subject the constraints to the column projection, reworking and removing equivalences.
458 /// This method should also introduce equivalences representing any repeated columns.
459 /// This notably does not project the held back classes.
460 pub fn project<I>(&mut self, output_columns: I)
461 where
462 I: IntoIterator<Item = usize> + Clone,
463 {
464 self.equivalence_classes.project(output_columns.clone());
465 }
466
467 /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
468 pub fn extract_equivalences(
469 &mut self,
470 columns: Option<&[SqlColumnType]>,
471 ) -> Vec<Vec<MirScalarExpr>> {
472 self.equivalence_classes.minimize(columns);
473
474 if self.held_back_classes.is_empty() {
475 self.equivalence_classes.classes.clone()
476 } else {
477 let reducer = self.equivalence_classes.reducer();
478 let mut classes = self.equivalence_classes.classes.clone();
479 for (anchor, mut class) in self.held_back_classes.drain(..) {
480 if let Some(mut anchor) = anchor {
481 // Reduce the anchor expression to its canonical form.
482 reducer.reduce_expr(&mut anchor);
483 class.push(anchor.clone());
484 }
485 classes.push(class.clone());
486 }
487 canonicalize_equivalence_classes(&mut classes);
488 classes
489 }
490 }
491}
492
493/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
494/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
495/// [`HumanizedEquivalenceClasses`].
496impl std::fmt::Display for EquivalenceClasses {
497 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
498 HumanizedEquivalenceClasses {
499 equivalence_classes: self,
500 cols: None,
501 mode: HumanizedExplain::default(),
502 }
503 .fmt(f)
504 }
505}
506
507/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
508/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
509/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
510/// `Display` nor `HumanizedExpr` is defined in this crate.)
511#[derive(Debug)]
512pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
513 /// The [`EquivalenceClasses`] to be humanized.
514 pub equivalence_classes: &'a EquivalenceClasses,
515 /// An optional vector of inferred column names to be used when rendering
516 /// column references in expressions.
517 pub cols: Option<&'a Vec<String>>,
518 /// The rendering mode to use. See `HumanizerMode` for details.
519 pub mode: M,
520}
521
522impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
523 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
524 // Only show `classes`.
525 // (The following hopefully avoids allocating any of the intermediate composite strings.)
526 let classes = self.equivalence_classes.classes.iter().map(|class| {
527 format!(
528 "{}",
529 bracketed(
530 "[",
531 "]",
532 separated(
533 ", ",
534 class.iter().map(|expr| self.mode.expr(expr, self.cols))
535 )
536 )
537 )
538 });
539 write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
540 }
541}
542
543impl EquivalenceClasses {
544 /// Comparator function for the complexity of scalar expressions. Simpler expressions are
545 /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
546 pub fn mir_scalar_expr_complexity(
547 e1: &MirScalarExpr,
548 e2: &MirScalarExpr,
549 ) -> std::cmp::Ordering {
550 use MirScalarExpr::*;
551 use std::cmp::Ordering::*;
552 match (e1, e2) {
553 (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
554 (Literal(_, _), _) => Less,
555 (_, Literal(_, _)) => Greater,
556 (Column(_, _), Column(_, _)) => e1.cmp(e2),
557 (Column(_, _), _) => Less,
558 (_, Column(_, _)) => Greater,
559 (x, y) => {
560 // General expressions should be ordered by their size,
561 // to ensure we only simplify expressions by substitution.
562 // If same size, then fall back to the expressions' Ord.
563 match x.size().cmp(&y.size()) {
564 Equal => x.cmp(y),
565 other => other,
566 }
567 }
568 }
569 }
570
571 /// Sorts and deduplicates each class, removing literal errors.
572 ///
573 /// This method does not ensure equivalence relation structure, but instead performs
574 /// only minimal structural clean-up.
575 fn tidy(&mut self) {
576 for class in self.classes.iter_mut() {
577 // Remove all literal errors, as they cannot be equated to other things.
578 class.retain(|e| !e.is_literal_err());
579 class.sort_by(Self::mir_scalar_expr_complexity);
580 class.dedup();
581 }
582 self.classes.retain(|c| c.len() > 1);
583 self.classes.sort();
584 self.classes.dedup();
585 }
586
587 /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
588 ///
589 /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
590 /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
591 /// whether the equivalence classes it represents have experienced any changes since the
592 /// last refresh.
593 fn refresh(&mut self) -> bool {
594 self.tidy();
595
596 // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
597 // If it contains the same number of expressions as `self.classes`, and for every expression in
598 // `self.classes` the two agree on the representative, they are identical.
599 if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
600 && self
601 .classes
602 .iter()
603 .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
604 {
605 // No change, so return false.
606 return false;
607 }
608
609 // Optimistically build the `remap` we would want.
610 // Note if any unions would be required, in which case we have further work to do,
611 // including re-forming `self.classes`.
612 let mut union_find = BTreeMap::default();
613 let mut dirtied = false;
614 for class in self.classes.iter() {
615 for expr in class.iter() {
616 if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
617 // A merge is required, but have the more complex expression point at the simpler one.
618 // This allows `union_find` to end as the `remap` for the new `classes` we form, with
619 // the only required work being compressing all the paths.
620 if Self::mir_scalar_expr_complexity(&other, &class[0])
621 == std::cmp::Ordering::Less
622 {
623 union_find.union(&class[0], &other);
624 } else {
625 union_find.union(&other, &class[0]);
626 }
627 dirtied = true;
628 }
629 }
630 }
631 if dirtied {
632 let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
633 for class in self.classes.drain(..) {
634 for expr in class {
635 let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
636 classes.entry(root).or_default().push(expr);
637 }
638 }
639 self.classes = classes.into_values().collect();
640 self.tidy();
641 }
642
643 let changed = self.remap != union_find;
644 self.remap = union_find;
645 changed
646 }
647
648 /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
649 ///
650 /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
651 pub fn minimize(&mut self, columns: Option<&[SqlColumnType]>) {
652 // Repeatedly, we reduce each of the classes themselves, then unify the classes.
653 // This should strictly reduce complexity, and reach a fixed point.
654 // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
655
656 // We should not rely on nullability information present in `column_types`. (Doing this
657 // every time just before calling `reduce` was found to be a bottleneck during incident-217,
658 // so now we do this nullability tweaking only once here.)
659 let mut columns = columns.map(|x| x.to_vec());
660 let mut nonnull = Vec::new();
661 if let Some(columns) = columns.as_mut() {
662 for (index, col) in columns.iter_mut().enumerate() {
663 let is_null = MirScalarExpr::column(index).call_is_null();
664 if !col.nullable
665 && self
666 .remap
667 .get(&is_null)
668 .map(|e| !e.is_literal_false())
669 .unwrap_or(true)
670 {
671 nonnull.push(is_null);
672 }
673 col.nullable = true;
674 }
675 }
676 if !nonnull.is_empty() {
677 nonnull.push(MirScalarExpr::literal_false());
678 self.classes.push(nonnull);
679 }
680
681 // Ensure `self.classes` and `self.remap` are equivalence relations.
682 // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
683 // We have also likely mutated `self.classes` just above with non-nullability information.
684 self.refresh();
685
686 // Termination will be detected by comparing to the map of equivalence classes.
687 let mut previous = Some(self.remap.clone());
688 while let Some(prev) = previous {
689 // Attempt to add new equivalences.
690 let novel = self.expand();
691 if !novel.is_empty() {
692 self.classes.extend(novel);
693 self.refresh();
694 }
695
696 // We continue as long as any simplification has occurred.
697 // An expression can be simplified, a duplication found, or two classes unified.
698 let mut stable = false;
699 while !stable {
700 stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
701 }
702
703 // Termination detection.
704 if prev != self.remap {
705 previous = Some(self.remap.clone());
706 } else {
707 previous = None;
708 }
709 }
710 }
711
712 /// Proposes new equivalences that are likely to be novel.
713 ///
714 /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
715 /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
716 /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
717 /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
718 /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
719 /// number of times we call and amount of work we do in `self.refresh()`.
720 fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
721 // Consider expanding `self.classes` with novel equivalences.
722 let mut novel = self.implications();
723 for class in novel.iter_mut() {
724 // reduce each expression to its canonical form.
725 for expr in class.iter_mut() {
726 self.remap.reduce_expr(expr);
727 }
728 class.sort();
729 class.dedup();
730 // for a class to be interesting we require at least two elements that do not reference the same root.
731 let common_class = class
732 .iter()
733 .map(|x| self.remap.get(x))
734 .reduce(|prev, this| if prev == this { prev } else { None });
735 if class.len() == 1 || common_class != Some(None) {
736 class.clear();
737 }
738 }
739 novel.retain(|c| !c.is_empty());
740 novel
741 }
742
743 /// Derives potentially novel equivalences without regard for minimization.
744 ///
745 /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
746 /// and therefore should not be used in `minimize_once`. They are still potentially important, but
747 /// required additional guardrails to ensure we reach a fixed point.
748 ///
749 /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
750 /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
751 /// For example, before producing a new equivalence, one could check that the involved terms are not
752 /// already present in the same class.
753 fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
754 let mut new_equivalences = Vec::new();
755
756 // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
757 let mut non_null = std::collections::BTreeSet::default();
758 for class in self.classes.iter() {
759 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
760 for e in class.iter() {
761 if let MirScalarExpr::CallUnary {
762 func: mz_expr::UnaryFunc::IsNull(_),
763 expr,
764 } = e
765 {
766 expr.non_null_requirements(&mut non_null);
767 }
768 }
769 }
770 }
771 // If we see `true == foo` we can add the non-null implications of `foo`.
772 // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
773 // an important idiom to identify for how we express predicates.
774 for class in self.classes.iter() {
775 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
776 for expr in class.iter() {
777 expr.non_null_requirements(&mut non_null);
778 }
779 }
780 }
781 // Only keep constraints that are not already known.
782 // Known constraints will present as `COL(_) IS NULL == false`,
783 // which can only happen if `false` is present, and both terms
784 // map to the same canonical representative>
785 let lit_false = MirScalarExpr::literal_false();
786 let target = self.remap.get(&lit_false);
787 if target.is_some() {
788 non_null.retain(|c| {
789 let is_null = MirScalarExpr::column(*c).call_is_null();
790 self.remap.get(&is_null) != target
791 });
792 }
793 if !non_null.is_empty() {
794 let mut class = Vec::with_capacity(non_null.len() + 1);
795 class.push(MirScalarExpr::literal_false());
796 class.extend(
797 non_null
798 .into_iter()
799 .map(|c| MirScalarExpr::column(c).call_is_null()),
800 );
801 new_equivalences.push(class);
802 }
803
804 // If we see records formed from other expressions, we can equate the expressions with
805 // accessors applied to the class of the record former. In `minimize_once` we reduce by
806 // equivalence class representative before we perform expression simplification, so we
807 // shoud be able to just use the expression former, rather than find its representative.
808 // The risk, potentially, is that we would apply accessors to the record former and then
809 // just simplify it away learning nothing.
810 for class in self.classes.iter() {
811 for expr in class.iter() {
812 // Record-forming expressions can equate their accessors and their members.
813 if let MirScalarExpr::CallVariadic {
814 func: mz_expr::VariadicFunc::RecordCreate { .. },
815 exprs,
816 } = expr
817 {
818 for (index, e) in exprs.iter().enumerate() {
819 new_equivalences.push(vec![
820 e.clone(),
821 expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
822 mz_expr::func::RecordGet(index),
823 )),
824 ]);
825 }
826 }
827 }
828 }
829
830 // Return all newly established equivalences.
831 new_equivalences
832 }
833
834 /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
835 ///
836 /// This invocation should take roughly linear time.
837 /// It starts with equivalence class invariants maintained (closed under transitivity), and then
838 /// 1. Performs per-expression reduction, including the class structure to replace subexpressions.
839 /// 2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
840 /// 3. Restores the equivalence class invariants.
841 fn minimize_once(&mut self, columns: Option<&[SqlColumnType]>) -> bool {
842 // 1. Reduce each expression
843 //
844 // This reduction first looks for subexpression substitutions that can be performed,
845 // and then applies expression reduction if column type information is provided.
846 for class in self.classes.iter_mut() {
847 for expr in class.iter_mut() {
848 self.remap.reduce_child(expr);
849 if let Some(columns) = columns {
850 let orig_expr = expr.clone();
851 expr.reduce(columns);
852 if expr.contains_err() {
853 // Rollback to the original expression if it contains an error.
854 *expr = orig_expr;
855 }
856 }
857 }
858 }
859
860 // 2. Identify idioms
861 // E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
862 let mut to_add = Vec::new();
863 for class in self.classes.iter_mut() {
864 if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
865 for expr in class.iter() {
866 // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
867 // This substitution replaces a complex expression with several smaller expressions, and cannot
868 // cycle if we follow that practice.
869 if let MirScalarExpr::CallBinary {
870 func: mz_expr::BinaryFunc::Eq,
871 expr1,
872 expr2,
873 } = expr
874 {
875 to_add.push(vec![*expr1.clone(), *expr2.clone()]);
876 to_add.push(vec![
877 MirScalarExpr::literal_false(),
878 expr1.clone().call_is_null(),
879 expr2.clone().call_is_null(),
880 ]);
881 }
882 }
883 // Remove the more complex form of the expression.
884 class.retain(|expr| {
885 if let MirScalarExpr::CallBinary {
886 func: mz_expr::BinaryFunc::Eq,
887 ..
888 } = expr
889 {
890 false
891 } else {
892 true
893 }
894 });
895 for expr in class.iter() {
896 // If TRUE == NOT(X) then FALSE == X is a simpler form.
897 if let MirScalarExpr::CallUnary {
898 func: mz_expr::UnaryFunc::Not(_),
899 expr: e,
900 } = expr
901 {
902 to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
903 }
904 }
905 class.retain(|expr| {
906 if let MirScalarExpr::CallUnary {
907 func: mz_expr::UnaryFunc::Not(_),
908 ..
909 } = expr
910 {
911 false
912 } else {
913 true
914 }
915 });
916 }
917 if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
918 for expr in class.iter() {
919 // If FALSE == NOT(X) then TRUE == X is a simpler form.
920 if let MirScalarExpr::CallUnary {
921 func: mz_expr::UnaryFunc::Not(_),
922 expr: e,
923 } = expr
924 {
925 to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
926 }
927 }
928 class.retain(|expr| {
929 if let MirScalarExpr::CallUnary {
930 func: mz_expr::UnaryFunc::Not(_),
931 ..
932 } = expr
933 {
934 false
935 } else {
936 true
937 }
938 });
939 }
940 }
941 self.classes.extend(to_add);
942
943 // 3. Restore equivalence relation structure and observe if any changes result.
944 self.refresh()
945 }
946
947 /// Produce the equivalences present in both inputs.
948 pub fn union(&self, other: &Self) -> Self {
949 self.union_many([other])
950 }
951
952 /// The equivalence classes of terms equivalent in all inputs.
953 ///
954 /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
955 /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
956 /// correct, result.
957 ///
958 /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
959 /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
960 /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
961 pub fn union_many<'a, I>(&self, others: I) -> Self
962 where
963 I: IntoIterator<Item = &'a Self>,
964 {
965 // List of expressions in the intersection, and a proxy equivalence class identifier.
966 let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
967 // Map from expression to a proxy equivalence class identifier.
968 let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
969 for (key, val) in self.remap.iter() {
970 if !rekey.contains_key(val) {
971 rekey.insert(val, rekey.len());
972 }
973 intersection.push((key, rekey[val]));
974 }
975 for other in others {
976 // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
977 let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
978 intersection.retain_mut(|(key, idx)| {
979 if let Some(val) = other.remap.get(key) {
980 if !rekey.contains_key(&(*idx, val)) {
981 rekey.insert((*idx, val), rekey.len());
982 }
983 *idx = rekey[&(*idx, val)];
984 true
985 } else {
986 false
987 }
988 });
989 }
990 let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
991 for (key, vals) in intersection {
992 classes.entry(vals).or_default().push(key.clone())
993 }
994 let classes = classes.into_values().collect::<Vec<_>>();
995 let mut equivalences = EquivalenceClasses {
996 classes,
997 remap: Default::default(),
998 };
999 equivalences.minimize(None);
1000 equivalences
1001 }
1002
1003 /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1004 pub fn permute(&mut self, permutation: &[usize]) {
1005 for class in self.classes.iter_mut() {
1006 for expr in class.iter_mut() {
1007 expr.permute(permutation);
1008 }
1009 }
1010 self.remap.clear();
1011 self.minimize(None);
1012 }
1013
1014 /// Subject the constraints to the column projection, reworking and removing equivalences.
1015 ///
1016 /// This method should also introduce equivalences representing any repeated columns.
1017 pub fn project<I>(&mut self, output_columns: I)
1018 where
1019 I: IntoIterator<Item = usize> + Clone,
1020 {
1021 // Retain the first instance of each column, and record subsequent instances as duplicates.
1022 let mut dupes = Vec::new();
1023 let mut remap = BTreeMap::default();
1024 for (idx, col) in output_columns.into_iter().enumerate() {
1025 if let Some(pos) = remap.get(&col) {
1026 dupes.push((*pos, idx));
1027 } else {
1028 remap.insert(col, idx);
1029 }
1030 }
1031
1032 // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1033 // Many expressions may not be localized, but may reference canonical non-localized expressions
1034 // for classes that contain a localized expression; in that case we can "backport" the localized
1035 // expression to give expressions referencing the canonical expression a shot at localization.
1036 //
1037 // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1038 // to look any further than backporting those. Backporting should have the property that the simplest
1039 // localized expression in each class does not contain any non-localized canonical expressions
1040 // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1041 // expressions should never fire a second
1042
1043 // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1044 // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1045 // As we find localized expressions, we can replace uses of their equivalent representative with them,
1046 // which may allow further expression localization.
1047 // We continue the process until no further classes can be localized.
1048
1049 // A map from representatives to our first localization of their equivalence class.
1050 let mut localized = false;
1051 while !localized {
1052 localized = true;
1053 let mut current_map = BTreeMap::default();
1054 for class in self.classes.iter_mut() {
1055 if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1056 if let Some(pos) = class
1057 .iter()
1058 .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1059 {
1060 class.swap(0, pos);
1061 localized = false;
1062 }
1063 }
1064 for expr in class[1..].iter() {
1065 current_map.insert(expr.clone(), class[0].clone());
1066 }
1067 }
1068
1069 // attempt to replace representatives with equivalent localizeable expressions.
1070 for class_index in 0..self.classes.len() {
1071 for index in 0..self.classes[class_index].len() {
1072 current_map.reduce_child(&mut self.classes[class_index][index]);
1073 }
1074 }
1075 // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1076 }
1077
1078 // Localize all localizable expressions and discard others.
1079 for class in self.classes.iter_mut() {
1080 class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1081 for expr in class.iter_mut() {
1082 expr.permute_map(&remap);
1083 }
1084 }
1085 self.classes.retain(|c| c.len() > 1);
1086 // If column repetitions, introduce them as equivalences.
1087 // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1088 for (col1, col2) in dupes {
1089 self.classes.push(vec![
1090 MirScalarExpr::column(col1),
1091 MirScalarExpr::column(col2),
1092 ]);
1093 }
1094 self.remap.clear();
1095 self.minimize(None);
1096 }
1097
1098 /// True if any equivalence class contains two distinct non-error literals.
1099 pub fn unsatisfiable(&self) -> bool {
1100 for class in self.classes.iter() {
1101 let mut literal_ok = None;
1102 for expr in class.iter() {
1103 if let MirScalarExpr::Literal(Ok(row), _) = expr {
1104 if literal_ok.is_some() && literal_ok != Some(row) {
1105 return true;
1106 } else {
1107 literal_ok = Some(row);
1108 }
1109 }
1110 }
1111 }
1112 false
1113 }
1114
1115 /// Returns a map that can be used to replace (sub-)expressions.
1116 pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1117 &self.remap
1118 }
1119
1120 /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1121 ///
1122 /// This test bails out as soon as it sees a non-literal, and may have false negatives
1123 /// if the data are not sorted with literals at the front.
1124 fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1125 where
1126 P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1127 {
1128 class
1129 .iter()
1130 .take_while(|e| e.is_literal())
1131 .filter_map(|e| e.as_literal())
1132 .any(move |e| predicate(&e))
1133 }
1134}
1135
1136/// A type capable of simplifying `MirScalarExpr`s.
1137pub trait ExpressionReducer {
1138 /// Attempt to replace `expr` itself with another expression.
1139 /// Returns true if it does so.
1140 fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1141 /// Attempt to replace any subexpressions of `expr` with other expressions.
1142 /// Returns true if it does so.
1143 fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1144 let mut simplified = false;
1145 simplified = simplified || self.reduce_child(expr);
1146 simplified = simplified || self.replace(expr);
1147 simplified
1148 }
1149 /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1150 /// Returns true if it does so.
1151 fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1152 let mut simplified = false;
1153 match expr {
1154 MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1155 simplified = self.reduce_expr(expr1) || simplified;
1156 simplified = self.reduce_expr(expr2) || simplified;
1157 }
1158 MirScalarExpr::CallUnary { expr, .. } => {
1159 simplified = self.reduce_expr(expr) || simplified;
1160 }
1161 MirScalarExpr::CallVariadic { exprs, .. } => {
1162 for expr in exprs.iter_mut() {
1163 simplified = self.reduce_expr(expr) || simplified;
1164 }
1165 }
1166 MirScalarExpr::If { cond: _, then, els } => {
1167 // Do not simplify `cond`, as we cannot ensure the simplification
1168 // continues to hold as expressions migrate around.
1169 simplified = self.reduce_expr(then) || simplified;
1170 simplified = self.reduce_expr(els) || simplified;
1171 }
1172 _ => {}
1173 }
1174 simplified
1175 }
1176}
1177
1178impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1179 /// Perform any exact replacement for `expr`, report if it had an effect.
1180 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1181 if let Some(other) = self.get(expr) {
1182 if other != &expr {
1183 expr.clone_from(other);
1184 return true;
1185 }
1186 }
1187 false
1188 }
1189}
1190
1191impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1192 /// Perform any exact replacement for `expr`, report if it had an effect.
1193 fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1194 if let Some(other) = self.get(expr) {
1195 if other != expr {
1196 expr.clone_from(other);
1197 return true;
1198 }
1199 }
1200 false
1201 }
1202}
1203
1204/// True iff the aggregate function returns an input datum.
1205fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1206 match aggregate {
1207 AggregateFunc::MaxInt16
1208 | AggregateFunc::MaxInt32
1209 | AggregateFunc::MaxInt64
1210 | AggregateFunc::MaxUInt16
1211 | AggregateFunc::MaxUInt32
1212 | AggregateFunc::MaxUInt64
1213 | AggregateFunc::MaxMzTimestamp
1214 | AggregateFunc::MaxFloat32
1215 | AggregateFunc::MaxFloat64
1216 | AggregateFunc::MaxBool
1217 | AggregateFunc::MaxString
1218 | AggregateFunc::MaxDate
1219 | AggregateFunc::MaxTimestamp
1220 | AggregateFunc::MaxTimestampTz
1221 | AggregateFunc::MinInt16
1222 | AggregateFunc::MinInt32
1223 | AggregateFunc::MinInt64
1224 | AggregateFunc::MinUInt16
1225 | AggregateFunc::MinUInt32
1226 | AggregateFunc::MinUInt64
1227 | AggregateFunc::MinMzTimestamp
1228 | AggregateFunc::MinFloat32
1229 | AggregateFunc::MinFloat64
1230 | AggregateFunc::MinBool
1231 | AggregateFunc::MinString
1232 | AggregateFunc::MinDate
1233 | AggregateFunc::MinTimestamp
1234 | AggregateFunc::MinTimestampTz
1235 | AggregateFunc::Any
1236 | AggregateFunc::All => true,
1237 _ => false,
1238 }
1239}