mz_transform/analysis/
equivalences.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
22use mz_expr::explain::{HumanizedExplain, HumanizerMode};
23use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
24use mz_ore::str::{bracketed, separated};
25use mz_repr::{Datum, SqlColumnType};
26
27use crate::analysis::{Analysis, Lattice};
28use crate::analysis::{Arity, SqlRelationType};
29use crate::analysis::{Derived, DerivedBuilder};
30
31/// Pulls up and pushes down predicate information represented as equivalences
32#[derive(Debug, Default)]
33pub struct Equivalences;
34
35impl Analysis for Equivalences {
36    // A `Some(list)` indicates a list of classes of equivalent expressions.
37    // A `None` indicates all expressions are equivalent, including contradictions;
38    // this is only possible for the empty collection, and as an initial result for
39    // unconstrained recursive terms.
40    type Value = Option<EquivalenceClasses>;
41
42    fn announce_dependencies(builder: &mut DerivedBuilder) {
43        builder.require(Arity);
44        builder.require(SqlRelationType); // needed for expression reduction.
45    }
46
47    fn derive(
48        &self,
49        expr: &MirRelationExpr,
50        index: usize,
51        results: &[Self::Value],
52        depends: &Derived,
53    ) -> Self::Value {
54        let mut equivalences = match expr {
55            MirRelationExpr::Constant { rows, typ } => {
56                // Trawl `rows` for any constant information worth recording.
57                // Literal columns may be valuable; non-nullability could be too.
58                let mut equivalences = EquivalenceClasses::default();
59                if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
60                    // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
61                    let len = row.iter().count();
62                    let mut common = Vec::with_capacity(len);
63                    common.extend(row.iter().map(Some));
64                    // Prep initial nullability information.
65                    let mut nullable_cols = common
66                        .iter()
67                        .map(|datum| datum == &Some(Datum::Null))
68                        .collect::<Vec<_>>();
69
70                    for (row, _cnt) in rows.iter() {
71                        for ((datum, common), nullable) in row
72                            .iter()
73                            .zip(common.iter_mut())
74                            .zip(nullable_cols.iter_mut())
75                        {
76                            if Some(datum) != *common {
77                                *common = None;
78                            }
79                            if datum == Datum::Null {
80                                *nullable = true;
81                            }
82                        }
83                    }
84                    for (index, common) in common.into_iter().enumerate() {
85                        if let Some(datum) = common {
86                            equivalences.classes.push(vec![
87                                MirScalarExpr::column(index),
88                                MirScalarExpr::literal_ok(
89                                    datum,
90                                    typ.column_types[index].scalar_type.clone(),
91                                ),
92                            ]);
93                        }
94                    }
95                    // If any columns are non-null, introduce this fact.
96                    if nullable_cols.iter().any(|x| !*x) {
97                        let mut class = vec![MirScalarExpr::literal_false()];
98                        for (index, nullable) in nullable_cols.iter().enumerate() {
99                            if !*nullable {
100                                class.push(MirScalarExpr::column(index).call_is_null());
101                            }
102                        }
103                        equivalences.classes.push(class);
104                    }
105                }
106                Some(equivalences)
107            }
108            MirRelationExpr::Get { id, typ, .. } => {
109                let mut equivalences = Some(EquivalenceClasses::default());
110                // Find local identifiers, but nothing for external identifiers.
111                if let Id::Local(id) = id {
112                    if let Some(offset) = depends.bindings().get(id) {
113                        // It is possible we have derived nothing for a recursive term
114                        if let Some(result) = results.get(*offset) {
115                            equivalences.clone_from(result);
116                        } else {
117                            // No top element was prepared.
118                            // This means we are executing pessimistically,
119                            // but perhaps we must because optimism is off.
120                        }
121                    }
122                }
123                // Incorporate statements about column nullability.
124                let mut non_null_cols = vec![MirScalarExpr::literal_false()];
125                for (index, col_type) in typ.column_types.iter().enumerate() {
126                    if !col_type.nullable {
127                        non_null_cols.push(MirScalarExpr::column(index).call_is_null());
128                    }
129                }
130                if non_null_cols.len() > 1 {
131                    if let Some(equivalences) = equivalences.as_mut() {
132                        equivalences.classes.push(non_null_cols);
133                    }
134                }
135
136                equivalences
137            }
138            MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
139            MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
140            MirRelationExpr::Project { outputs, .. } => {
141                // restrict equivalences, and introduce equivalences for repeated outputs.
142                let mut equivalences = results.get(index - 1).unwrap().clone();
143                equivalences
144                    .as_mut()
145                    .map(|e| e.project(outputs.iter().cloned()));
146                equivalences
147            }
148            MirRelationExpr::Map { scalars, .. } => {
149                // introduce equivalences for new columns and expressions that define them.
150                let mut equivalences = results.get(index - 1).unwrap().clone();
151                if let Some(equivalences) = &mut equivalences {
152                    let input_arity = depends.results::<Arity>()[index - 1];
153                    for (pos, expr) in scalars.iter().enumerate() {
154                        equivalences
155                            .classes
156                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
157                    }
158                }
159                equivalences
160            }
161            MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
162            MirRelationExpr::Filter { predicates, .. } => {
163                let mut equivalences = results.get(index - 1).unwrap().clone();
164                if let Some(equivalences) = &mut equivalences {
165                    let mut class = predicates.clone();
166                    class.push(MirScalarExpr::literal_ok(
167                        Datum::True,
168                        mz_repr::SqlScalarType::Bool,
169                    ));
170                    equivalences.classes.push(class);
171                }
172                equivalences
173            }
174            MirRelationExpr::Join { equivalences, .. } => {
175                // Collect equivalences from all inputs;
176                let expr_index = index;
177                let mut children = depends
178                    .children_of_rev(expr_index, expr.children().count())
179                    .collect::<Vec<_>>();
180                children.reverse();
181
182                let arity = depends.results::<Arity>();
183                let mut columns = 0;
184                let mut result = Some(EquivalenceClasses::default());
185                for child in children.into_iter() {
186                    let input_arity = arity[child];
187                    let equivalences = results[child].clone();
188                    if let Some(mut equivalences) = equivalences {
189                        let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
190                        equivalences.permute(&permutation);
191                        result
192                            .as_mut()
193                            .map(|e| e.classes.extend(equivalences.classes));
194                    } else {
195                        result = None;
196                    }
197                    columns += input_arity;
198                }
199
200                // Fold join equivalences into our results.
201                result
202                    .as_mut()
203                    .map(|e| e.classes.extend(equivalences.iter().cloned()));
204                result
205            }
206            MirRelationExpr::Reduce {
207                group_key,
208                aggregates,
209                ..
210            } => {
211                let input_arity = depends.results::<Arity>()[index - 1];
212                let mut equivalences = results.get(index - 1).unwrap().clone();
213                if let Some(equivalences) = &mut equivalences {
214                    // Introduce keys column equivalences as if a map, then project to those columns.
215                    // This should retain as much information as possible about these columns.
216                    for (pos, expr) in group_key.iter().enumerate() {
217                        equivalences
218                            .classes
219                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
220                    }
221
222                    // Having added classes to `equivalences`, we should minimize the classes to fold the
223                    // information in before applying the `project`, to set it up for success.
224                    equivalences.minimize(None);
225
226                    // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
227                    let extended = equivalences.clone();
228                    // Now project down the equivalences, as we will extend them in terms of the output columns.
229                    equivalences.project(input_arity..(input_arity + group_key.len()));
230
231                    // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
232                    // They also pass through equivalences of them and other constant columns (e.g. key columns).
233                    // However, it is not correct to simply project onto these columns, as relationships amongst
234                    // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
235                    // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
236                    // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
237                    for (index, aggregate) in aggregates.iter().enumerate() {
238                        if aggregate_is_input(&aggregate.func) {
239                            let mut temp_equivs = extended.clone();
240                            temp_equivs.classes.push(vec![
241                                MirScalarExpr::column(input_arity + group_key.len()),
242                                aggregate.expr.clone(),
243                            ]);
244                            temp_equivs.minimize(None);
245                            temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
246                            let columns = (0..group_key.len())
247                                .chain(std::iter::once(group_key.len() + index))
248                                .collect::<Vec<_>>();
249                            temp_equivs.permute(&columns[..]);
250                            equivalences.classes.extend(temp_equivs.classes);
251                        }
252                    }
253                }
254                equivalences
255            }
256            MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
257            MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
258            MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
259            MirRelationExpr::Union { .. } => {
260                let expr_index = index;
261                let mut child_equivs = depends
262                    .children_of_rev(expr_index, expr.children().count())
263                    .flat_map(|c| &results[c]);
264                if let Some(first) = child_equivs.next() {
265                    Some(first.union_many(child_equivs))
266                } else {
267                    None
268                }
269            }
270            MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
271        };
272
273        let expr_type = depends.results::<SqlRelationType>()[index].as_ref();
274        equivalences
275            .as_mut()
276            .map(|e| e.minimize(expr_type.map(|x| &x[..])));
277        equivalences
278    }
279
280    fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
281        Some(Box::new(EQLattice))
282    }
283}
284
285struct EQLattice;
286
287impl Lattice<Option<EquivalenceClasses>> for EQLattice {
288    fn top(&self) -> Option<EquivalenceClasses> {
289        None
290    }
291
292    fn meet_assign(
293        &self,
294        a: &mut Option<EquivalenceClasses>,
295        b: Option<EquivalenceClasses>,
296    ) -> bool {
297        match (&mut *a, b) {
298            (_, None) => false,
299            (None, b) => {
300                *a = b;
301                true
302            }
303            (Some(a), Some(b)) => {
304                let mut c = a.union(&b);
305                std::mem::swap(a, &mut c);
306                a != &mut c
307            }
308        }
309    }
310}
311
312/// A compact representation of classes of expressions that must be equivalent.
313///
314/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
315/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
316/// and any other element of that list can be replaced by it.
317///
318/// The classes are meant to be minimized, with each expression as reduced as it can be,
319/// and all classes sharing an element merged.
320#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
321pub struct EquivalenceClasses {
322    /// Multiple lists of equivalent expressions, each representing an equivalence class.
323    ///
324    /// The first element should be the "canonical" simplest element, that any other element
325    /// can be replaced by.
326    /// These classes are unified whenever possible, to minimize the number of classes.
327    /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
328    /// which refreshes both `self.classes` and `self.remap`.
329    pub classes: Vec<Vec<MirScalarExpr>>,
330
331    /// An expression simplification map.
332    ///
333    /// This map reflects an equivalence relation based on a prior version of `self.classes`.
334    /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
335    /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
336    ///
337    /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
338    /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
339    /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
340    remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
341}
342
343/// An enum that can represent either a standard `EquivalenceClasses` or a
344/// `EquivalenceClassesWithholdingErrors`.
345#[derive(Clone, Debug)]
346pub enum EqClassesImpl {
347    /// Standard equivalence classes.
348    EquivalenceClasses(EquivalenceClasses),
349    /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
350    EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
351}
352
353impl EqClassesImpl {
354    /// Returns a reference to the underlying `EquivalenceClasses`.
355    pub fn equivalence_classes(&self) -> &EquivalenceClasses {
356        match self {
357            EqClassesImpl::EquivalenceClasses(classes) => classes,
358            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
359                &classes.equivalence_classes
360            }
361        }
362    }
363
364    /// Returns a mutable reference to the underlying `EquivalenceClasses`.
365    pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
366        match self {
367            EqClassesImpl::EquivalenceClasses(classes) => classes,
368            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
369                &mut classes.equivalence_classes
370            }
371        }
372    }
373
374    /// Extend the equivalence classes with new equivalences.
375    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
376        match self {
377            EqClassesImpl::EquivalenceClasses(classes) => {
378                classes.classes.extend(equivalences);
379            }
380            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
381                classes.extend_equivalences(equivalences);
382            }
383        }
384    }
385
386    /// Push a new equivalence class.
387    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
388        match self {
389            EqClassesImpl::EquivalenceClasses(classes) => {
390                classes.classes.push(equivalences);
391            }
392            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
393                classes.push_equivalences(equivalences);
394            }
395        }
396    }
397
398    /// Subject the constraints to the column projection, reworking and removing equivalences.
399    pub fn project<I>(&mut self, output_columns: I)
400    where
401        I: IntoIterator<Item = usize> + Clone,
402    {
403        match self {
404            EqClassesImpl::EquivalenceClasses(classes) => {
405                classes.project(output_columns);
406            }
407            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
408                classes.project(output_columns);
409            }
410        }
411    }
412
413    /// Extract the equivalences
414    pub fn extract_equivalences(
415        &mut self,
416        columns: Option<&[SqlColumnType]>,
417    ) -> Vec<Vec<MirScalarExpr>> {
418        match self {
419            EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
420            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
421                classes.extract_equivalences(columns)
422            }
423        }
424    }
425}
426
427/// A wrapper struct for equivalence classes that witholds errors
428/// from the underlying equivalence classes, reintroducing them
429/// when extracting equivalences.
430#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
431pub struct EquivalenceClassesWithholdingErrors {
432    /// The underlying equivalence classes.
433    pub equivalence_classes: EquivalenceClasses,
434    /// A list of equivalence classes that contain errors, which we hold back
435    pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
436}
437
438impl EquivalenceClassesWithholdingErrors {
439    /// Extend the equivalence classes with new equivalences.
440    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
441        for class in equivalences {
442            self.push_equivalences(class);
443        }
444    }
445
446    /// Push a new equivalence class, which may contain errors.
447    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
448        let (with_errs, without_errs): (Vec<_>, Vec<_>) =
449            equivalences.into_iter().partition(|e| e.contains_err());
450        self.held_back_classes
451            .push((without_errs.first().cloned(), with_errs));
452        if without_errs.len() > 1 {
453            self.equivalence_classes.classes.push(without_errs);
454        }
455    }
456
457    /// Subject the constraints to the column projection, reworking and removing equivalences.
458    /// This method should also introduce equivalences representing any repeated columns.
459    /// This notably does not project the held back classes.
460    pub fn project<I>(&mut self, output_columns: I)
461    where
462        I: IntoIterator<Item = usize> + Clone,
463    {
464        self.equivalence_classes.project(output_columns.clone());
465    }
466
467    /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
468    pub fn extract_equivalences(
469        &mut self,
470        columns: Option<&[SqlColumnType]>,
471    ) -> Vec<Vec<MirScalarExpr>> {
472        self.equivalence_classes.minimize(columns);
473
474        if self.held_back_classes.is_empty() {
475            self.equivalence_classes.classes.clone()
476        } else {
477            let reducer = self.equivalence_classes.reducer();
478            let mut classes = self.equivalence_classes.classes.clone();
479            for (anchor, mut class) in self.held_back_classes.drain(..) {
480                if let Some(mut anchor) = anchor {
481                    // Reduce the anchor expression to its canonical form.
482                    reducer.reduce_expr(&mut anchor);
483                    class.push(anchor.clone());
484                }
485                classes.push(class.clone());
486            }
487            canonicalize_equivalence_classes(&mut classes);
488            classes
489        }
490    }
491}
492
493/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
494/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
495/// [`HumanizedEquivalenceClasses`].
496impl std::fmt::Display for EquivalenceClasses {
497    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
498        HumanizedEquivalenceClasses {
499            equivalence_classes: self,
500            cols: None,
501            mode: HumanizedExplain::default(),
502        }
503        .fmt(f)
504    }
505}
506
507/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
508/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
509/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
510/// `Display` nor `HumanizedExpr` is defined in this crate.)
511#[derive(Debug)]
512pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
513    /// The [`EquivalenceClasses`] to be humanized.
514    pub equivalence_classes: &'a EquivalenceClasses,
515    /// An optional vector of inferred column names to be used when rendering
516    /// column references in expressions.
517    pub cols: Option<&'a Vec<String>>,
518    /// The rendering mode to use. See `HumanizerMode` for details.
519    pub mode: M,
520}
521
522impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
523    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
524        // Only show `classes`.
525        // (The following hopefully avoids allocating any of the intermediate composite strings.)
526        let classes = self.equivalence_classes.classes.iter().map(|class| {
527            format!(
528                "{}",
529                bracketed(
530                    "[",
531                    "]",
532                    separated(
533                        ", ",
534                        class.iter().map(|expr| self.mode.expr(expr, self.cols))
535                    )
536                )
537            )
538        });
539        write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
540    }
541}
542
543impl EquivalenceClasses {
544    /// Comparator function for the complexity of scalar expressions. Simpler expressions are
545    /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
546    pub fn mir_scalar_expr_complexity(
547        e1: &MirScalarExpr,
548        e2: &MirScalarExpr,
549    ) -> std::cmp::Ordering {
550        use MirScalarExpr::*;
551        use std::cmp::Ordering::*;
552        match (e1, e2) {
553            (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
554            (Literal(_, _), _) => Less,
555            (_, Literal(_, _)) => Greater,
556            (Column(_, _), Column(_, _)) => e1.cmp(e2),
557            (Column(_, _), _) => Less,
558            (_, Column(_, _)) => Greater,
559            (x, y) => {
560                // General expressions should be ordered by their size,
561                // to ensure we only simplify expressions by substitution.
562                // If same size, then fall back to the expressions' Ord.
563                match x.size().cmp(&y.size()) {
564                    Equal => x.cmp(y),
565                    other => other,
566                }
567            }
568        }
569    }
570
571    /// Sorts and deduplicates each class, removing literal errors.
572    ///
573    /// This method does not ensure equivalence relation structure, but instead performs
574    /// only minimal structural clean-up.
575    fn tidy(&mut self) {
576        for class in self.classes.iter_mut() {
577            // Remove all literal errors, as they cannot be equated to other things.
578            class.retain(|e| !e.is_literal_err());
579            class.sort_by(Self::mir_scalar_expr_complexity);
580            class.dedup();
581        }
582        self.classes.retain(|c| c.len() > 1);
583        self.classes.sort();
584        self.classes.dedup();
585    }
586
587    /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
588    ///
589    /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
590    /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
591    /// whether the equivalence classes it represents have experienced any changes since the
592    /// last refresh.
593    fn refresh(&mut self) -> bool {
594        self.tidy();
595
596        // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
597        // If it contains the same number of expressions as `self.classes`, and for every expression in
598        // `self.classes` the two agree on the representative, they are identical.
599        if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
600            && self
601                .classes
602                .iter()
603                .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
604        {
605            // No change, so return false.
606            return false;
607        }
608
609        // Optimistically build the `remap` we would want.
610        // Note if any unions would be required, in which case we have further work to do,
611        // including re-forming `self.classes`.
612        let mut union_find = BTreeMap::default();
613        let mut dirtied = false;
614        for class in self.classes.iter() {
615            for expr in class.iter() {
616                if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
617                    // A merge is required, but have the more complex expression point at the simpler one.
618                    // This allows `union_find` to end as the `remap` for the new `classes` we form, with
619                    // the only required work being compressing all the paths.
620                    if Self::mir_scalar_expr_complexity(&other, &class[0])
621                        == std::cmp::Ordering::Less
622                    {
623                        union_find.union(&class[0], &other);
624                    } else {
625                        union_find.union(&other, &class[0]);
626                    }
627                    dirtied = true;
628                }
629            }
630        }
631        if dirtied {
632            let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
633            for class in self.classes.drain(..) {
634                for expr in class {
635                    let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
636                    classes.entry(root).or_default().push(expr);
637                }
638            }
639            self.classes = classes.into_values().collect();
640            self.tidy();
641        }
642
643        let changed = self.remap != union_find;
644        self.remap = union_find;
645        changed
646    }
647
648    /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
649    ///
650    /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
651    pub fn minimize(&mut self, columns: Option<&[SqlColumnType]>) {
652        // Repeatedly, we reduce each of the classes themselves, then unify the classes.
653        // This should strictly reduce complexity, and reach a fixed point.
654        // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
655
656        // We should not rely on nullability information present in `column_types`. (Doing this
657        // every time just before calling `reduce` was found to be a bottleneck during incident-217,
658        // so now we do this nullability tweaking only once here.)
659        let mut columns = columns.map(|x| x.to_vec());
660        let mut nonnull = Vec::new();
661        if let Some(columns) = columns.as_mut() {
662            for (index, col) in columns.iter_mut().enumerate() {
663                let is_null = MirScalarExpr::column(index).call_is_null();
664                if !col.nullable
665                    && self
666                        .remap
667                        .get(&is_null)
668                        .map(|e| !e.is_literal_false())
669                        .unwrap_or(true)
670                {
671                    nonnull.push(is_null);
672                }
673                col.nullable = true;
674            }
675        }
676        if !nonnull.is_empty() {
677            nonnull.push(MirScalarExpr::literal_false());
678            self.classes.push(nonnull);
679        }
680
681        // Ensure `self.classes` and `self.remap` are equivalence relations.
682        // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
683        // We have also likely mutated `self.classes` just above with non-nullability information.
684        self.refresh();
685
686        // Termination will be detected by comparing to the map of equivalence classes.
687        let mut previous = Some(self.remap.clone());
688        while let Some(prev) = previous {
689            // Attempt to add new equivalences.
690            let novel = self.expand();
691            if !novel.is_empty() {
692                self.classes.extend(novel);
693                self.refresh();
694            }
695
696            // We continue as long as any simplification has occurred.
697            // An expression can be simplified, a duplication found, or two classes unified.
698            let mut stable = false;
699            while !stable {
700                stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
701            }
702
703            // Termination detection.
704            if prev != self.remap {
705                previous = Some(self.remap.clone());
706            } else {
707                previous = None;
708            }
709        }
710    }
711
712    /// Proposes new equivalences that are likely to be novel.
713    ///
714    /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
715    /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
716    /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
717    /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
718    /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
719    /// number of times we call and amount of work we do in `self.refresh()`.
720    fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
721        // Consider expanding `self.classes` with novel equivalences.
722        let mut novel = self.implications();
723        for class in novel.iter_mut() {
724            // reduce each expression to its canonical form.
725            for expr in class.iter_mut() {
726                self.remap.reduce_expr(expr);
727            }
728            class.sort();
729            class.dedup();
730            // for a class to be interesting we require at least two elements that do not reference the same root.
731            let common_class = class
732                .iter()
733                .map(|x| self.remap.get(x))
734                .reduce(|prev, this| if prev == this { prev } else { None });
735            if class.len() == 1 || common_class != Some(None) {
736                class.clear();
737            }
738        }
739        novel.retain(|c| !c.is_empty());
740        novel
741    }
742
743    /// Derives potentially novel equivalences without regard for minimization.
744    ///
745    /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
746    /// and therefore should not be used in `minimize_once`. They are still potentially important, but
747    /// required additional guardrails to ensure we reach a fixed point.
748    ///
749    /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
750    /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
751    /// For example, before producing a new equivalence, one could check that the involved terms are not
752    /// already present in the same class.
753    fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
754        let mut new_equivalences = Vec::new();
755
756        // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
757        let mut non_null = std::collections::BTreeSet::default();
758        for class in self.classes.iter() {
759            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
760                for e in class.iter() {
761                    if let MirScalarExpr::CallUnary {
762                        func: mz_expr::UnaryFunc::IsNull(_),
763                        expr,
764                    } = e
765                    {
766                        expr.non_null_requirements(&mut non_null);
767                    }
768                }
769            }
770        }
771        // If we see `true == foo` we can add the non-null implications of `foo`.
772        // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
773        // an important idiom to identify for how we express predicates.
774        for class in self.classes.iter() {
775            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
776                for expr in class.iter() {
777                    expr.non_null_requirements(&mut non_null);
778                }
779            }
780        }
781        // Only keep constraints that are not already known.
782        // Known constraints will present as `COL(_) IS NULL == false`,
783        // which can only happen if `false` is present, and both terms
784        // map to the same canonical representative>
785        let lit_false = MirScalarExpr::literal_false();
786        let target = self.remap.get(&lit_false);
787        if target.is_some() {
788            non_null.retain(|c| {
789                let is_null = MirScalarExpr::column(*c).call_is_null();
790                self.remap.get(&is_null) != target
791            });
792        }
793        if !non_null.is_empty() {
794            let mut class = Vec::with_capacity(non_null.len() + 1);
795            class.push(MirScalarExpr::literal_false());
796            class.extend(
797                non_null
798                    .into_iter()
799                    .map(|c| MirScalarExpr::column(c).call_is_null()),
800            );
801            new_equivalences.push(class);
802        }
803
804        // If we see records formed from other expressions, we can equate the expressions with
805        // accessors applied to the class of the record former. In `minimize_once` we reduce by
806        // equivalence class representative before we perform expression simplification, so we
807        // shoud be able to just use the expression former, rather than find its representative.
808        // The risk, potentially, is that we would apply accessors to the record former and then
809        // just simplify it away learning nothing.
810        for class in self.classes.iter() {
811            for expr in class.iter() {
812                // Record-forming expressions can equate their accessors and their members.
813                if let MirScalarExpr::CallVariadic {
814                    func: mz_expr::VariadicFunc::RecordCreate { .. },
815                    exprs,
816                } = expr
817                {
818                    for (index, e) in exprs.iter().enumerate() {
819                        new_equivalences.push(vec![
820                            e.clone(),
821                            expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
822                                mz_expr::func::RecordGet(index),
823                            )),
824                        ]);
825                    }
826                }
827            }
828        }
829
830        // Return all newly established equivalences.
831        new_equivalences
832    }
833
834    /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
835    ///
836    /// This invocation should take roughly linear time.
837    /// It starts with equivalence class invariants maintained (closed under transitivity), and then
838    ///   1. Performs per-expression reduction, including the class structure to replace subexpressions.
839    ///   2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
840    ///   3. Restores the equivalence class invariants.
841    fn minimize_once(&mut self, columns: Option<&[SqlColumnType]>) -> bool {
842        // 1. Reduce each expression
843        //
844        // This reduction first looks for subexpression substitutions that can be performed,
845        // and then applies expression reduction if column type information is provided.
846        for class in self.classes.iter_mut() {
847            for expr in class.iter_mut() {
848                self.remap.reduce_child(expr);
849                if let Some(columns) = columns {
850                    let orig_expr = expr.clone();
851                    expr.reduce(columns);
852                    if expr.contains_err() {
853                        // Rollback to the original expression if it contains an error.
854                        *expr = orig_expr;
855                    }
856                }
857            }
858        }
859
860        // 2. Identify idioms
861        //    E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
862        let mut to_add = Vec::new();
863        for class in self.classes.iter_mut() {
864            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
865                for expr in class.iter() {
866                    // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
867                    // This substitution replaces a complex expression with several smaller expressions, and cannot
868                    // cycle if we follow that practice.
869                    if let MirScalarExpr::CallBinary {
870                        func: mz_expr::BinaryFunc::Eq,
871                        expr1,
872                        expr2,
873                    } = expr
874                    {
875                        to_add.push(vec![*expr1.clone(), *expr2.clone()]);
876                        to_add.push(vec![
877                            MirScalarExpr::literal_false(),
878                            expr1.clone().call_is_null(),
879                            expr2.clone().call_is_null(),
880                        ]);
881                    }
882                }
883                // Remove the more complex form of the expression.
884                class.retain(|expr| {
885                    if let MirScalarExpr::CallBinary {
886                        func: mz_expr::BinaryFunc::Eq,
887                        ..
888                    } = expr
889                    {
890                        false
891                    } else {
892                        true
893                    }
894                });
895                for expr in class.iter() {
896                    // If TRUE == NOT(X) then FALSE == X is a simpler form.
897                    if let MirScalarExpr::CallUnary {
898                        func: mz_expr::UnaryFunc::Not(_),
899                        expr: e,
900                    } = expr
901                    {
902                        to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
903                    }
904                }
905                class.retain(|expr| {
906                    if let MirScalarExpr::CallUnary {
907                        func: mz_expr::UnaryFunc::Not(_),
908                        ..
909                    } = expr
910                    {
911                        false
912                    } else {
913                        true
914                    }
915                });
916            }
917            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
918                for expr in class.iter() {
919                    // If FALSE == NOT(X) then TRUE == X is a simpler form.
920                    if let MirScalarExpr::CallUnary {
921                        func: mz_expr::UnaryFunc::Not(_),
922                        expr: e,
923                    } = expr
924                    {
925                        to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
926                    }
927                }
928                class.retain(|expr| {
929                    if let MirScalarExpr::CallUnary {
930                        func: mz_expr::UnaryFunc::Not(_),
931                        ..
932                    } = expr
933                    {
934                        false
935                    } else {
936                        true
937                    }
938                });
939            }
940        }
941        self.classes.extend(to_add);
942
943        // 3. Restore equivalence relation structure and observe if any changes result.
944        self.refresh()
945    }
946
947    /// Produce the equivalences present in both inputs.
948    pub fn union(&self, other: &Self) -> Self {
949        self.union_many([other])
950    }
951
952    /// The equivalence classes of terms equivalent in all inputs.
953    ///
954    /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
955    /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
956    /// correct, result.
957    ///
958    /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
959    /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
960    /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
961    pub fn union_many<'a, I>(&self, others: I) -> Self
962    where
963        I: IntoIterator<Item = &'a Self>,
964    {
965        // List of expressions in the intersection, and a proxy equivalence class identifier.
966        let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
967        // Map from expression to a proxy equivalence class identifier.
968        let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
969        for (key, val) in self.remap.iter() {
970            if !rekey.contains_key(val) {
971                rekey.insert(val, rekey.len());
972            }
973            intersection.push((key, rekey[val]));
974        }
975        for other in others {
976            // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
977            let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
978            intersection.retain_mut(|(key, idx)| {
979                if let Some(val) = other.remap.get(key) {
980                    if !rekey.contains_key(&(*idx, val)) {
981                        rekey.insert((*idx, val), rekey.len());
982                    }
983                    *idx = rekey[&(*idx, val)];
984                    true
985                } else {
986                    false
987                }
988            });
989        }
990        let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
991        for (key, vals) in intersection {
992            classes.entry(vals).or_default().push(key.clone())
993        }
994        let classes = classes.into_values().collect::<Vec<_>>();
995        let mut equivalences = EquivalenceClasses {
996            classes,
997            remap: Default::default(),
998        };
999        equivalences.minimize(None);
1000        equivalences
1001    }
1002
1003    /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1004    pub fn permute(&mut self, permutation: &[usize]) {
1005        for class in self.classes.iter_mut() {
1006            for expr in class.iter_mut() {
1007                expr.permute(permutation);
1008            }
1009        }
1010        self.remap.clear();
1011        self.minimize(None);
1012    }
1013
1014    /// Subject the constraints to the column projection, reworking and removing equivalences.
1015    ///
1016    /// This method should also introduce equivalences representing any repeated columns.
1017    pub fn project<I>(&mut self, output_columns: I)
1018    where
1019        I: IntoIterator<Item = usize> + Clone,
1020    {
1021        // Retain the first instance of each column, and record subsequent instances as duplicates.
1022        let mut dupes = Vec::new();
1023        let mut remap = BTreeMap::default();
1024        for (idx, col) in output_columns.into_iter().enumerate() {
1025            if let Some(pos) = remap.get(&col) {
1026                dupes.push((*pos, idx));
1027            } else {
1028                remap.insert(col, idx);
1029            }
1030        }
1031
1032        // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1033        // Many expressions may not be localized, but may reference canonical non-localized expressions
1034        // for classes that contain a localized expression; in that case we can "backport" the localized
1035        // expression to give expressions referencing the canonical expression a shot at localization.
1036        //
1037        // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1038        // to look any further than backporting those. Backporting should have the property that the simplest
1039        // localized expression in each class does not contain any non-localized canonical expressions
1040        // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1041        // expressions should never fire a second
1042
1043        // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1044        // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1045        // As we find localized expressions, we can replace uses of their equivalent representative with them,
1046        // which may allow further expression localization.
1047        // We continue the process until no further classes can be localized.
1048
1049        // A map from representatives to our first localization of their equivalence class.
1050        let mut localized = false;
1051        while !localized {
1052            localized = true;
1053            let mut current_map = BTreeMap::default();
1054            for class in self.classes.iter_mut() {
1055                if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1056                    if let Some(pos) = class
1057                        .iter()
1058                        .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1059                    {
1060                        class.swap(0, pos);
1061                        localized = false;
1062                    }
1063                }
1064                for expr in class[1..].iter() {
1065                    current_map.insert(expr.clone(), class[0].clone());
1066                }
1067            }
1068
1069            // attempt to replace representatives with equivalent localizeable expressions.
1070            for class_index in 0..self.classes.len() {
1071                for index in 0..self.classes[class_index].len() {
1072                    current_map.reduce_child(&mut self.classes[class_index][index]);
1073                }
1074            }
1075            // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1076        }
1077
1078        // Localize all localizable expressions and discard others.
1079        for class in self.classes.iter_mut() {
1080            class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1081            for expr in class.iter_mut() {
1082                expr.permute_map(&remap);
1083            }
1084        }
1085        self.classes.retain(|c| c.len() > 1);
1086        // If column repetitions, introduce them as equivalences.
1087        // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1088        for (col1, col2) in dupes {
1089            self.classes.push(vec![
1090                MirScalarExpr::column(col1),
1091                MirScalarExpr::column(col2),
1092            ]);
1093        }
1094        self.remap.clear();
1095        self.minimize(None);
1096    }
1097
1098    /// True if any equivalence class contains two distinct non-error literals.
1099    pub fn unsatisfiable(&self) -> bool {
1100        for class in self.classes.iter() {
1101            let mut literal_ok = None;
1102            for expr in class.iter() {
1103                if let MirScalarExpr::Literal(Ok(row), _) = expr {
1104                    if literal_ok.is_some() && literal_ok != Some(row) {
1105                        return true;
1106                    } else {
1107                        literal_ok = Some(row);
1108                    }
1109                }
1110            }
1111        }
1112        false
1113    }
1114
1115    /// Returns a map that can be used to replace (sub-)expressions.
1116    pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1117        &self.remap
1118    }
1119
1120    /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1121    ///
1122    /// This test bails out as soon as it sees a non-literal, and may have false negatives
1123    /// if the data are not sorted with literals at the front.
1124    fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1125    where
1126        P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1127    {
1128        class
1129            .iter()
1130            .take_while(|e| e.is_literal())
1131            .filter_map(|e| e.as_literal())
1132            .any(move |e| predicate(&e))
1133    }
1134}
1135
1136/// A type capable of simplifying `MirScalarExpr`s.
1137pub trait ExpressionReducer {
1138    /// Attempt to replace `expr` itself with another expression.
1139    /// Returns true if it does so.
1140    fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1141    /// Attempt to replace any subexpressions of `expr` with other expressions.
1142    /// Returns true if it does so.
1143    fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1144        let mut simplified = false;
1145        simplified = simplified || self.reduce_child(expr);
1146        simplified = simplified || self.replace(expr);
1147        simplified
1148    }
1149    /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1150    /// Returns true if it does so.
1151    fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1152        let mut simplified = false;
1153        match expr {
1154            MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1155                simplified = self.reduce_expr(expr1) || simplified;
1156                simplified = self.reduce_expr(expr2) || simplified;
1157            }
1158            MirScalarExpr::CallUnary { expr, .. } => {
1159                simplified = self.reduce_expr(expr) || simplified;
1160            }
1161            MirScalarExpr::CallVariadic { exprs, .. } => {
1162                for expr in exprs.iter_mut() {
1163                    simplified = self.reduce_expr(expr) || simplified;
1164                }
1165            }
1166            MirScalarExpr::If { cond: _, then, els } => {
1167                // Do not simplify `cond`, as we cannot ensure the simplification
1168                // continues to hold as expressions migrate around.
1169                simplified = self.reduce_expr(then) || simplified;
1170                simplified = self.reduce_expr(els) || simplified;
1171            }
1172            _ => {}
1173        }
1174        simplified
1175    }
1176}
1177
1178impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1179    /// Perform any exact replacement for `expr`, report if it had an effect.
1180    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1181        if let Some(other) = self.get(expr) {
1182            if other != &expr {
1183                expr.clone_from(other);
1184                return true;
1185            }
1186        }
1187        false
1188    }
1189}
1190
1191impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1192    /// Perform any exact replacement for `expr`, report if it had an effect.
1193    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1194        if let Some(other) = self.get(expr) {
1195            if other != expr {
1196                expr.clone_from(other);
1197                return true;
1198            }
1199        }
1200        false
1201    }
1202}
1203
1204/// True iff the aggregate function returns an input datum.
1205fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1206    match aggregate {
1207        AggregateFunc::MaxInt16
1208        | AggregateFunc::MaxInt32
1209        | AggregateFunc::MaxInt64
1210        | AggregateFunc::MaxUInt16
1211        | AggregateFunc::MaxUInt32
1212        | AggregateFunc::MaxUInt64
1213        | AggregateFunc::MaxMzTimestamp
1214        | AggregateFunc::MaxFloat32
1215        | AggregateFunc::MaxFloat64
1216        | AggregateFunc::MaxBool
1217        | AggregateFunc::MaxString
1218        | AggregateFunc::MaxDate
1219        | AggregateFunc::MaxTimestamp
1220        | AggregateFunc::MaxTimestampTz
1221        | AggregateFunc::MinInt16
1222        | AggregateFunc::MinInt32
1223        | AggregateFunc::MinInt64
1224        | AggregateFunc::MinUInt16
1225        | AggregateFunc::MinUInt32
1226        | AggregateFunc::MinUInt64
1227        | AggregateFunc::MinMzTimestamp
1228        | AggregateFunc::MinFloat32
1229        | AggregateFunc::MinFloat64
1230        | AggregateFunc::MinBool
1231        | AggregateFunc::MinString
1232        | AggregateFunc::MinDate
1233        | AggregateFunc::MinTimestamp
1234        | AggregateFunc::MinTimestampTz
1235        | AggregateFunc::Any
1236        | AggregateFunc::All => true,
1237        _ => false,
1238    }
1239}