mz_transform/analysis/
equivalences.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use itertools::Itertools;
22use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
23use mz_expr::explain::{HumanizedExplain, HumanizerMode};
24use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
25use mz_ore::str::{bracketed, separated};
26use mz_repr::{ColumnType, Datum};
27
28use crate::analysis::{Analysis, Lattice};
29use crate::analysis::{Arity, RelationType};
30use crate::analysis::{Derived, DerivedBuilder};
31
32/// Pulls up and pushes down predicate information represented as equivalences
33#[derive(Debug, Default)]
34pub struct Equivalences;
35
36impl Analysis for Equivalences {
37    // A `Some(list)` indicates a list of classes of equivalent expressions.
38    // A `None` indicates all expressions are equivalent, including contradictions;
39    // this is only possible for the empty collection, and as an initial result for
40    // unconstrained recursive terms.
41    type Value = Option<EquivalenceClasses>;
42
43    fn announce_dependencies(builder: &mut DerivedBuilder) {
44        builder.require(Arity);
45        builder.require(RelationType); // needed for expression reduction.
46    }
47
48    fn derive(
49        &self,
50        expr: &MirRelationExpr,
51        index: usize,
52        results: &[Self::Value],
53        depends: &Derived,
54    ) -> Self::Value {
55        let mut equivalences = match expr {
56            MirRelationExpr::Constant { rows, typ } => {
57                // Trawl `rows` for any constant information worth recording.
58                // Literal columns may be valuable; non-nullability could be too.
59                let mut equivalences = EquivalenceClasses::default();
60                if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
61                    // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
62                    let len = row.iter().count();
63                    let mut common = Vec::with_capacity(len);
64                    common.extend(row.iter().map(Some));
65                    // Prep initial nullability information.
66                    let mut nullable_cols = common
67                        .iter()
68                        .map(|datum| datum == &Some(Datum::Null))
69                        .collect::<Vec<_>>();
70
71                    for (row, _cnt) in rows.iter() {
72                        for ((datum, common), nullable) in row
73                            .iter()
74                            .zip_eq(common.iter_mut())
75                            .zip_eq(nullable_cols.iter_mut())
76                        {
77                            if Some(datum) != *common {
78                                *common = None;
79                            }
80                            if datum == Datum::Null {
81                                *nullable = true;
82                            }
83                        }
84                    }
85                    for (index, common) in common.into_iter().enumerate() {
86                        if let Some(datum) = common {
87                            equivalences.classes.push(vec![
88                                MirScalarExpr::column(index),
89                                MirScalarExpr::literal_ok(
90                                    datum,
91                                    typ.column_types[index].scalar_type.clone(),
92                                ),
93                            ]);
94                        }
95                    }
96                    // If any columns are non-null, introduce this fact.
97                    if nullable_cols.iter().any(|x| !*x) {
98                        let mut class = vec![MirScalarExpr::literal_false()];
99                        for (index, nullable) in nullable_cols.iter().enumerate() {
100                            if !*nullable {
101                                class.push(MirScalarExpr::column(index).call_is_null());
102                            }
103                        }
104                        equivalences.classes.push(class);
105                    }
106                }
107                Some(equivalences)
108            }
109            MirRelationExpr::Get { id, typ, .. } => {
110                let mut equivalences = Some(EquivalenceClasses::default());
111                // Find local identifiers, but nothing for external identifiers.
112                if let Id::Local(id) = id {
113                    if let Some(offset) = depends.bindings().get(id) {
114                        // It is possible we have derived nothing for a recursive term
115                        if let Some(result) = results.get(*offset) {
116                            equivalences.clone_from(result);
117                        } else {
118                            // No top element was prepared.
119                            // This means we are executing pessimistically,
120                            // but perhaps we must because optimism is off.
121                        }
122                    }
123                }
124                // Incorporate statements about column nullability.
125                let mut non_null_cols = vec![MirScalarExpr::literal_false()];
126                for (index, col_type) in typ.column_types.iter().enumerate() {
127                    if !col_type.nullable {
128                        non_null_cols.push(MirScalarExpr::column(index).call_is_null());
129                    }
130                }
131                if non_null_cols.len() > 1 {
132                    if let Some(equivalences) = equivalences.as_mut() {
133                        equivalences.classes.push(non_null_cols);
134                    }
135                }
136
137                equivalences
138            }
139            MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
140            MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
141            MirRelationExpr::Project { outputs, .. } => {
142                // restrict equivalences, and introduce equivalences for repeated outputs.
143                let mut equivalences = results.get(index - 1).unwrap().clone();
144                equivalences
145                    .as_mut()
146                    .map(|e| e.project(outputs.iter().cloned()));
147                equivalences
148            }
149            MirRelationExpr::Map { scalars, .. } => {
150                // introduce equivalences for new columns and expressions that define them.
151                let mut equivalences = results.get(index - 1).unwrap().clone();
152                if let Some(equivalences) = &mut equivalences {
153                    let input_arity = depends.results::<Arity>()[index - 1];
154                    for (pos, expr) in scalars.iter().enumerate() {
155                        equivalences
156                            .classes
157                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
158                    }
159                }
160                equivalences
161            }
162            MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
163            MirRelationExpr::Filter { predicates, .. } => {
164                let mut equivalences = results.get(index - 1).unwrap().clone();
165                if let Some(equivalences) = &mut equivalences {
166                    let mut class = predicates.clone();
167                    class.push(MirScalarExpr::literal_ok(
168                        Datum::True,
169                        mz_repr::ScalarType::Bool,
170                    ));
171                    equivalences.classes.push(class);
172                }
173                equivalences
174            }
175            MirRelationExpr::Join { equivalences, .. } => {
176                // Collect equivalences from all inputs;
177                let expr_index = index;
178                let mut children = depends
179                    .children_of_rev(expr_index, expr.children().count())
180                    .collect::<Vec<_>>();
181                children.reverse();
182
183                let arity = depends.results::<Arity>();
184                let mut columns = 0;
185                let mut result = Some(EquivalenceClasses::default());
186                for child in children.into_iter() {
187                    let input_arity = arity[child];
188                    let equivalences = results[child].clone();
189                    if let Some(mut equivalences) = equivalences {
190                        let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
191                        equivalences.permute(&permutation);
192                        result
193                            .as_mut()
194                            .map(|e| e.classes.extend(equivalences.classes));
195                    } else {
196                        result = None;
197                    }
198                    columns += input_arity;
199                }
200
201                // Fold join equivalences into our results.
202                result
203                    .as_mut()
204                    .map(|e| e.classes.extend(equivalences.iter().cloned()));
205                result
206            }
207            MirRelationExpr::Reduce {
208                group_key,
209                aggregates,
210                ..
211            } => {
212                let input_arity = depends.results::<Arity>()[index - 1];
213                let mut equivalences = results.get(index - 1).unwrap().clone();
214                if let Some(equivalences) = &mut equivalences {
215                    // Introduce keys column equivalences as if a map, then project to those columns.
216                    // This should retain as much information as possible about these columns.
217                    for (pos, expr) in group_key.iter().enumerate() {
218                        equivalences
219                            .classes
220                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
221                    }
222
223                    // Having added classes to `equivalences`, we should minimize the classes to fold the
224                    // information in before applying the `project`, to set it up for success.
225                    equivalences.minimize(None);
226
227                    // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
228                    let extended = equivalences.clone();
229                    // Now project down the equivalences, as we will extend them in terms of the output columns.
230                    equivalences.project(input_arity..(input_arity + group_key.len()));
231
232                    // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
233                    // They also pass through equivalences of them and other constant columns (e.g. key columns).
234                    // However, it is not correct to simply project onto these columns, as relationships amongst
235                    // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
236                    // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
237                    // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
238                    for (index, aggregate) in aggregates.iter().enumerate() {
239                        if aggregate_is_input(&aggregate.func) {
240                            let mut temp_equivs = extended.clone();
241                            temp_equivs.classes.push(vec![
242                                MirScalarExpr::column(input_arity + group_key.len()),
243                                aggregate.expr.clone(),
244                            ]);
245                            temp_equivs.minimize(None);
246                            temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
247                            let columns = (0..group_key.len())
248                                .chain(std::iter::once(group_key.len() + index))
249                                .collect::<Vec<_>>();
250                            temp_equivs.permute(&columns[..]);
251                            equivalences.classes.extend(temp_equivs.classes);
252                        }
253                    }
254                }
255                equivalences
256            }
257            MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
258            MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
259            MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
260            MirRelationExpr::Union { .. } => {
261                let expr_index = index;
262                let mut child_equivs = depends
263                    .children_of_rev(expr_index, expr.children().count())
264                    .flat_map(|c| &results[c]);
265                if let Some(first) = child_equivs.next() {
266                    Some(first.union_many(child_equivs))
267                } else {
268                    None
269                }
270            }
271            MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
272        };
273
274        let expr_type = depends.results::<RelationType>()[index].as_ref();
275        equivalences
276            .as_mut()
277            .map(|e| e.minimize(expr_type.map(|x| &x[..])));
278        equivalences
279    }
280
281    fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
282        Some(Box::new(EQLattice))
283    }
284}
285
286struct EQLattice;
287
288impl Lattice<Option<EquivalenceClasses>> for EQLattice {
289    fn top(&self) -> Option<EquivalenceClasses> {
290        None
291    }
292
293    fn meet_assign(
294        &self,
295        a: &mut Option<EquivalenceClasses>,
296        b: Option<EquivalenceClasses>,
297    ) -> bool {
298        match (&mut *a, b) {
299            (_, None) => false,
300            (None, b) => {
301                *a = b;
302                true
303            }
304            (Some(a), Some(b)) => {
305                let mut c = a.union(&b);
306                std::mem::swap(a, &mut c);
307                a != &mut c
308            }
309        }
310    }
311}
312
313/// A compact representation of classes of expressions that must be equivalent.
314///
315/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
316/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
317/// and any other element of that list can be replaced by it.
318///
319/// The classes are meant to be minimized, with each expression as reduced as it can be,
320/// and all classes sharing an element merged.
321#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
322pub struct EquivalenceClasses {
323    /// Multiple lists of equivalent expressions, each representing an equivalence class.
324    ///
325    /// The first element should be the "canonical" simplest element, that any other element
326    /// can be replaced by.
327    /// These classes are unified whenever possible, to minimize the number of classes.
328    /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
329    /// which refreshes both `self.classes` and `self.remap`.
330    pub classes: Vec<Vec<MirScalarExpr>>,
331
332    /// An expression simplification map.
333    ///
334    /// This map reflects an equivalence relation based on a prior version of `self.classes`.
335    /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
336    /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
337    ///
338    /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
339    /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
340    /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
341    remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
342}
343
344/// An enum that can represent either a standard `EquivalenceClasses` or a
345/// `EquivalenceClassesWithholdingErrors`.
346#[derive(Clone, Debug)]
347pub enum EqClassesImpl {
348    /// Standard equivalence classes.
349    EquivalenceClasses(EquivalenceClasses),
350    /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
351    EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
352}
353
354impl EqClassesImpl {
355    /// Returns a reference to the underlying `EquivalenceClasses`.
356    pub fn equivalence_classes(&self) -> &EquivalenceClasses {
357        match self {
358            EqClassesImpl::EquivalenceClasses(classes) => classes,
359            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
360                &classes.equivalence_classes
361            }
362        }
363    }
364
365    /// Returns a mutable reference to the underlying `EquivalenceClasses`.
366    pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
367        match self {
368            EqClassesImpl::EquivalenceClasses(classes) => classes,
369            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
370                &mut classes.equivalence_classes
371            }
372        }
373    }
374
375    /// Extend the equivalence classes with new equivalences.
376    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
377        match self {
378            EqClassesImpl::EquivalenceClasses(classes) => {
379                classes.classes.extend(equivalences);
380            }
381            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
382                classes.extend_equivalences(equivalences);
383            }
384        }
385    }
386
387    /// Push a new equivalence class.
388    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
389        match self {
390            EqClassesImpl::EquivalenceClasses(classes) => {
391                classes.classes.push(equivalences);
392            }
393            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
394                classes.push_equivalences(equivalences);
395            }
396        }
397    }
398
399    /// Subject the constraints to the column projection, reworking and removing equivalences.
400    pub fn project<I>(&mut self, output_columns: I)
401    where
402        I: IntoIterator<Item = usize> + Clone,
403    {
404        match self {
405            EqClassesImpl::EquivalenceClasses(classes) => {
406                classes.project(output_columns);
407            }
408            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
409                classes.project(output_columns);
410            }
411        }
412    }
413
414    /// Extract the equivalences
415    pub fn extract_equivalences(
416        &mut self,
417        columns: Option<&[ColumnType]>,
418    ) -> Vec<Vec<MirScalarExpr>> {
419        match self {
420            EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
421            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
422                classes.extract_equivalences(columns)
423            }
424        }
425    }
426}
427
428/// A wrapper struct for equivalence classes that witholds errors
429/// from the underlying equivalence classes, reintroducing them
430/// when extracting equivalences.
431#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
432pub struct EquivalenceClassesWithholdingErrors {
433    /// The underlying equivalence classes.
434    pub equivalence_classes: EquivalenceClasses,
435    /// A list of equivalence classes that contain errors, which we hold back
436    pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
437}
438
439impl EquivalenceClassesWithholdingErrors {
440    /// Extend the equivalence classes with new equivalences.
441    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
442        for class in equivalences {
443            self.push_equivalences(class);
444        }
445    }
446
447    /// Push a new equivalence class, which may contain errors.
448    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
449        let (with_errs, without_errs): (Vec<_>, Vec<_>) =
450            equivalences.into_iter().partition(|e| e.contains_err());
451        self.held_back_classes
452            .push((without_errs.first().cloned(), with_errs));
453        if without_errs.len() > 1 {
454            self.equivalence_classes.classes.push(without_errs);
455        }
456    }
457
458    /// Subject the constraints to the column projection, reworking and removing equivalences.
459    /// This method should also introduce equivalences representing any repeated columns.
460    /// This notably does not project the held back classes.
461    pub fn project<I>(&mut self, output_columns: I)
462    where
463        I: IntoIterator<Item = usize> + Clone,
464    {
465        self.equivalence_classes.project(output_columns.clone());
466    }
467
468    /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
469    pub fn extract_equivalences(
470        &mut self,
471        columns: Option<&[ColumnType]>,
472    ) -> Vec<Vec<MirScalarExpr>> {
473        self.equivalence_classes.minimize(columns);
474
475        if self.held_back_classes.is_empty() {
476            self.equivalence_classes.classes.clone()
477        } else {
478            let reducer = self.equivalence_classes.reducer();
479            let mut classes = self.equivalence_classes.classes.clone();
480            for (anchor, mut class) in self.held_back_classes.drain(..) {
481                if let Some(mut anchor) = anchor {
482                    // Reduce the anchor expression to its canonical form.
483                    reducer.reduce_expr(&mut anchor);
484                    class.push(anchor.clone());
485                }
486                classes.push(class.clone());
487            }
488            canonicalize_equivalence_classes(&mut classes);
489            classes
490        }
491    }
492}
493
494/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
495/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
496/// [`HumanizedEquivalenceClasses`].
497impl std::fmt::Display for EquivalenceClasses {
498    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
499        HumanizedEquivalenceClasses {
500            equivalence_classes: self,
501            cols: None,
502            mode: HumanizedExplain::default(),
503        }
504        .fmt(f)
505    }
506}
507
508/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
509/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
510/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
511/// `Display` nor `HumanizedExpr` is defined in this crate.)
512#[derive(Debug)]
513pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
514    /// The [`EquivalenceClasses`] to be humanized.
515    pub equivalence_classes: &'a EquivalenceClasses,
516    /// An optional vector of inferred column names to be used when rendering
517    /// column references in expressions.
518    pub cols: Option<&'a Vec<String>>,
519    /// The rendering mode to use. See `HumanizerMode` for details.
520    pub mode: M,
521}
522
523impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
524    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
525        // Only show `classes`.
526        // (The following hopefully avoids allocating any of the intermediate composite strings.)
527        let classes = self.equivalence_classes.classes.iter().map(|class| {
528            format!(
529                "{}",
530                bracketed(
531                    "[",
532                    "]",
533                    separated(
534                        ", ",
535                        class.iter().map(|expr| self.mode.expr(expr, self.cols))
536                    )
537                )
538            )
539        });
540        write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
541    }
542}
543
544impl EquivalenceClasses {
545    /// Comparator function for the complexity of scalar expressions. Simpler expressions are
546    /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
547    pub fn mir_scalar_expr_complexity(
548        e1: &MirScalarExpr,
549        e2: &MirScalarExpr,
550    ) -> std::cmp::Ordering {
551        use MirScalarExpr::*;
552        use std::cmp::Ordering::*;
553        match (e1, e2) {
554            (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
555            (Literal(_, _), _) => Less,
556            (_, Literal(_, _)) => Greater,
557            (Column(_, _), Column(_, _)) => e1.cmp(e2),
558            (Column(_, _), _) => Less,
559            (_, Column(_, _)) => Greater,
560            (x, y) => {
561                // General expressions should be ordered by their size,
562                // to ensure we only simplify expressions by substitution.
563                // If same size, then fall back to the expressions' Ord.
564                match x.size().cmp(&y.size()) {
565                    Equal => x.cmp(y),
566                    other => other,
567                }
568            }
569        }
570    }
571
572    /// Sorts and deduplicates each class, removing literal errors.
573    ///
574    /// This method does not ensure equivalence relation structure, but instead performs
575    /// only minimal structural clean-up.
576    fn tidy(&mut self) {
577        for class in self.classes.iter_mut() {
578            // Remove all literal errors, as they cannot be equated to other things.
579            class.retain(|e| !e.is_literal_err());
580            class.sort_by(Self::mir_scalar_expr_complexity);
581            class.dedup();
582        }
583        self.classes.retain(|c| c.len() > 1);
584        self.classes.sort();
585        self.classes.dedup();
586    }
587
588    /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
589    ///
590    /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
591    /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
592    /// whether the equivalence classes it represents have experienced any changes since the
593    /// last refresh.
594    fn refresh(&mut self) -> bool {
595        self.tidy();
596
597        // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
598        // If it contains the same number of expressions as `self.classes`, and for every expression in
599        // `self.classes` the two agree on the representative, they are identical.
600        if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
601            && self
602                .classes
603                .iter()
604                .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
605        {
606            // No change, so return false.
607            return false;
608        }
609
610        // Optimistically build the `remap` we would want.
611        // Note if any unions would be required, in which case we have further work to do,
612        // including re-forming `self.classes`.
613        let mut union_find = BTreeMap::default();
614        let mut dirtied = false;
615        for class in self.classes.iter() {
616            for expr in class.iter() {
617                if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
618                    // A merge is required, but have the more complex expression point at the simpler one.
619                    // This allows `union_find` to end as the `remap` for the new `classes` we form, with
620                    // the only required work being compressing all the paths.
621                    if Self::mir_scalar_expr_complexity(&other, &class[0])
622                        == std::cmp::Ordering::Less
623                    {
624                        union_find.union(&class[0], &other);
625                    } else {
626                        union_find.union(&other, &class[0]);
627                    }
628                    dirtied = true;
629                }
630            }
631        }
632        if dirtied {
633            let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
634            for class in self.classes.drain(..) {
635                for expr in class {
636                    let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
637                    classes.entry(root).or_default().push(expr);
638                }
639            }
640            self.classes = classes.into_values().collect();
641            self.tidy();
642        }
643
644        let changed = self.remap != union_find;
645        self.remap = union_find;
646        changed
647    }
648
649    /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
650    ///
651    /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
652    pub fn minimize(&mut self, columns: Option<&[ColumnType]>) {
653        // Repeatedly, we reduce each of the classes themselves, then unify the classes.
654        // This should strictly reduce complexity, and reach a fixed point.
655        // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
656
657        // We should not rely on nullability information present in `column_types`. (Doing this
658        // every time just before calling `reduce` was found to be a bottleneck during incident-217,
659        // so now we do this nullability tweaking only once here.)
660        let mut columns = columns.map(|x| x.to_vec());
661        let mut nonnull = Vec::new();
662        if let Some(columns) = columns.as_mut() {
663            for (index, col) in columns.iter_mut().enumerate() {
664                let is_null = MirScalarExpr::column(index).call_is_null();
665                if !col.nullable
666                    && self
667                        .remap
668                        .get(&is_null)
669                        .map(|e| !e.is_literal_false())
670                        .unwrap_or(true)
671                {
672                    nonnull.push(is_null);
673                }
674                col.nullable = true;
675            }
676        }
677        if !nonnull.is_empty() {
678            nonnull.push(MirScalarExpr::literal_false());
679            self.classes.push(nonnull);
680        }
681
682        // Ensure `self.classes` and `self.remap` are equivalence relations.
683        // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
684        // We have also likely mutated `self.classes` just above with non-nullability information.
685        self.refresh();
686
687        // Termination will be detected by comparing to the map of equivalence classes.
688        let mut previous = Some(self.remap.clone());
689        while let Some(prev) = previous {
690            // Attempt to add new equivalences.
691            let novel = self.expand();
692            if !novel.is_empty() {
693                self.classes.extend(novel);
694                self.refresh();
695            }
696
697            // We continue as long as any simplification has occurred.
698            // An expression can be simplified, a duplication found, or two classes unified.
699            let mut stable = false;
700            while !stable {
701                stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
702            }
703
704            // Termination detection.
705            if prev != self.remap {
706                previous = Some(self.remap.clone());
707            } else {
708                previous = None;
709            }
710        }
711    }
712
713    /// Proposes new equivalences that are likely to be novel.
714    ///
715    /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
716    /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
717    /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
718    /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
719    /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
720    /// number of times we call and amount of work we do in `self.refresh()`.
721    fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
722        // Consider expanding `self.classes` with novel equivalences.
723        let mut novel = self.implications();
724        for class in novel.iter_mut() {
725            // reduce each expression to its canonical form.
726            for expr in class.iter_mut() {
727                self.remap.reduce_expr(expr);
728            }
729            class.sort();
730            class.dedup();
731            // for a class to be interesting we require at least two elements that do not reference the same root.
732            let common_class = class
733                .iter()
734                .map(|x| self.remap.get(x))
735                .reduce(|prev, this| if prev == this { prev } else { None });
736            if class.len() == 1 || common_class != Some(None) {
737                class.clear();
738            }
739        }
740        novel.retain(|c| !c.is_empty());
741        novel
742    }
743
744    /// Derives potentially novel equivalences without regard for minimization.
745    ///
746    /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
747    /// and therefore should not be used in `minimize_once`. They are still potentially important, but
748    /// required additional guardrails to ensure we reach a fixed point.
749    ///
750    /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
751    /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
752    /// For example, before producing a new equivalence, one could check that the involved terms are not
753    /// already present in the same class.
754    fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
755        let mut new_equivalences = Vec::new();
756
757        // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
758        let mut non_null = std::collections::BTreeSet::default();
759        for class in self.classes.iter() {
760            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
761                for e in class.iter() {
762                    if let MirScalarExpr::CallUnary {
763                        func: mz_expr::UnaryFunc::IsNull(_),
764                        expr,
765                    } = e
766                    {
767                        expr.non_null_requirements(&mut non_null);
768                    }
769                }
770            }
771        }
772        // If we see `true == foo` we can add the non-null implications of `foo`.
773        // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
774        // an important idiom to identify for how we express predicates.
775        for class in self.classes.iter() {
776            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
777                for expr in class.iter() {
778                    expr.non_null_requirements(&mut non_null);
779                }
780            }
781        }
782        // Only keep constraints that are not already known.
783        // Known constraints will present as `COL(_) IS NULL == false`,
784        // which can only happen if `false` is present, and both terms
785        // map to the same canonical representative>
786        let lit_false = MirScalarExpr::literal_false();
787        let target = self.remap.get(&lit_false);
788        if target.is_some() {
789            non_null.retain(|c| {
790                let is_null = MirScalarExpr::column(*c).call_is_null();
791                self.remap.get(&is_null) != target
792            });
793        }
794        if !non_null.is_empty() {
795            let mut class = Vec::with_capacity(non_null.len() + 1);
796            class.push(MirScalarExpr::literal_false());
797            class.extend(
798                non_null
799                    .into_iter()
800                    .map(|c| MirScalarExpr::column(c).call_is_null()),
801            );
802            new_equivalences.push(class);
803        }
804
805        // If we see records formed from other expressions, we can equate the expressions with
806        // accessors applied to the class of the record former. In `minimize_once` we reduce by
807        // equivalence class representative before we perform expression simplification, so we
808        // shoud be able to just use the expression former, rather than find its representative.
809        // The risk, potentially, is that we would apply accessors to the record former and then
810        // just simplify it away learning nothing.
811        for class in self.classes.iter() {
812            for expr in class.iter() {
813                // Record-forming expressions can equate their accessors and their members.
814                if let MirScalarExpr::CallVariadic {
815                    func: mz_expr::VariadicFunc::RecordCreate { .. },
816                    exprs,
817                } = expr
818                {
819                    for (index, e) in exprs.iter().enumerate() {
820                        new_equivalences.push(vec![
821                            e.clone(),
822                            expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
823                                mz_expr::func::RecordGet(index),
824                            )),
825                        ]);
826                    }
827                }
828            }
829        }
830
831        // Return all newly established equivalences.
832        new_equivalences
833    }
834
835    /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
836    ///
837    /// This invocation should take roughly linear time.
838    /// It starts with equivalence class invariants maintained (closed under transitivity), and then
839    ///   1. Performs per-expression reduction, including the class structure to replace subexpressions.
840    ///   2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
841    ///   3. Restores the equivalence class invariants.
842    fn minimize_once(&mut self, columns: Option<&[ColumnType]>) -> bool {
843        // 1. Reduce each expression
844        //
845        // This reduction first looks for subexpression substitutions that can be performed,
846        // and then applies expression reduction if column type information is provided.
847        for class in self.classes.iter_mut() {
848            for expr in class.iter_mut() {
849                self.remap.reduce_child(expr);
850                if let Some(columns) = columns {
851                    let orig_expr = expr.clone();
852                    expr.reduce(columns);
853                    if expr.contains_err() {
854                        // Rollback to the original expression if it contains an error.
855                        *expr = orig_expr;
856                    }
857                }
858            }
859        }
860
861        // 2. Identify idioms
862        //    E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
863        let mut to_add = Vec::new();
864        for class in self.classes.iter_mut() {
865            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
866                for expr in class.iter() {
867                    // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
868                    // This substitution replaces a complex expression with several smaller expressions, and cannot
869                    // cycle if we follow that practice.
870                    if let MirScalarExpr::CallBinary {
871                        func: mz_expr::BinaryFunc::Eq,
872                        expr1,
873                        expr2,
874                    } = expr
875                    {
876                        to_add.push(vec![*expr1.clone(), *expr2.clone()]);
877                        to_add.push(vec![
878                            MirScalarExpr::literal_false(),
879                            expr1.clone().call_is_null(),
880                            expr2.clone().call_is_null(),
881                        ]);
882                    }
883                }
884                // Remove the more complex form of the expression.
885                class.retain(|expr| {
886                    if let MirScalarExpr::CallBinary {
887                        func: mz_expr::BinaryFunc::Eq,
888                        ..
889                    } = expr
890                    {
891                        false
892                    } else {
893                        true
894                    }
895                });
896                for expr in class.iter() {
897                    // If TRUE == NOT(X) then FALSE == X is a simpler form.
898                    if let MirScalarExpr::CallUnary {
899                        func: mz_expr::UnaryFunc::Not(_),
900                        expr: e,
901                    } = expr
902                    {
903                        to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
904                    }
905                }
906                class.retain(|expr| {
907                    if let MirScalarExpr::CallUnary {
908                        func: mz_expr::UnaryFunc::Not(_),
909                        ..
910                    } = expr
911                    {
912                        false
913                    } else {
914                        true
915                    }
916                });
917            }
918            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
919                for expr in class.iter() {
920                    // If FALSE == NOT(X) then TRUE == X is a simpler form.
921                    if let MirScalarExpr::CallUnary {
922                        func: mz_expr::UnaryFunc::Not(_),
923                        expr: e,
924                    } = expr
925                    {
926                        to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
927                    }
928                }
929                class.retain(|expr| {
930                    if let MirScalarExpr::CallUnary {
931                        func: mz_expr::UnaryFunc::Not(_),
932                        ..
933                    } = expr
934                    {
935                        false
936                    } else {
937                        true
938                    }
939                });
940            }
941        }
942        self.classes.extend(to_add);
943
944        // 3. Restore equivalence relation structure and observe if any changes result.
945        self.refresh()
946    }
947
948    /// Produce the equivalences present in both inputs.
949    pub fn union(&self, other: &Self) -> Self {
950        self.union_many([other])
951    }
952
953    /// The equivalence classes of terms equivalent in all inputs.
954    ///
955    /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
956    /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
957    /// correct, result.
958    ///
959    /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
960    /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
961    /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
962    pub fn union_many<'a, I>(&self, others: I) -> Self
963    where
964        I: IntoIterator<Item = &'a Self>,
965    {
966        // List of expressions in the intersection, and a proxy equivalence class identifier.
967        let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
968        // Map from expression to a proxy equivalence class identifier.
969        let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
970        for (key, val) in self.remap.iter() {
971            if !rekey.contains_key(val) {
972                rekey.insert(val, rekey.len());
973            }
974            intersection.push((key, rekey[val]));
975        }
976        for other in others {
977            // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
978            let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
979            intersection.retain_mut(|(key, idx)| {
980                if let Some(val) = other.remap.get(key) {
981                    if !rekey.contains_key(&(*idx, val)) {
982                        rekey.insert((*idx, val), rekey.len());
983                    }
984                    *idx = rekey[&(*idx, val)];
985                    true
986                } else {
987                    false
988                }
989            });
990        }
991        let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
992        for (key, vals) in intersection {
993            classes.entry(vals).or_default().push(key.clone())
994        }
995        let classes = classes.into_values().collect::<Vec<_>>();
996        let mut equivalences = EquivalenceClasses {
997            classes,
998            remap: Default::default(),
999        };
1000        equivalences.minimize(None);
1001        equivalences
1002    }
1003
1004    /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1005    pub fn permute(&mut self, permutation: &[usize]) {
1006        for class in self.classes.iter_mut() {
1007            for expr in class.iter_mut() {
1008                expr.permute(permutation);
1009            }
1010        }
1011        self.remap.clear();
1012        self.minimize(None);
1013    }
1014
1015    /// Subject the constraints to the column projection, reworking and removing equivalences.
1016    ///
1017    /// This method should also introduce equivalences representing any repeated columns.
1018    pub fn project<I>(&mut self, output_columns: I)
1019    where
1020        I: IntoIterator<Item = usize> + Clone,
1021    {
1022        // Retain the first instance of each column, and record subsequent instances as duplicates.
1023        let mut dupes = Vec::new();
1024        let mut remap = BTreeMap::default();
1025        for (idx, col) in output_columns.into_iter().enumerate() {
1026            if let Some(pos) = remap.get(&col) {
1027                dupes.push((*pos, idx));
1028            } else {
1029                remap.insert(col, idx);
1030            }
1031        }
1032
1033        // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1034        // Many expressions may not be localized, but may reference canonical non-localized expressions
1035        // for classes that contain a localized expression; in that case we can "backport" the localized
1036        // expression to give expressions referencing the canonical expression a shot at localization.
1037        //
1038        // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1039        // to look any further than backporting those. Backporting should have the property that the simplest
1040        // localized expression in each class does not contain any non-localized canonical expressions
1041        // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1042        // expressions should never fire a second
1043
1044        // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1045        // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1046        // As we find localized expressions, we can replace uses of their equivalent representative with them,
1047        // which may allow further expression localization.
1048        // We continue the process until no further classes can be localized.
1049
1050        // A map from representatives to our first localization of their equivalence class.
1051        let mut localized = false;
1052        while !localized {
1053            localized = true;
1054            let mut current_map = BTreeMap::default();
1055            for class in self.classes.iter_mut() {
1056                if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1057                    if let Some(pos) = class
1058                        .iter()
1059                        .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1060                    {
1061                        class.swap(0, pos);
1062                        localized = false;
1063                    }
1064                }
1065                for expr in class[1..].iter() {
1066                    current_map.insert(expr.clone(), class[0].clone());
1067                }
1068            }
1069
1070            // attempt to replace representatives with equivalent localizeable expressions.
1071            for class_index in 0..self.classes.len() {
1072                for index in 0..self.classes[class_index].len() {
1073                    current_map.reduce_child(&mut self.classes[class_index][index]);
1074                }
1075            }
1076            // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1077        }
1078
1079        // Localize all localizable expressions and discard others.
1080        for class in self.classes.iter_mut() {
1081            class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1082            for expr in class.iter_mut() {
1083                expr.permute_map(&remap);
1084            }
1085        }
1086        self.classes.retain(|c| c.len() > 1);
1087        // If column repetitions, introduce them as equivalences.
1088        // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1089        for (col1, col2) in dupes {
1090            self.classes.push(vec![
1091                MirScalarExpr::column(col1),
1092                MirScalarExpr::column(col2),
1093            ]);
1094        }
1095        self.remap.clear();
1096        self.minimize(None);
1097    }
1098
1099    /// True if any equivalence class contains two distinct non-error literals.
1100    pub fn unsatisfiable(&self) -> bool {
1101        for class in self.classes.iter() {
1102            let mut literal_ok = None;
1103            for expr in class.iter() {
1104                if let MirScalarExpr::Literal(Ok(row), _) = expr {
1105                    if literal_ok.is_some() && literal_ok != Some(row) {
1106                        return true;
1107                    } else {
1108                        literal_ok = Some(row);
1109                    }
1110                }
1111            }
1112        }
1113        false
1114    }
1115
1116    /// Returns a map that can be used to replace (sub-)expressions.
1117    pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1118        &self.remap
1119    }
1120
1121    /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1122    ///
1123    /// This test bails out as soon as it sees a non-literal, and may have false negatives
1124    /// if the data are not sorted with literals at the front.
1125    fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1126    where
1127        P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1128    {
1129        class
1130            .iter()
1131            .take_while(|e| e.is_literal())
1132            .filter_map(|e| e.as_literal())
1133            .any(move |e| predicate(&e))
1134    }
1135}
1136
1137/// A type capable of simplifying `MirScalarExpr`s.
1138pub trait ExpressionReducer {
1139    /// Attempt to replace `expr` itself with another expression.
1140    /// Returns true if it does so.
1141    fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1142    /// Attempt to replace any subexpressions of `expr` with other expressions.
1143    /// Returns true if it does so.
1144    fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1145        let mut simplified = false;
1146        simplified = simplified || self.reduce_child(expr);
1147        simplified = simplified || self.replace(expr);
1148        simplified
1149    }
1150    /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1151    /// Returns true if it does so.
1152    fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1153        let mut simplified = false;
1154        match expr {
1155            MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1156                simplified = self.reduce_expr(expr1) || simplified;
1157                simplified = self.reduce_expr(expr2) || simplified;
1158            }
1159            MirScalarExpr::CallUnary { expr, .. } => {
1160                simplified = self.reduce_expr(expr) || simplified;
1161            }
1162            MirScalarExpr::CallVariadic { exprs, .. } => {
1163                for expr in exprs.iter_mut() {
1164                    simplified = self.reduce_expr(expr) || simplified;
1165                }
1166            }
1167            MirScalarExpr::If { cond: _, then, els } => {
1168                // Do not simplify `cond`, as we cannot ensure the simplification
1169                // continues to hold as expressions migrate around.
1170                simplified = self.reduce_expr(then) || simplified;
1171                simplified = self.reduce_expr(els) || simplified;
1172            }
1173            _ => {}
1174        }
1175        simplified
1176    }
1177}
1178
1179impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1180    /// Perform any exact replacement for `expr`, report if it had an effect.
1181    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1182        if let Some(other) = self.get(expr) {
1183            if other != &expr {
1184                expr.clone_from(other);
1185                return true;
1186            }
1187        }
1188        false
1189    }
1190}
1191
1192impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1193    /// Perform any exact replacement for `expr`, report if it had an effect.
1194    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1195        if let Some(other) = self.get(expr) {
1196            if other != expr {
1197                expr.clone_from(other);
1198                return true;
1199            }
1200        }
1201        false
1202    }
1203}
1204
1205/// True iff the aggregate function returns an input datum.
1206fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1207    match aggregate {
1208        AggregateFunc::MaxInt16
1209        | AggregateFunc::MaxInt32
1210        | AggregateFunc::MaxInt64
1211        | AggregateFunc::MaxUInt16
1212        | AggregateFunc::MaxUInt32
1213        | AggregateFunc::MaxUInt64
1214        | AggregateFunc::MaxMzTimestamp
1215        | AggregateFunc::MaxFloat32
1216        | AggregateFunc::MaxFloat64
1217        | AggregateFunc::MaxBool
1218        | AggregateFunc::MaxString
1219        | AggregateFunc::MaxDate
1220        | AggregateFunc::MaxTimestamp
1221        | AggregateFunc::MaxTimestampTz
1222        | AggregateFunc::MinInt16
1223        | AggregateFunc::MinInt32
1224        | AggregateFunc::MinInt64
1225        | AggregateFunc::MinUInt16
1226        | AggregateFunc::MinUInt32
1227        | AggregateFunc::MinUInt64
1228        | AggregateFunc::MinMzTimestamp
1229        | AggregateFunc::MinFloat32
1230        | AggregateFunc::MinFloat64
1231        | AggregateFunc::MinBool
1232        | AggregateFunc::MinString
1233        | AggregateFunc::MinDate
1234        | AggregateFunc::MinTimestamp
1235        | AggregateFunc::MinTimestampTz
1236        | AggregateFunc::Any
1237        | AggregateFunc::All => true,
1238        _ => false,
1239    }
1240}