mz_transform/analysis/
equivalences.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use mz_expr::explain::{HumanizedExplain, HumanizerMode};
22use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
23use mz_ore::str::{bracketed, separated};
24use mz_repr::{ColumnType, Datum};
25
26use crate::analysis::{Analysis, Lattice};
27use crate::analysis::{Arity, RelationType};
28use crate::analysis::{Derived, DerivedBuilder};
29
30/// Pulls up and pushes down predicate information represented as equivalences
31#[derive(Debug, Default)]
32pub struct Equivalences;
33
34impl Analysis for Equivalences {
35    // A `Some(list)` indicates a list of classes of equivalent expressions.
36    // A `None` indicates all expressions are equivalent, including contradictions;
37    // this is only possible for the empty collection, and as an initial result for
38    // unconstrained recursive terms.
39    type Value = Option<EquivalenceClasses>;
40
41    fn announce_dependencies(builder: &mut DerivedBuilder) {
42        builder.require(Arity);
43        builder.require(RelationType); // needed for expression reduction.
44    }
45
46    fn derive(
47        &self,
48        expr: &MirRelationExpr,
49        index: usize,
50        results: &[Self::Value],
51        depends: &Derived,
52    ) -> Self::Value {
53        let mut equivalences = match expr {
54            MirRelationExpr::Constant { rows, typ } => {
55                // Trawl `rows` for any constant information worth recording.
56                // Literal columns may be valuable; non-nullability could be too.
57                let mut equivalences = EquivalenceClasses::default();
58                if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
59                    // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
60                    let len = row.iter().count();
61                    let mut common = Vec::with_capacity(len);
62                    common.extend(row.iter().map(Some));
63                    // Prep initial nullability information.
64                    let mut nullable_cols = common
65                        .iter()
66                        .map(|datum| datum == &Some(Datum::Null))
67                        .collect::<Vec<_>>();
68
69                    for (row, _cnt) in rows.iter() {
70                        for ((datum, common), nullable) in row
71                            .iter()
72                            .zip(common.iter_mut())
73                            .zip(nullable_cols.iter_mut())
74                        {
75                            if Some(datum) != *common {
76                                *common = None;
77                            }
78                            if datum == Datum::Null {
79                                *nullable = true;
80                            }
81                        }
82                    }
83                    for (index, common) in common.into_iter().enumerate() {
84                        if let Some(datum) = common {
85                            equivalences.classes.push(vec![
86                                MirScalarExpr::Column(index),
87                                MirScalarExpr::literal_ok(
88                                    datum,
89                                    typ.column_types[index].scalar_type.clone(),
90                                ),
91                            ]);
92                        }
93                    }
94                    // If any columns are non-null, introduce this fact.
95                    if nullable_cols.iter().any(|x| !*x) {
96                        let mut class = vec![MirScalarExpr::literal_false()];
97                        for (index, nullable) in nullable_cols.iter().enumerate() {
98                            if !*nullable {
99                                class.push(MirScalarExpr::column(index).call_is_null());
100                            }
101                        }
102                        equivalences.classes.push(class);
103                    }
104                }
105                Some(equivalences)
106            }
107            MirRelationExpr::Get { id, typ, .. } => {
108                let mut equivalences = Some(EquivalenceClasses::default());
109                // Find local identifiers, but nothing for external identifiers.
110                if let Id::Local(id) = id {
111                    if let Some(offset) = depends.bindings().get(id) {
112                        // It is possible we have derived nothing for a recursive term
113                        if let Some(result) = results.get(*offset) {
114                            equivalences.clone_from(result);
115                        } else {
116                            // No top element was prepared.
117                            // This means we are executing pessimistically,
118                            // but perhaps we must because optimism is off.
119                        }
120                    }
121                }
122                // Incorporate statements about column nullability.
123                let mut non_null_cols = vec![MirScalarExpr::literal_false()];
124                for (index, col_type) in typ.column_types.iter().enumerate() {
125                    if !col_type.nullable {
126                        non_null_cols.push(MirScalarExpr::column(index).call_is_null());
127                    }
128                }
129                if non_null_cols.len() > 1 {
130                    if let Some(equivalences) = equivalences.as_mut() {
131                        equivalences.classes.push(non_null_cols);
132                    }
133                }
134
135                equivalences
136            }
137            MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
138            MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
139            MirRelationExpr::Project { outputs, .. } => {
140                // restrict equivalences, and introduce equivalences for repeated outputs.
141                let mut equivalences = results.get(index - 1).unwrap().clone();
142                equivalences
143                    .as_mut()
144                    .map(|e| e.project(outputs.iter().cloned()));
145                equivalences
146            }
147            MirRelationExpr::Map { scalars, .. } => {
148                // introduce equivalences for new columns and expressions that define them.
149                let mut equivalences = results.get(index - 1).unwrap().clone();
150                if let Some(equivalences) = &mut equivalences {
151                    let input_arity = depends.results::<Arity>()[index - 1];
152                    for (pos, expr) in scalars.iter().enumerate() {
153                        equivalences
154                            .classes
155                            .push(vec![MirScalarExpr::Column(input_arity + pos), expr.clone()]);
156                    }
157                }
158                equivalences
159            }
160            MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
161            MirRelationExpr::Filter { predicates, .. } => {
162                let mut equivalences = results.get(index - 1).unwrap().clone();
163                if let Some(equivalences) = &mut equivalences {
164                    let mut class = predicates.clone();
165                    class.push(MirScalarExpr::literal_ok(
166                        Datum::True,
167                        mz_repr::ScalarType::Bool,
168                    ));
169                    equivalences.classes.push(class);
170                }
171                equivalences
172            }
173            MirRelationExpr::Join { equivalences, .. } => {
174                // Collect equivalences from all inputs;
175                let expr_index = index;
176                let mut children = depends
177                    .children_of_rev(expr_index, expr.children().count())
178                    .collect::<Vec<_>>();
179                children.reverse();
180
181                let arity = depends.results::<Arity>();
182                let mut columns = 0;
183                let mut result = Some(EquivalenceClasses::default());
184                for child in children.into_iter() {
185                    let input_arity = arity[child];
186                    let equivalences = results[child].clone();
187                    if let Some(mut equivalences) = equivalences {
188                        let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
189                        equivalences.permute(&permutation);
190                        result
191                            .as_mut()
192                            .map(|e| e.classes.extend(equivalences.classes));
193                    } else {
194                        result = None;
195                    }
196                    columns += input_arity;
197                }
198
199                // Fold join equivalences into our results.
200                result
201                    .as_mut()
202                    .map(|e| e.classes.extend(equivalences.iter().cloned()));
203                result
204            }
205            MirRelationExpr::Reduce {
206                group_key,
207                aggregates,
208                ..
209            } => {
210                let input_arity = depends.results::<Arity>()[index - 1];
211                let mut equivalences = results.get(index - 1).unwrap().clone();
212                if let Some(equivalences) = &mut equivalences {
213                    // Introduce keys column equivalences as if a map, then project to those columns.
214                    // This should retain as much information as possible about these columns.
215                    for (pos, expr) in group_key.iter().enumerate() {
216                        equivalences
217                            .classes
218                            .push(vec![MirScalarExpr::Column(input_arity + pos), expr.clone()]);
219                    }
220
221                    // Having added classes to `equivalences`, we should minimize the classes to fold the
222                    // information in before applying the `project`, to set it up for success.
223                    equivalences.minimize(None);
224
225                    // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
226                    let extended = equivalences.clone();
227                    // Now project down the equivalences, as we will extend them in terms of the output columns.
228                    equivalences.project(input_arity..(input_arity + group_key.len()));
229
230                    // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
231                    // They also pass through equivalences of them and other constant columns (e.g. key columns).
232                    // However, it is not correct to simply project onto these columns, as relationships amongst
233                    // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
234                    // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
235                    // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
236                    for (index, aggregate) in aggregates.iter().enumerate() {
237                        if aggregate_is_input(&aggregate.func) {
238                            let mut temp_equivs = extended.clone();
239                            temp_equivs.classes.push(vec![
240                                MirScalarExpr::column(input_arity + group_key.len()),
241                                aggregate.expr.clone(),
242                            ]);
243                            temp_equivs.minimize(None);
244                            temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
245                            let columns = (0..group_key.len())
246                                .chain(std::iter::once(group_key.len() + index))
247                                .collect::<Vec<_>>();
248                            temp_equivs.permute(&columns[..]);
249                            equivalences.classes.extend(temp_equivs.classes);
250                        }
251                    }
252                }
253                equivalences
254            }
255            MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
256            MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
257            MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
258            MirRelationExpr::Union { .. } => {
259                let expr_index = index;
260                let mut child_equivs = depends
261                    .children_of_rev(expr_index, expr.children().count())
262                    .flat_map(|c| &results[c]);
263                if let Some(first) = child_equivs.next() {
264                    Some(first.union_many(child_equivs))
265                } else {
266                    None
267                }
268            }
269            MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
270        };
271
272        let expr_type = depends.results::<RelationType>()[index].as_ref();
273        equivalences
274            .as_mut()
275            .map(|e| e.minimize(expr_type.map(|x| &x[..])));
276        equivalences
277    }
278
279    fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
280        Some(Box::new(EQLattice))
281    }
282}
283
284struct EQLattice;
285
286impl Lattice<Option<EquivalenceClasses>> for EQLattice {
287    fn top(&self) -> Option<EquivalenceClasses> {
288        None
289    }
290
291    fn meet_assign(
292        &self,
293        a: &mut Option<EquivalenceClasses>,
294        b: Option<EquivalenceClasses>,
295    ) -> bool {
296        match (&mut *a, b) {
297            (_, None) => false,
298            (None, b) => {
299                *a = b;
300                true
301            }
302            (Some(a), Some(b)) => {
303                let mut c = a.union(&b);
304                std::mem::swap(a, &mut c);
305                a != &mut c
306            }
307        }
308    }
309}
310
311/// A compact representation of classes of expressions that must be equivalent.
312///
313/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
314/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
315/// and any other element of that list can be replaced by it.
316///
317/// The classes are meant to be minimized, with each expression as reduced as it can be,
318/// and all classes sharing an element merged.
319#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
320pub struct EquivalenceClasses {
321    /// Multiple lists of equivalent expressions, each representing an equivalence class.
322    ///
323    /// The first element should be the "canonical" simplest element, that any other element
324    /// can be replaced by.
325    /// These classes are unified whenever possible, to minimize the number of classes.
326    /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
327    /// which refreshes both `self.classes` and `self.remap`.
328    pub classes: Vec<Vec<MirScalarExpr>>,
329
330    /// An expression simplification map.
331    ///
332    /// This map reflects an equivalence relation based on a prior version of `self.classes`.
333    /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
334    /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
335    ///
336    /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
337    /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
338    /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
339    remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
340}
341
342/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
343/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
344/// [`HumanizedEquivalenceClasses`].
345impl std::fmt::Display for EquivalenceClasses {
346    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
347        HumanizedEquivalenceClasses {
348            equivalence_classes: self,
349            cols: None,
350            mode: HumanizedExplain::default(),
351        }
352        .fmt(f)
353    }
354}
355
356/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
357/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
358/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
359/// `Display` nor `HumanizedExpr` is defined in this crate.)
360#[derive(Debug)]
361pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
362    /// The [`EquivalenceClasses`] to be humanized.
363    pub equivalence_classes: &'a EquivalenceClasses,
364    /// An optional vector of inferred column names to be used when rendering
365    /// column references in expressions.
366    pub cols: Option<&'a Vec<String>>,
367    /// The rendering mode to use. See `HumanizerMode` for details.
368    pub mode: M,
369}
370
371impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
372    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
373        // Only show `classes`.
374        // (The following hopefully avoids allocating any of the intermediate composite strings.)
375        let classes = self.equivalence_classes.classes.iter().map(|class| {
376            format!(
377                "{}",
378                bracketed(
379                    "[",
380                    "]",
381                    separated(
382                        ", ",
383                        class.iter().map(|expr| self.mode.expr(expr, self.cols))
384                    )
385                )
386            )
387        });
388        write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
389    }
390}
391
392impl EquivalenceClasses {
393    /// Comparator function for the complexity of scalar expressions. Simpler expressions are
394    /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
395    pub fn mir_scalar_expr_complexity(
396        e1: &MirScalarExpr,
397        e2: &MirScalarExpr,
398    ) -> std::cmp::Ordering {
399        use MirScalarExpr::*;
400        use std::cmp::Ordering::*;
401        match (e1, e2) {
402            (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
403            (Literal(_, _), _) => Less,
404            (_, Literal(_, _)) => Greater,
405            (Column(_), Column(_)) => e1.cmp(e2),
406            (Column(_), _) => Less,
407            (_, Column(_)) => Greater,
408            (x, y) => {
409                // General expressions should be ordered by their size,
410                // to ensure we only simplify expressions by substitution.
411                // If same size, then fall back to the expressions' Ord.
412                match x.size().cmp(&y.size()) {
413                    Equal => x.cmp(y),
414                    other => other,
415                }
416            }
417        }
418    }
419
420    /// Sorts and deduplicates each class, removing literal errors.
421    ///
422    /// This method does not ensure equivalence relation structure, but instead performs
423    /// only minimal structural clean-up.
424    fn tidy(&mut self) {
425        for class in self.classes.iter_mut() {
426            // Remove all literal errors, as they cannot be equated to other things.
427            class.retain(|e| !e.is_literal_err());
428            class.sort_by(Self::mir_scalar_expr_complexity);
429            class.dedup();
430        }
431        self.classes.retain(|c| c.len() > 1);
432        self.classes.sort();
433        self.classes.dedup();
434    }
435
436    /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
437    ///
438    /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
439    /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
440    /// whether the equivalence classes it represents have experienced any changes since the
441    /// last refresh.
442    fn refresh(&mut self) -> bool {
443        self.tidy();
444
445        // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
446        // If it contains the same number of expressions as `self.classes`, and for every expression in
447        // `self.classes` the two agree on the representative, they are identical.
448        if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
449            && self
450                .classes
451                .iter()
452                .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
453        {
454            // No change, so return false.
455            return false;
456        }
457
458        // Optimistically build the `remap` we would want.
459        // Note if any unions would be required, in which case we have further work to do,
460        // including re-forming `self.classes`.
461        let mut union_find = BTreeMap::default();
462        let mut dirtied = false;
463        for class in self.classes.iter() {
464            for expr in class.iter() {
465                if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
466                    // A merge is required, but have the more complex expression point at the simpler one.
467                    // This allows `union_find` to end as the `remap` for the new `classes` we form, with
468                    // the only required work being compressing all the paths.
469                    if Self::mir_scalar_expr_complexity(&other, &class[0])
470                        == std::cmp::Ordering::Less
471                    {
472                        union_find.union(&class[0], &other);
473                    } else {
474                        union_find.union(&other, &class[0]);
475                    }
476                    dirtied = true;
477                }
478            }
479        }
480        if dirtied {
481            let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
482            for class in self.classes.drain(..) {
483                for expr in class {
484                    let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
485                    classes.entry(root).or_default().push(expr);
486                }
487            }
488            self.classes = classes.into_values().collect();
489            self.tidy();
490        }
491
492        let changed = self.remap != union_find;
493        self.remap = union_find;
494        changed
495    }
496
497    /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
498    ///
499    /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
500    pub fn minimize(&mut self, columns: Option<&[ColumnType]>) {
501        // Repeatedly, we reduce each of the classes themselves, then unify the classes.
502        // This should strictly reduce complexity, and reach a fixed point.
503        // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
504
505        // We should not rely on nullability information present in `column_types`. (Doing this
506        // every time just before calling `reduce` was found to be a bottleneck during incident-217,
507        // so now we do this nullability tweaking only once here.)
508        let mut columns = columns.map(|x| x.to_vec());
509        let mut nonnull = Vec::new();
510        if let Some(columns) = columns.as_mut() {
511            for (index, col) in columns.iter_mut().enumerate() {
512                let is_null = MirScalarExpr::column(index).call_is_null();
513                if !col.nullable
514                    && self
515                        .remap
516                        .get(&is_null)
517                        .map(|e| !e.is_literal_false())
518                        .unwrap_or(true)
519                {
520                    nonnull.push(is_null);
521                }
522                col.nullable = true;
523            }
524        }
525        if !nonnull.is_empty() {
526            nonnull.push(MirScalarExpr::literal_false());
527            self.classes.push(nonnull);
528        }
529
530        // Ensure `self.classes` and `self.remap` are equivalence relations.
531        // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
532        // We have also likely mutated `self.classes` just above with non-nullability information.
533        self.refresh();
534
535        // Termination will be detected by comparing to the map of equivalence classes.
536        let mut previous = Some(self.remap.clone());
537        while let Some(prev) = previous {
538            // Attempt to add new equivalences.
539            let novel = self.expand();
540            if !novel.is_empty() {
541                self.classes.extend(novel);
542                self.refresh();
543            }
544
545            // We continue as long as any simplification has occurred.
546            // An expression can be simplified, a duplication found, or two classes unified.
547            let mut stable = false;
548            while !stable {
549                stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
550            }
551
552            // Termination detection.
553            if prev != self.remap {
554                previous = Some(self.remap.clone());
555            } else {
556                previous = None;
557            }
558        }
559    }
560
561    /// Proposes new equivalences that are likely to be novel.
562    ///
563    /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
564    /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
565    /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
566    /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
567    /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
568    /// number of times we call and amount of work we do in `self.refresh()`.
569    fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
570        // Consider expanding `self.classes` with novel equivalences.
571        let mut novel = self.implications();
572        for class in novel.iter_mut() {
573            // reduce each expression to its canonical form.
574            for expr in class.iter_mut() {
575                self.remap.reduce_expr(expr);
576            }
577            class.sort();
578            class.dedup();
579            // for a class to be interesting we require at least two elements that do not reference the same root.
580            let common_class = class
581                .iter()
582                .map(|x| self.remap.get(x))
583                .reduce(|prev, this| if prev == this { prev } else { None });
584            if class.len() == 1 || common_class != Some(None) {
585                class.clear();
586            }
587        }
588        novel.retain(|c| !c.is_empty());
589        novel
590    }
591
592    /// Derives potentially novel equivalences without regard for minimization.
593    ///
594    /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
595    /// and therefore should not be used in `minimize_once`. They are still potentially important, but
596    /// required additional guardrails to ensure we reach a fixed point.
597    ///
598    /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
599    /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
600    /// For example, before producing a new equivalence, one could check that the involved terms are not
601    /// already present in the same class.
602    fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
603        let mut new_equivalences = Vec::new();
604
605        // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
606        let mut non_null = std::collections::BTreeSet::default();
607        for class in self.classes.iter() {
608            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
609                for e in class.iter() {
610                    if let MirScalarExpr::CallUnary {
611                        func: mz_expr::UnaryFunc::IsNull(_),
612                        expr,
613                    } = e
614                    {
615                        expr.non_null_requirements(&mut non_null);
616                    }
617                }
618            }
619        }
620        // If we see `true == foo` we can add the non-null implications of `foo`.
621        // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
622        // an important idiom to identify for how we express predicates.
623        for class in self.classes.iter() {
624            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
625                for expr in class.iter() {
626                    expr.non_null_requirements(&mut non_null);
627                }
628            }
629        }
630        // Only keep constraints that are not already known.
631        // Known constraints will present as `COL(_) IS NULL == false`,
632        // which can only happen if `false` is present, and both terms
633        // map to the same canonical representative>
634        let lit_false = MirScalarExpr::literal_false();
635        let target = self.remap.get(&lit_false);
636        if target.is_some() {
637            non_null.retain(|c| {
638                let is_null = MirScalarExpr::column(*c).call_is_null();
639                self.remap.get(&is_null) != target
640            });
641        }
642        if !non_null.is_empty() {
643            let mut class = Vec::with_capacity(non_null.len() + 1);
644            class.push(MirScalarExpr::literal_false());
645            class.extend(
646                non_null
647                    .into_iter()
648                    .map(|c| MirScalarExpr::column(c).call_is_null()),
649            );
650            new_equivalences.push(class);
651        }
652
653        // If we see records formed from other expressions, we can equate the expressions with
654        // accessors applied to the class of the record former. In `minimize_once` we reduce by
655        // equivalence class representative before we perform expression simplification, so we
656        // shoud be able to just use the expression former, rather than find its representative.
657        // The risk, potentially, is that we would apply accessors to the record former and then
658        // just simplify it away learning nothing.
659        for class in self.classes.iter() {
660            for expr in class.iter() {
661                // Record-forming expressions can equate their accessors and their members.
662                if let MirScalarExpr::CallVariadic {
663                    func: mz_expr::VariadicFunc::RecordCreate { .. },
664                    exprs,
665                } = expr
666                {
667                    for (index, e) in exprs.iter().enumerate() {
668                        new_equivalences.push(vec![
669                            e.clone(),
670                            expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
671                                mz_expr::func::RecordGet(index),
672                            )),
673                        ]);
674                    }
675                }
676            }
677        }
678
679        // Return all newly established equivalences.
680        new_equivalences
681    }
682
683    /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
684    ///
685    /// This invocation should take roughly linear time.
686    /// It starts with equivalence class invariants maintained (closed under transitivity), and then
687    ///   1. Performs per-expression reduction, including the class structure to replace subexpressions.
688    ///   2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
689    ///   3. Restores the equivalence class invariants.
690    fn minimize_once(&mut self, columns: Option<&[ColumnType]>) -> bool {
691        // 1. Reduce each expression
692        //
693        // This reduction first looks for subexpression substitutions that can be performed,
694        // and then applies expression reduction if column type information is provided.
695        for class in self.classes.iter_mut() {
696            for expr in class.iter_mut() {
697                self.remap.reduce_child(expr);
698                if let Some(columns) = columns {
699                    expr.reduce(columns);
700                }
701            }
702        }
703
704        // 2. Identify idioms
705        //    E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
706        let mut to_add = Vec::new();
707        for class in self.classes.iter_mut() {
708            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
709                for expr in class.iter() {
710                    // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
711                    // This substitution replaces a complex expression with several smaller expressions, and cannot
712                    // cycle if we follow that practice.
713                    if let MirScalarExpr::CallBinary {
714                        func: mz_expr::BinaryFunc::Eq,
715                        expr1,
716                        expr2,
717                    } = expr
718                    {
719                        to_add.push(vec![*expr1.clone(), *expr2.clone()]);
720                        to_add.push(vec![
721                            MirScalarExpr::literal_false(),
722                            expr1.clone().call_is_null(),
723                            expr2.clone().call_is_null(),
724                        ]);
725                    }
726                }
727                // Remove the more complex form of the expression.
728                class.retain(|expr| {
729                    if let MirScalarExpr::CallBinary {
730                        func: mz_expr::BinaryFunc::Eq,
731                        ..
732                    } = expr
733                    {
734                        false
735                    } else {
736                        true
737                    }
738                });
739                for expr in class.iter() {
740                    // If TRUE == NOT(X) then FALSE == X is a simpler form.
741                    if let MirScalarExpr::CallUnary {
742                        func: mz_expr::UnaryFunc::Not(_),
743                        expr: e,
744                    } = expr
745                    {
746                        to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
747                    }
748                }
749                class.retain(|expr| {
750                    if let MirScalarExpr::CallUnary {
751                        func: mz_expr::UnaryFunc::Not(_),
752                        ..
753                    } = expr
754                    {
755                        false
756                    } else {
757                        true
758                    }
759                });
760            }
761            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
762                for expr in class.iter() {
763                    // If FALSE == NOT(X) then TRUE == X is a simpler form.
764                    if let MirScalarExpr::CallUnary {
765                        func: mz_expr::UnaryFunc::Not(_),
766                        expr: e,
767                    } = expr
768                    {
769                        to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
770                    }
771                }
772                class.retain(|expr| {
773                    if let MirScalarExpr::CallUnary {
774                        func: mz_expr::UnaryFunc::Not(_),
775                        ..
776                    } = expr
777                    {
778                        false
779                    } else {
780                        true
781                    }
782                });
783            }
784        }
785        self.classes.extend(to_add);
786
787        // 3. Restore equivalence relation structure and observe if any changes result.
788        self.refresh()
789    }
790
791    /// Produce the equivalences present in both inputs.
792    pub fn union(&self, other: &Self) -> Self {
793        self.union_many([other])
794    }
795
796    /// The equivalence classes of terms equivalent in all inputs.
797    ///
798    /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
799    /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
800    /// correct, result.
801    ///
802    /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
803    /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
804    /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
805    pub fn union_many<'a, I>(&self, others: I) -> Self
806    where
807        I: IntoIterator<Item = &'a Self>,
808    {
809        // List of expressions in the intersection, and a proxy equivalence class identifier.
810        let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
811        // Map from expression to a proxy equivalence class identifier.
812        let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
813        for (key, val) in self.remap.iter() {
814            if !rekey.contains_key(val) {
815                rekey.insert(val, rekey.len());
816            }
817            intersection.push((key, rekey[val]));
818        }
819        for other in others {
820            // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
821            let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
822            intersection.retain_mut(|(key, idx)| {
823                if let Some(val) = other.remap.get(key) {
824                    if !rekey.contains_key(&(*idx, val)) {
825                        rekey.insert((*idx, val), rekey.len());
826                    }
827                    *idx = rekey[&(*idx, val)];
828                    true
829                } else {
830                    false
831                }
832            });
833        }
834        let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
835        for (key, vals) in intersection {
836            classes.entry(vals).or_default().push(key.clone())
837        }
838        let classes = classes.into_values().collect::<Vec<_>>();
839        let mut equivalences = EquivalenceClasses {
840            classes,
841            remap: Default::default(),
842        };
843        equivalences.minimize(None);
844        equivalences
845    }
846
847    /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
848    pub fn permute(&mut self, permutation: &[usize]) {
849        for class in self.classes.iter_mut() {
850            for expr in class.iter_mut() {
851                expr.permute(permutation);
852            }
853        }
854        self.remap.clear();
855        self.minimize(None);
856    }
857
858    /// Subject the constraints to the column projection, reworking and removing equivalences.
859    ///
860    /// This method should also introduce equivalences representing any repeated columns.
861    pub fn project<I>(&mut self, output_columns: I)
862    where
863        I: IntoIterator<Item = usize> + Clone,
864    {
865        // Retain the first instance of each column, and record subsequent instances as duplicates.
866        let mut dupes = Vec::new();
867        let mut remap = BTreeMap::default();
868        for (idx, col) in output_columns.into_iter().enumerate() {
869            if let Some(pos) = remap.get(&col) {
870                dupes.push((*pos, idx));
871            } else {
872                remap.insert(col, idx);
873            }
874        }
875
876        // Some expressions may be "localized" in that they only reference columns in `output_columns`.
877        // Many expressions may not be localized, but may reference canonical non-localized expressions
878        // for classes that contain a localized expression; in that case we can "backport" the localized
879        // expression to give expressions referencing the canonical expression a shot at localization.
880        //
881        // Expressions should only contain instances of canonical expressions, and so we shouldn't need
882        // to look any further than backporting those. Backporting should have the property that the simplest
883        // localized expression in each class does not contain any non-localized canonical expressions
884        // (as that would make it non-localized); our backporting of non-localized canonicals with localized
885        // expressions should never fire a second
886
887        // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
888        // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
889        // As we find localized expressions, we can replace uses of their equivalent representative with them,
890        // which may allow further expression localization.
891        // We continue the process until no further classes can be localized.
892
893        // A map from representatives to our first localization of their equivalence class.
894        let mut localized = false;
895        while !localized {
896            localized = true;
897            let mut current_map = BTreeMap::default();
898            for class in self.classes.iter_mut() {
899                if !class[0].support().iter().all(|c| remap.contains_key(c)) {
900                    if let Some(pos) = class
901                        .iter()
902                        .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
903                    {
904                        class.swap(0, pos);
905                        localized = false;
906                    }
907                }
908                for expr in class[1..].iter() {
909                    current_map.insert(expr.clone(), class[0].clone());
910                }
911            }
912
913            // attempt to replace representatives with equivalent localizeable expressions.
914            for class_index in 0..self.classes.len() {
915                for index in 0..self.classes[class_index].len() {
916                    current_map.reduce_child(&mut self.classes[class_index][index]);
917                }
918            }
919            // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
920        }
921
922        // Localize all localizable expressions and discard others.
923        for class in self.classes.iter_mut() {
924            class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
925            for expr in class.iter_mut() {
926                expr.permute_map(&remap);
927            }
928        }
929        self.classes.retain(|c| c.len() > 1);
930        // If column repetitions, introduce them as equivalences.
931        // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
932        for (col1, col2) in dupes {
933            self.classes.push(vec![
934                MirScalarExpr::Column(col1),
935                MirScalarExpr::Column(col2),
936            ]);
937        }
938        self.remap.clear();
939        self.minimize(None);
940    }
941
942    /// True if any equivalence class contains two distinct non-error literals.
943    pub fn unsatisfiable(&self) -> bool {
944        for class in self.classes.iter() {
945            let mut literal_ok = None;
946            for expr in class.iter() {
947                if let MirScalarExpr::Literal(Ok(row), _) = expr {
948                    if literal_ok.is_some() && literal_ok != Some(row) {
949                        return true;
950                    } else {
951                        literal_ok = Some(row);
952                    }
953                }
954            }
955        }
956        false
957    }
958
959    /// Returns a map that can be used to replace (sub-)expressions.
960    pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
961        &self.remap
962    }
963
964    /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
965    ///
966    /// This test bails out as soon as it sees a non-literal, and may have false negatives
967    /// if the data are not sorted with literals at the front.
968    fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
969    where
970        P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
971    {
972        class
973            .iter()
974            .take_while(|e| e.is_literal())
975            .filter_map(|e| e.as_literal())
976            .any(move |e| predicate(&e))
977    }
978}
979
980/// A type capable of simplifying `MirScalarExpr`s.
981pub trait ExpressionReducer {
982    /// Attempt to replace `expr` itself with another expression.
983    /// Returns true if it does so.
984    fn replace(&self, expr: &mut MirScalarExpr) -> bool;
985    /// Attempt to replace any subexpressions of `expr` with other expressions.
986    /// Returns true if it does so.
987    fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
988        let mut simplified = false;
989        simplified = simplified || self.reduce_child(expr);
990        simplified = simplified || self.replace(expr);
991        simplified
992    }
993    /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
994    /// Returns true if it does so.
995    fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
996        let mut simplified = false;
997        match expr {
998            MirScalarExpr::CallBinary { expr1, expr2, .. } => {
999                simplified = self.reduce_expr(expr1) || simplified;
1000                simplified = self.reduce_expr(expr2) || simplified;
1001            }
1002            MirScalarExpr::CallUnary { expr, .. } => {
1003                simplified = self.reduce_expr(expr) || simplified;
1004            }
1005            MirScalarExpr::CallVariadic { exprs, .. } => {
1006                for expr in exprs.iter_mut() {
1007                    simplified = self.reduce_expr(expr) || simplified;
1008                }
1009            }
1010            MirScalarExpr::If { cond: _, then, els } => {
1011                // Do not simplify `cond`, as we cannot ensure the simplification
1012                // continues to hold as expressions migrate around.
1013                simplified = self.reduce_expr(then) || simplified;
1014                simplified = self.reduce_expr(els) || simplified;
1015            }
1016            _ => {}
1017        }
1018        simplified
1019    }
1020}
1021
1022impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1023    /// Perform any exact replacement for `expr`, report if it had an effect.
1024    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1025        if let Some(other) = self.get(expr) {
1026            if other != &expr {
1027                expr.clone_from(other);
1028                return true;
1029            }
1030        }
1031        false
1032    }
1033}
1034
1035impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1036    /// Perform any exact replacement for `expr`, report if it had an effect.
1037    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1038        if let Some(other) = self.get(expr) {
1039            if other != expr {
1040                expr.clone_from(other);
1041                return true;
1042            }
1043        }
1044        false
1045    }
1046}
1047
1048trait UnionFind<T> {
1049    /// Sets `self[x]` to the root from `x`, and returns a reference to the root.
1050    fn find<'a>(&'a mut self, x: &T) -> Option<&'a T>;
1051    /// Ensures that `x` and `y` have the same root.
1052    fn union(&mut self, x: &T, y: &T);
1053}
1054
1055impl<T: Clone + Ord> UnionFind<T> for BTreeMap<T, T> {
1056    fn find<'a>(&'a mut self, x: &T) -> Option<&'a T> {
1057        if !self.contains_key(x) {
1058            None
1059        } else {
1060            if self[x] != self[&self[x]] {
1061                // Path halving
1062                let mut y = self[x].clone();
1063                while y != self[&y] {
1064                    let grandparent = self[&self[&y]].clone();
1065                    *self.get_mut(&y).unwrap() = grandparent;
1066                    y.clone_from(&self[&y]);
1067                }
1068                *self.get_mut(x).unwrap() = y;
1069            }
1070            Some(&self[x])
1071        }
1072    }
1073
1074    fn union(&mut self, x: &T, y: &T) {
1075        match (self.find(x).is_some(), self.find(y).is_some()) {
1076            (true, true) => {
1077                if self[x] != self[y] {
1078                    let root_x = self[x].clone();
1079                    let root_y = self[y].clone();
1080                    self.insert(root_x, root_y);
1081                }
1082            }
1083            (false, true) => {
1084                self.insert(x.clone(), self[y].clone());
1085            }
1086            (true, false) => {
1087                self.insert(y.clone(), self[x].clone());
1088            }
1089            (false, false) => {
1090                self.insert(x.clone(), x.clone());
1091                self.insert(y.clone(), x.clone());
1092            }
1093        }
1094    }
1095}
1096
1097/// True iff the aggregate function returns an input datum.
1098fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1099    match aggregate {
1100        AggregateFunc::MaxInt16
1101        | AggregateFunc::MaxInt32
1102        | AggregateFunc::MaxInt64
1103        | AggregateFunc::MaxUInt16
1104        | AggregateFunc::MaxUInt32
1105        | AggregateFunc::MaxUInt64
1106        | AggregateFunc::MaxMzTimestamp
1107        | AggregateFunc::MaxFloat32
1108        | AggregateFunc::MaxFloat64
1109        | AggregateFunc::MaxBool
1110        | AggregateFunc::MaxString
1111        | AggregateFunc::MaxDate
1112        | AggregateFunc::MaxTimestamp
1113        | AggregateFunc::MaxTimestampTz
1114        | AggregateFunc::MinInt16
1115        | AggregateFunc::MinInt32
1116        | AggregateFunc::MinInt64
1117        | AggregateFunc::MinUInt16
1118        | AggregateFunc::MinUInt32
1119        | AggregateFunc::MinUInt64
1120        | AggregateFunc::MinMzTimestamp
1121        | AggregateFunc::MinFloat32
1122        | AggregateFunc::MinFloat64
1123        | AggregateFunc::MinBool
1124        | AggregateFunc::MinString
1125        | AggregateFunc::MinDate
1126        | AggregateFunc::MinTimestamp
1127        | AggregateFunc::MinTimestampTz
1128        | AggregateFunc::Any
1129        | AggregateFunc::All => true,
1130        _ => false,
1131    }
1132}