Skip to main content

mz_transform/analysis/
equivalences.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! An analysis that reports all known-equivalent expressions for each relation.
11//!
12//! Expressions are equivalent at a relation if they are certain to evaluate to
13//! the same `Datum` for all records in the relation.
14//!
15//! Equivalences are recorded in an `EquivalenceClasses`, which lists all known
16//! equivalences classes, each a list of equivalent expressions.
17
18use std::collections::BTreeMap;
19use std::fmt::Formatter;
20
21use itertools::Itertools;
22use mz_expr::canonicalize::{UnionFind, canonicalize_equivalence_classes};
23use mz_expr::explain::{HumanizedExplain, HumanizerMode};
24use mz_expr::{AggregateFunc, Id, MirRelationExpr, MirScalarExpr};
25use mz_ore::str::{bracketed, separated};
26use mz_repr::{Datum, ReprColumnType, ReprScalarType};
27
28use crate::analysis::{Analysis, Lattice};
29use crate::analysis::{Arity, ReprRelationType};
30use crate::analysis::{Derived, DerivedBuilder};
31
32/// Pulls up and pushes down predicate information represented as equivalences
33#[derive(Debug, Default)]
34pub struct Equivalences;
35
36impl Analysis for Equivalences {
37    // A `Some(list)` indicates a list of classes of equivalent expressions.
38    // A `None` indicates all expressions are equivalent, including contradictions;
39    // this is only possible for the empty collection, and as an initial result for
40    // unconstrained recursive terms.
41    type Value = Option<EquivalenceClasses>;
42
43    fn announce_dependencies(builder: &mut DerivedBuilder) {
44        builder.require(Arity);
45        builder.require(ReprRelationType); // needed for expression reduction.
46    }
47
48    fn derive(
49        &self,
50        expr: &MirRelationExpr,
51        index: usize,
52        results: &[Self::Value],
53        depends: &Derived,
54    ) -> Self::Value {
55        let mut equivalences = match expr {
56            MirRelationExpr::Constant { rows, typ } => {
57                // Trawl `rows` for any constant information worth recording.
58                // Literal columns may be valuable; non-nullability could be too.
59                let mut equivalences = EquivalenceClasses::default();
60                if let Ok([(row, _cnt), rows @ ..]) = rows.as_deref() {
61                    // Vector of `Option<Datum>` which becomes `None` once a column has a second datum.
62                    let len = row.iter().count();
63                    let mut common = Vec::with_capacity(len);
64                    common.extend(row.iter().map(Some));
65                    // Prep initial nullability information.
66                    let mut nullable_cols = common
67                        .iter()
68                        .map(|datum| datum == &Some(Datum::Null))
69                        .collect::<Vec<_>>();
70
71                    for (row, _cnt) in rows.iter() {
72                        for ((datum, common), nullable) in row
73                            .iter()
74                            .zip_eq(common.iter_mut())
75                            .zip_eq(nullable_cols.iter_mut())
76                        {
77                            if Some(datum) != *common {
78                                *common = None;
79                            }
80                            if datum == Datum::Null {
81                                *nullable = true;
82                            }
83                        }
84                    }
85                    for (index, common) in common.into_iter().enumerate() {
86                        if let Some(datum) = common {
87                            equivalences.classes.push(vec![
88                                MirScalarExpr::column(index),
89                                MirScalarExpr::literal_ok(
90                                    datum,
91                                    typ.column_types[index].scalar_type.clone(),
92                                ),
93                            ]);
94                        }
95                    }
96                    // If any columns are non-null, introduce this fact.
97                    if nullable_cols.iter().any(|x| !*x) {
98                        let mut class = vec![MirScalarExpr::literal_false()];
99                        for (index, nullable) in nullable_cols.iter().enumerate() {
100                            if !*nullable {
101                                class.push(MirScalarExpr::column(index).call_is_null());
102                            }
103                        }
104                        equivalences.classes.push(class);
105                    }
106                }
107                Some(equivalences)
108            }
109            MirRelationExpr::Get { id, typ, .. } => {
110                let mut equivalences = Some(EquivalenceClasses::default());
111                // Find local identifiers, but nothing for external identifiers.
112                if let Id::Local(id) = id {
113                    if let Some(offset) = depends.bindings().get(id) {
114                        // It is possible we have derived nothing for a recursive term
115                        if let Some(result) = results.get(*offset) {
116                            equivalences.clone_from(result);
117                        } else {
118                            // No top element was prepared.
119                            // This means we are executing pessimistically,
120                            // but perhaps we must because optimism is off.
121                        }
122                    }
123                }
124                // Incorporate statements about column nullability.
125                let mut non_null_cols = vec![MirScalarExpr::literal_false()];
126                for (index, col_type) in typ.column_types.iter().enumerate() {
127                    if !col_type.nullable {
128                        non_null_cols.push(MirScalarExpr::column(index).call_is_null());
129                    }
130                }
131                if non_null_cols.len() > 1 {
132                    if let Some(equivalences) = equivalences.as_mut() {
133                        equivalences.classes.push(non_null_cols);
134                    }
135                }
136
137                equivalences
138            }
139            MirRelationExpr::Let { .. } => results.get(index - 1).unwrap().clone(),
140            MirRelationExpr::LetRec { .. } => results.get(index - 1).unwrap().clone(),
141            MirRelationExpr::Project { outputs, .. } => {
142                // restrict equivalences, and introduce equivalences for repeated outputs.
143                let mut equivalences = results.get(index - 1).unwrap().clone();
144                equivalences
145                    .as_mut()
146                    .map(|e| e.project(outputs.iter().cloned()));
147                equivalences
148            }
149            MirRelationExpr::Map { scalars, .. } => {
150                // introduce equivalences for new columns and expressions that define them.
151                let mut equivalences = results.get(index - 1).unwrap().clone();
152                if let Some(equivalences) = &mut equivalences {
153                    let input_arity = depends.results::<Arity>()[index - 1];
154                    for (pos, expr) in scalars.iter().enumerate() {
155                        equivalences
156                            .classes
157                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
158                    }
159                }
160                equivalences
161            }
162            MirRelationExpr::FlatMap { .. } => results.get(index - 1).unwrap().clone(),
163            MirRelationExpr::Filter { predicates, .. } => {
164                let mut equivalences = results.get(index - 1).unwrap().clone();
165                if let Some(equivalences) = &mut equivalences {
166                    let mut class = predicates.clone();
167                    class.push(MirScalarExpr::literal_ok(Datum::True, ReprScalarType::Bool));
168                    equivalences.classes.push(class);
169                }
170                equivalences
171            }
172            MirRelationExpr::Join { equivalences, .. } => {
173                // Collect equivalences from all inputs;
174                let expr_index = index;
175                let mut children = depends
176                    .children_of_rev(expr_index, expr.children().count())
177                    .collect::<Vec<_>>();
178                children.reverse();
179
180                let arity = depends.results::<Arity>();
181                let mut columns = 0;
182                let mut result = Some(EquivalenceClasses::default());
183                for child in children.into_iter() {
184                    let input_arity = arity[child];
185                    let equivalences = results[child].clone();
186                    if let Some(mut equivalences) = equivalences {
187                        let permutation = (columns..(columns + input_arity)).collect::<Vec<_>>();
188                        equivalences.permute(&permutation);
189                        result
190                            .as_mut()
191                            .map(|e| e.classes.extend(equivalences.classes));
192                    } else {
193                        result = None;
194                    }
195                    columns += input_arity;
196                }
197
198                // Fold join equivalences into our results.
199                result
200                    .as_mut()
201                    .map(|e| e.classes.extend(equivalences.iter().cloned()));
202                result
203            }
204            MirRelationExpr::Reduce {
205                group_key,
206                aggregates,
207                ..
208            } => {
209                let input_arity = depends.results::<Arity>()[index - 1];
210                let mut equivalences = results.get(index - 1).unwrap().clone();
211                if let Some(equivalences) = &mut equivalences {
212                    // Introduce keys column equivalences as if a map, then project to those columns.
213                    // This should retain as much information as possible about these columns.
214                    for (pos, expr) in group_key.iter().enumerate() {
215                        equivalences
216                            .classes
217                            .push(vec![MirScalarExpr::column(input_arity + pos), expr.clone()]);
218                    }
219
220                    // Having added classes to `equivalences`, we should minimize the classes to fold the
221                    // information in before applying the `project`, to set it up for success.
222                    equivalences.minimize(None);
223
224                    // Grab a copy of the equivalences with key columns added to use in aggregate reasoning.
225                    let extended = equivalences.clone();
226                    // Now project down the equivalences, as we will extend them in terms of the output columns.
227                    equivalences.project(input_arity..(input_arity + group_key.len()));
228
229                    // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
230                    // They also pass through equivalences of them and other constant columns (e.g. key columns).
231                    // However, it is not correct to simply project onto these columns, as relationships amongst
232                    // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
233                    // The correct thing to do is treat the reduce as a join between single-aggregate reductions,
234                    // where each single MIN/MAX/ANY/ALL aggregate propagates equivalences.
235                    for (index, aggregate) in aggregates.iter().enumerate() {
236                        if aggregate_is_input(&aggregate.func) {
237                            let mut temp_equivs = extended.clone();
238                            temp_equivs.classes.push(vec![
239                                MirScalarExpr::column(input_arity + group_key.len()),
240                                aggregate.expr.clone(),
241                            ]);
242                            temp_equivs.minimize(None);
243                            temp_equivs.project(input_arity..(input_arity + group_key.len() + 1));
244                            let columns = (0..group_key.len())
245                                .chain(std::iter::once(group_key.len() + index))
246                                .collect::<Vec<_>>();
247                            temp_equivs.permute(&columns[..]);
248                            equivalences.classes.extend(temp_equivs.classes);
249                        }
250                    }
251                }
252                equivalences
253            }
254            MirRelationExpr::TopK { .. } => results.get(index - 1).unwrap().clone(),
255            MirRelationExpr::Negate { .. } => results.get(index - 1).unwrap().clone(),
256            MirRelationExpr::Threshold { .. } => results.get(index - 1).unwrap().clone(),
257            MirRelationExpr::Union { .. } => {
258                let expr_index = index;
259                let mut child_equivs = depends
260                    .children_of_rev(expr_index, expr.children().count())
261                    .flat_map(|c| &results[c]);
262                if let Some(first) = child_equivs.next() {
263                    Some(first.union_many(child_equivs))
264                } else {
265                    None
266                }
267            }
268            MirRelationExpr::ArrangeBy { .. } => results.get(index - 1).unwrap().clone(),
269        };
270
271        let repr_expr_type = depends.results::<ReprRelationType>()[index].as_ref();
272        equivalences
273            .as_mut()
274            .map(|e| e.minimize(repr_expr_type.as_ref().map(|x| &x[..])));
275        equivalences
276    }
277
278    fn lattice() -> Option<Box<dyn Lattice<Self::Value>>> {
279        Some(Box::new(EQLattice))
280    }
281}
282
283struct EQLattice;
284
285impl Lattice<Option<EquivalenceClasses>> for EQLattice {
286    fn top(&self) -> Option<EquivalenceClasses> {
287        None
288    }
289
290    fn meet_assign(
291        &self,
292        a: &mut Option<EquivalenceClasses>,
293        b: Option<EquivalenceClasses>,
294    ) -> bool {
295        match (&mut *a, b) {
296            (_, None) => false,
297            (None, b) => {
298                *a = b;
299                true
300            }
301            (Some(a), Some(b)) => {
302                let mut c = a.union(&b);
303                std::mem::swap(a, &mut c);
304                a != &mut c
305            }
306        }
307    }
308}
309
310/// A compact representation of classes of expressions that must be equivalent.
311///
312/// Each "class" contains a list of expressions, each of which must be `Eq::eq` equal.
313/// Ideally, the first element is the "simplest", e.g. a literal or column reference,
314/// and any other element of that list can be replaced by it.
315///
316/// The classes are meant to be minimized, with each expression as reduced as it can be,
317/// and all classes sharing an element merged.
318#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
319pub struct EquivalenceClasses {
320    /// Multiple lists of equivalent expressions, each representing an equivalence class.
321    ///
322    /// The first element should be the "canonical" simplest element, that any other element
323    /// can be replaced by.
324    /// These classes are unified whenever possible, to minimize the number of classes.
325    /// They are only guaranteed to form an equivalence relation after a call to `minimize`,
326    /// which refreshes both `self.classes` and `self.remap`.
327    pub classes: Vec<Vec<MirScalarExpr>>,
328
329    /// An expression simplification map.
330    ///
331    /// This map reflects an equivalence relation based on a prior version of `self.classes`.
332    /// As users may add to `self.classes`, `self.remap` may become stale. We refresh `remap`
333    /// only in `self.refresh()`, to the equivalence relation that derives from `self.classes`.
334    ///
335    /// It is important to `self.remap.clear()` if you invalidate it by mutating rather than
336    /// appending to `self.classes`. This will be corrected in the next call to `self.refresh()`,
337    /// but until then `remap` could be arbitrarily wrong. This should be improved in the future.
338    remap: BTreeMap<MirScalarExpr, MirScalarExpr>,
339}
340
341/// An enum that can represent either a standard `EquivalenceClasses` or a
342/// `EquivalenceClassesWithholdingErrors`.
343#[derive(Clone, Debug)]
344pub enum EqClassesImpl {
345    /// Standard equivalence classes.
346    EquivalenceClasses(EquivalenceClasses),
347    /// Equivalence classes that withhold errors, reintroducing them when extracting equivalences.
348    EquivalenceClassesWithholdingErrors(EquivalenceClassesWithholdingErrors),
349}
350
351impl EqClassesImpl {
352    /// Returns a reference to the underlying `EquivalenceClasses`.
353    pub fn equivalence_classes(&self) -> &EquivalenceClasses {
354        match self {
355            EqClassesImpl::EquivalenceClasses(classes) => classes,
356            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
357                &classes.equivalence_classes
358            }
359        }
360    }
361
362    /// Returns a mutable reference to the underlying `EquivalenceClasses`.
363    pub fn equivalence_classes_mut(&mut self) -> &mut EquivalenceClasses {
364        match self {
365            EqClassesImpl::EquivalenceClasses(classes) => classes,
366            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
367                &mut classes.equivalence_classes
368            }
369        }
370    }
371
372    /// Extend the equivalence classes with new equivalences.
373    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
374        match self {
375            EqClassesImpl::EquivalenceClasses(classes) => {
376                classes.classes.extend(equivalences);
377            }
378            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
379                classes.extend_equivalences(equivalences);
380            }
381        }
382    }
383
384    /// Push a new equivalence class.
385    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
386        match self {
387            EqClassesImpl::EquivalenceClasses(classes) => {
388                classes.classes.push(equivalences);
389            }
390            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
391                classes.push_equivalences(equivalences);
392            }
393        }
394    }
395
396    /// Subject the constraints to the column projection, reworking and removing equivalences.
397    pub fn project<I>(&mut self, output_columns: I)
398    where
399        I: IntoIterator<Item = usize> + Clone,
400    {
401        match self {
402            EqClassesImpl::EquivalenceClasses(classes) => {
403                classes.project(output_columns);
404            }
405            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
406                classes.project(output_columns);
407            }
408        }
409    }
410
411    /// Extract the equivalences
412    pub fn extract_equivalences(
413        &mut self,
414        columns: Option<&[ReprColumnType]>,
415    ) -> Vec<Vec<MirScalarExpr>> {
416        match self {
417            EqClassesImpl::EquivalenceClasses(classes) => classes.classes.clone(),
418            EqClassesImpl::EquivalenceClassesWithholdingErrors(classes) => {
419                classes.extract_equivalences(columns)
420            }
421        }
422    }
423}
424
425/// A wrapper struct for equivalence classes that witholds errors
426/// from the underlying equivalence classes, reintroducing them
427/// when extracting equivalences.
428#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Default, Debug)]
429pub struct EquivalenceClassesWithholdingErrors {
430    /// The underlying equivalence classes.
431    pub equivalence_classes: EquivalenceClasses,
432    /// A list of equivalence classes that contain errors, which we hold back
433    pub held_back_classes: Vec<(Option<MirScalarExpr>, Vec<MirScalarExpr>)>,
434}
435
436impl EquivalenceClassesWithholdingErrors {
437    /// Extend the equivalence classes with new equivalences.
438    pub fn extend_equivalences(&mut self, equivalences: Vec<Vec<MirScalarExpr>>) {
439        for class in equivalences {
440            self.push_equivalences(class);
441        }
442    }
443
444    /// Push a new equivalence class, which may contain errors.
445    pub fn push_equivalences(&mut self, equivalences: Vec<MirScalarExpr>) {
446        let (with_errs, without_errs): (Vec<_>, Vec<_>) =
447            equivalences.into_iter().partition(|e| e.contains_err());
448        self.held_back_classes
449            .push((without_errs.first().cloned(), with_errs));
450        if without_errs.len() > 1 {
451            self.equivalence_classes.classes.push(without_errs);
452        }
453    }
454
455    /// Subject the constraints to the column projection, reworking and removing equivalences.
456    /// This method should also introduce equivalences representing any repeated columns.
457    /// This notably does not project the held back classes.
458    pub fn project<I>(&mut self, output_columns: I)
459    where
460        I: IntoIterator<Item = usize> + Clone,
461    {
462        self.equivalence_classes.project(output_columns.clone());
463    }
464
465    /// Minimize the equivalence classes, and return the result, reintroducing any held back classes.
466    pub fn extract_equivalences(
467        &mut self,
468        columns: Option<&[ReprColumnType]>,
469    ) -> Vec<Vec<MirScalarExpr>> {
470        self.equivalence_classes.minimize(columns);
471
472        if self.held_back_classes.is_empty() {
473            self.equivalence_classes.classes.clone()
474        } else {
475            let reducer = self.equivalence_classes.reducer();
476            let mut classes = self.equivalence_classes.classes.clone();
477            for (anchor, mut class) in self.held_back_classes.drain(..) {
478                if let Some(mut anchor) = anchor {
479                    // Reduce the anchor expression to its canonical form.
480                    reducer.reduce_expr(&mut anchor);
481                    class.push(anchor.clone());
482                }
483                classes.push(class.clone());
484            }
485            canonicalize_equivalence_classes(&mut classes);
486            classes
487        }
488    }
489}
490
491/// Raw printing of [`EquivalenceClasses`] with default expression humanization.
492/// Don't use this in `EXPLAIN`! For redaction, column name support, etc., see
493/// [`HumanizedEquivalenceClasses`].
494impl std::fmt::Display for EquivalenceClasses {
495    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
496        HumanizedEquivalenceClasses {
497            equivalence_classes: self,
498            cols: None,
499            mode: HumanizedExplain::default(),
500        }
501        .fmt(f)
502    }
503}
504
505/// Wrapper struct for human-readable printing of expressions inside [`EquivalenceClasses`].
506/// (Similar to `HumanizedExpr`. Unfortunately, we can't just use `HumanizedExpr` here, because
507/// we'd need to `impl Display for HumanizedExpr<'a, EquivalenceClasses, M>`, but neither
508/// `Display` nor `HumanizedExpr` is defined in this crate.)
509#[derive(Debug)]
510pub struct HumanizedEquivalenceClasses<'a, M = HumanizedExplain> {
511    /// The [`EquivalenceClasses`] to be humanized.
512    pub equivalence_classes: &'a EquivalenceClasses,
513    /// An optional vector of inferred column names to be used when rendering
514    /// column references in expressions.
515    pub cols: Option<&'a Vec<String>>,
516    /// The rendering mode to use. See `HumanizerMode` for details.
517    pub mode: M,
518}
519
520impl std::fmt::Display for HumanizedEquivalenceClasses<'_> {
521    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
522        // Only show `classes`.
523        // (The following hopefully avoids allocating any of the intermediate composite strings.)
524        let classes = self.equivalence_classes.classes.iter().map(|class| {
525            format!(
526                "{}",
527                bracketed(
528                    "[",
529                    "]",
530                    separated(
531                        ", ",
532                        class.iter().map(|expr| self.mode.expr(expr, self.cols))
533                    )
534                )
535            )
536        });
537        write!(f, "{}", bracketed("[", "]", separated(", ", classes)))
538    }
539}
540
541impl EquivalenceClasses {
542    /// Comparator function for the complexity of scalar expressions. Simpler expressions are
543    /// smaller. Can be used when we need to decide which of several equivalent expressions to use.
544    pub fn mir_scalar_expr_complexity(
545        e1: &MirScalarExpr,
546        e2: &MirScalarExpr,
547    ) -> std::cmp::Ordering {
548        use MirScalarExpr::*;
549        use std::cmp::Ordering::*;
550        match (e1, e2) {
551            (Literal(_, _), Literal(_, _)) => e1.cmp(e2),
552            (Literal(_, _), _) => Less,
553            (_, Literal(_, _)) => Greater,
554            (Column(_, _), Column(_, _)) => e1.cmp(e2),
555            (Column(_, _), _) => Less,
556            (_, Column(_, _)) => Greater,
557            (x, y) => {
558                // General expressions should be ordered by their size,
559                // to ensure we only simplify expressions by substitution.
560                // If same size, then fall back to the expressions' Ord.
561                match x.size().cmp(&y.size()) {
562                    Equal => x.cmp(y),
563                    other => other,
564                }
565            }
566        }
567    }
568
569    /// Sorts and deduplicates each class, removing literal errors.
570    ///
571    /// This method does not ensure equivalence relation structure, but instead performs
572    /// only minimal structural clean-up.
573    fn tidy(&mut self) {
574        for class in self.classes.iter_mut() {
575            // Remove all literal errors, as they cannot be equated to other things.
576            class.retain(|e| !e.is_literal_err());
577            class.sort_by(Self::mir_scalar_expr_complexity);
578            class.dedup();
579        }
580        self.classes.retain(|c| c.len() > 1);
581        self.classes.sort();
582        self.classes.dedup();
583    }
584
585    /// Restore equivalence relation structure to `self.classes` and refresh `self.remap`.
586    ///
587    /// This method takes roughly linear time, and returns true iff `self.remap` has changed.
588    /// This is the only method that refreshes `self.remap`, and is a perfect place to decide
589    /// whether the equivalence classes it represents have experienced any changes since the
590    /// last refresh.
591    fn refresh(&mut self) -> bool {
592        self.tidy();
593
594        // remap may already be the correct answer, and if so we should avoid the work of rebuilding it.
595        // If it contains the same number of expressions as `self.classes`, and for every expression in
596        // `self.classes` the two agree on the representative, they are identical.
597        if self.remap.len() == self.classes.iter().map(|c| c.len()).sum::<usize>()
598            && self
599                .classes
600                .iter()
601                .all(|c| c.iter().all(|e| self.remap.get(e) == Some(&c[0])))
602        {
603            // No change, so return false.
604            return false;
605        }
606
607        // Optimistically build the `remap` we would want.
608        // Note if any unions would be required, in which case we have further work to do,
609        // including re-forming `self.classes`.
610        let mut union_find = BTreeMap::default();
611        let mut dirtied = false;
612        for class in self.classes.iter() {
613            for expr in class.iter() {
614                if let Some(other) = union_find.insert(expr.clone(), class[0].clone()) {
615                    // A merge is required, but have the more complex expression point at the simpler one.
616                    // This allows `union_find` to end as the `remap` for the new `classes` we form, with
617                    // the only required work being compressing all the paths.
618                    if Self::mir_scalar_expr_complexity(&other, &class[0])
619                        == std::cmp::Ordering::Less
620                    {
621                        union_find.union(&class[0], &other);
622                    } else {
623                        union_find.union(&other, &class[0]);
624                    }
625                    dirtied = true;
626                }
627            }
628        }
629        if dirtied {
630            let mut classes: BTreeMap<_, Vec<_>> = BTreeMap::default();
631            for class in self.classes.drain(..) {
632                for expr in class {
633                    let root: MirScalarExpr = union_find.find(&expr).unwrap().clone();
634                    classes.entry(root).or_default().push(expr);
635                }
636            }
637            self.classes = classes.into_values().collect();
638            self.tidy();
639        }
640
641        let changed = self.remap != union_find;
642        self.remap = union_find;
643        changed
644    }
645
646    /// Update `self` to maintain the same equivalences which potentially reducing along `Ord::le`.
647    ///
648    /// Informally this means simplifying constraints, removing redundant constraints, and unifying equivalence classes.
649    pub fn minimize(&mut self, columns: Option<&[ReprColumnType]>) {
650        // Repeatedly, we reduce each of the classes themselves, then unify the classes.
651        // This should strictly reduce complexity, and reach a fixed point.
652        // Ideally it is *confluent*, arriving at the same fixed point no matter the order of operations.
653
654        // We should not rely on nullability information present in `column_types`. (Doing this
655        // every time just before calling `reduce` was found to be a bottleneck during incident-217,
656        // so now we do this nullability tweaking only once here.)
657        let mut columns = columns.map(|x| x.to_vec());
658        let mut nonnull = Vec::new();
659        if let Some(columns) = columns.as_mut() {
660            for (index, col) in columns.iter_mut().enumerate() {
661                let is_null = MirScalarExpr::column(index).call_is_null();
662                if !col.nullable
663                    && self
664                        .remap
665                        .get(&is_null)
666                        .map(|e| !e.is_literal_false())
667                        .unwrap_or(true)
668                {
669                    nonnull.push(is_null);
670                }
671                col.nullable = true;
672            }
673        }
674        if !nonnull.is_empty() {
675            nonnull.push(MirScalarExpr::literal_false());
676            self.classes.push(nonnull);
677        }
678
679        // Ensure `self.classes` and `self.remap` are equivalence relations.
680        // Users are allowed to mutate `self.classes`, so we must perform this normalization at least once.
681        // We have also likely mutated `self.classes` just above with non-nullability information.
682        self.refresh();
683
684        // Termination will be detected by comparing to the map of equivalence classes.
685        let mut previous = Some(self.remap.clone());
686        while let Some(prev) = previous {
687            // Attempt to add new equivalences.
688            let novel = self.expand();
689            if !novel.is_empty() {
690                self.classes.extend(novel);
691                self.refresh();
692            }
693
694            // We continue as long as any simplification has occurred.
695            // An expression can be simplified, a duplication found, or two classes unified.
696            let mut stable = false;
697            while !stable {
698                stable = !self.minimize_once(columns.as_ref().map(|x| &x[..]));
699            }
700
701            // Termination detection.
702            if prev != self.remap {
703                previous = Some(self.remap.clone());
704            } else {
705                previous = None;
706            }
707        }
708    }
709
710    /// Proposes new equivalences that are likely to be novel.
711    ///
712    /// This method invokes `self.implications()` to propose equivalences, and then judges them to be
713    /// novel or not based on existing knowledge, reducing the equivalences down to their novel core.
714    /// This method may produce non-novel equivalences, due to its inability to perform `MSE::reduce`.
715    /// We can end up with e.g. constant expressions that cannot be found until they are so reduced.
716    /// The novelty detection is best-effort, and meant to provide a clearer signal and minimize the
717    /// number of times we call and amount of work we do in `self.refresh()`.
718    fn expand(&self) -> Vec<Vec<MirScalarExpr>> {
719        // Consider expanding `self.classes` with novel equivalences.
720        let mut novel = self.implications();
721        for class in novel.iter_mut() {
722            // reduce each expression to its canonical form.
723            for expr in class.iter_mut() {
724                self.remap.reduce_expr(expr);
725            }
726            class.sort();
727            class.dedup();
728            // for a class to be interesting we require at least two elements that do not reference the same root.
729            let common_class = class
730                .iter()
731                .map(|x| self.remap.get(x))
732                .reduce(|prev, this| if prev == this { prev } else { None });
733            if class.len() == 1 || common_class != Some(None) {
734                class.clear();
735            }
736        }
737        novel.retain(|c| !c.is_empty());
738        novel
739    }
740
741    /// Derives potentially novel equivalences without regard for minimization.
742    ///
743    /// This is an opportunity to explore equivalences that do not correspond to expression minimization,
744    /// and therefore should not be used in `minimize_once`. They are still potentially important, but
745    /// required additional guardrails to ensure we reach a fixed point.
746    ///
747    /// The implications will be introduced into `self.classes` and will prompt a round of minimization,
748    /// making it somewhat polite to avoid producing outputs that cannot result in novel equivalences.
749    /// For example, before producing a new equivalence, one could check that the involved terms are not
750    /// already present in the same class.
751    fn implications(&self) -> Vec<Vec<MirScalarExpr>> {
752        let mut new_equivalences = Vec::new();
753
754        // If we see `false == IsNull(foo)` we can add the non-null implications of `foo`.
755        let mut non_null = std::collections::BTreeSet::default();
756        for class in self.classes.iter() {
757            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
758                for e in class.iter() {
759                    if let MirScalarExpr::CallUnary {
760                        func: mz_expr::UnaryFunc::IsNull(_),
761                        expr,
762                    } = e
763                    {
764                        expr.non_null_requirements(&mut non_null);
765                    }
766                }
767            }
768        }
769        // If we see `true == foo` we can add the non-null implications of `foo`.
770        // TODO: generalize to arbitrary non-null, non-error literals; at the moment `true == pred` is
771        // an important idiom to identify for how we express predicates.
772        for class in self.classes.iter() {
773            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
774                for expr in class.iter() {
775                    expr.non_null_requirements(&mut non_null);
776                }
777            }
778        }
779        // Only keep constraints that are not already known.
780        // Known constraints will present as `COL(_) IS NULL == false`,
781        // which can only happen if `false` is present, and both terms
782        // map to the same canonical representative>
783        let lit_false = MirScalarExpr::literal_false();
784        let target = self.remap.get(&lit_false);
785        if target.is_some() {
786            non_null.retain(|c| {
787                let is_null = MirScalarExpr::column(*c).call_is_null();
788                self.remap.get(&is_null) != target
789            });
790        }
791        if !non_null.is_empty() {
792            let mut class = Vec::with_capacity(non_null.len() + 1);
793            class.push(MirScalarExpr::literal_false());
794            class.extend(
795                non_null
796                    .into_iter()
797                    .map(|c| MirScalarExpr::column(c).call_is_null()),
798            );
799            new_equivalences.push(class);
800        }
801
802        // If we see records formed from other expressions, we can equate the expressions with
803        // accessors applied to the class of the record former. In `minimize_once` we reduce by
804        // equivalence class representative before we perform expression simplification, so we
805        // shoud be able to just use the expression former, rather than find its representative.
806        // The risk, potentially, is that we would apply accessors to the record former and then
807        // just simplify it away learning nothing.
808        for class in self.classes.iter() {
809            for expr in class.iter() {
810                // Record-forming expressions can equate their accessors and their members.
811                if let MirScalarExpr::CallVariadic {
812                    func: mz_expr::VariadicFunc::RecordCreate(..),
813                    exprs,
814                } = expr
815                {
816                    for (index, e) in exprs.iter().enumerate() {
817                        new_equivalences.push(vec![
818                            e.clone(),
819                            expr.clone().call_unary(mz_expr::UnaryFunc::RecordGet(
820                                mz_expr::func::RecordGet(index),
821                            )),
822                        ]);
823                    }
824                }
825            }
826        }
827
828        // Return all newly established equivalences.
829        new_equivalences
830    }
831
832    /// A single iteration of minimization, which we expect to repeat but benefit from factoring out.
833    ///
834    /// This invocation should take roughly linear time.
835    /// It starts with equivalence class invariants maintained (closed under transitivity), and then
836    ///   1. Performs per-expression reduction, including the class structure to replace subexpressions.
837    ///   2. Applies idiom detection to e.g. unpack expressions equivalence to literal true or false.
838    ///   3. Restores the equivalence class invariants.
839    fn minimize_once(&mut self, columns: Option<&[ReprColumnType]>) -> bool {
840        // 1. Reduce each expression
841        //
842        // This reduction first looks for subexpression substitutions that can be performed,
843        // and then applies expression reduction if column type information is provided.
844        for class in self.classes.iter_mut() {
845            for expr in class.iter_mut() {
846                self.remap.reduce_child(expr);
847                if let Some(columns) = columns {
848                    let orig_expr = expr.clone();
849                    expr.reduce(columns);
850                    if expr.contains_err() {
851                        // Rollback to the original expression if it contains an error.
852                        *expr = orig_expr;
853                    }
854                }
855            }
856        }
857
858        // 2. Identify idioms
859        //    E.g. If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
860        let mut to_add = Vec::new();
861        for class in self.classes.iter_mut() {
862            if Self::class_contains_literal(class, |e| e == &Ok(Datum::True)) {
863                for expr in class.iter() {
864                    // If Eq(x, y) must be true, we can introduce classes `[x, y]` and `[false, IsNull(x), IsNull(y)]`.
865                    // This substitution replaces a complex expression with several smaller expressions, and cannot
866                    // cycle if we follow that practice.
867                    if let MirScalarExpr::CallBinary {
868                        func: mz_expr::BinaryFunc::Eq(_),
869                        expr1,
870                        expr2,
871                    } = expr
872                    {
873                        to_add.push(vec![*expr1.clone(), *expr2.clone()]);
874                        to_add.push(vec![
875                            MirScalarExpr::literal_false(),
876                            expr1.clone().call_is_null(),
877                            expr2.clone().call_is_null(),
878                        ]);
879                    }
880                }
881                // Remove the more complex form of the expression.
882                class.retain(|expr| {
883                    if let MirScalarExpr::CallBinary {
884                        func: mz_expr::BinaryFunc::Eq(_),
885                        ..
886                    } = expr
887                    {
888                        false
889                    } else {
890                        true
891                    }
892                });
893                for expr in class.iter() {
894                    // If TRUE == NOT(X) then FALSE == X is a simpler form.
895                    if let MirScalarExpr::CallUnary {
896                        func: mz_expr::UnaryFunc::Not(_),
897                        expr: e,
898                    } = expr
899                    {
900                        to_add.push(vec![MirScalarExpr::literal_false(), (**e).clone()]);
901                    }
902                }
903                class.retain(|expr| {
904                    if let MirScalarExpr::CallUnary {
905                        func: mz_expr::UnaryFunc::Not(_),
906                        ..
907                    } = expr
908                    {
909                        false
910                    } else {
911                        true
912                    }
913                });
914            }
915            if Self::class_contains_literal(class, |e| e == &Ok(Datum::False)) {
916                for expr in class.iter() {
917                    // If FALSE == NOT(X) then TRUE == X is a simpler form.
918                    if let MirScalarExpr::CallUnary {
919                        func: mz_expr::UnaryFunc::Not(_),
920                        expr: e,
921                    } = expr
922                    {
923                        to_add.push(vec![MirScalarExpr::literal_true(), (**e).clone()]);
924                    }
925                }
926                class.retain(|expr| {
927                    if let MirScalarExpr::CallUnary {
928                        func: mz_expr::UnaryFunc::Not(_),
929                        ..
930                    } = expr
931                    {
932                        false
933                    } else {
934                        true
935                    }
936                });
937            }
938        }
939        self.classes.extend(to_add);
940
941        // 3. Restore equivalence relation structure and observe if any changes result.
942        self.refresh()
943    }
944
945    /// Produce the equivalences present in both inputs.
946    pub fn union(&self, other: &Self) -> Self {
947        self.union_many([other])
948    }
949
950    /// The equivalence classes of terms equivalent in all inputs.
951    ///
952    /// This method relies on the `remap` member of each input, and bases the intersection on these rather than `classes`.
953    /// This means one should ensure `minimize()` has been called on all inputs, or risk getting a stale, but conservatively
954    /// correct, result.
955    ///
956    /// This method currently misses opportunities, because it only looks for exactly matches in expressions,
957    /// which may not include all possible matches. For example, `f(#1) == g(#1)` may exist in one class, but
958    /// in another class where `#0 == #1` it may exist as `f(#0) == g(#0)`.
959    pub fn union_many<'a, I>(&self, others: I) -> Self
960    where
961        I: IntoIterator<Item = &'a Self>,
962    {
963        // List of expressions in the intersection, and a proxy equivalence class identifier.
964        let mut intersection: Vec<(&MirScalarExpr, usize)> = Default::default();
965        // Map from expression to a proxy equivalence class identifier.
966        let mut rekey: BTreeMap<&MirScalarExpr, usize> = Default::default();
967        for (key, val) in self.remap.iter() {
968            if !rekey.contains_key(val) {
969                rekey.insert(val, rekey.len());
970            }
971            intersection.push((key, rekey[val]));
972        }
973        for other in others {
974            // Map from proxy equivalence class identifier and equivalence class expr to a new proxy identifier.
975            let mut rekey: BTreeMap<(usize, &MirScalarExpr), usize> = Default::default();
976            intersection.retain_mut(|(key, idx)| {
977                if let Some(val) = other.remap.get(key) {
978                    if !rekey.contains_key(&(*idx, val)) {
979                        rekey.insert((*idx, val), rekey.len());
980                    }
981                    *idx = rekey[&(*idx, val)];
982                    true
983                } else {
984                    false
985                }
986            });
987        }
988        let mut classes: BTreeMap<_, Vec<MirScalarExpr>> = Default::default();
989        for (key, vals) in intersection {
990            classes.entry(vals).or_default().push(key.clone())
991        }
992        let classes = classes.into_values().collect::<Vec<_>>();
993        let mut equivalences = EquivalenceClasses {
994            classes,
995            remap: Default::default(),
996        };
997        equivalences.minimize(None);
998        equivalences
999    }
1000
1001    /// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
1002    pub fn permute(&mut self, permutation: &[usize]) {
1003        for class in self.classes.iter_mut() {
1004            for expr in class.iter_mut() {
1005                expr.permute(permutation);
1006            }
1007        }
1008        self.remap.clear();
1009        self.minimize(None);
1010    }
1011
1012    /// Subject the constraints to the column projection, reworking and removing equivalences.
1013    ///
1014    /// This method should also introduce equivalences representing any repeated columns.
1015    pub fn project<I>(&mut self, output_columns: I)
1016    where
1017        I: IntoIterator<Item = usize> + Clone,
1018    {
1019        // Retain the first instance of each column, and record subsequent instances as duplicates.
1020        let mut dupes = Vec::new();
1021        let mut remap = BTreeMap::default();
1022        for (idx, col) in output_columns.into_iter().enumerate() {
1023            if let Some(pos) = remap.get(&col) {
1024                dupes.push((*pos, idx));
1025            } else {
1026                remap.insert(col, idx);
1027            }
1028        }
1029
1030        // Some expressions may be "localized" in that they only reference columns in `output_columns`.
1031        // Many expressions may not be localized, but may reference canonical non-localized expressions
1032        // for classes that contain a localized expression; in that case we can "backport" the localized
1033        // expression to give expressions referencing the canonical expression a shot at localization.
1034        //
1035        // Expressions should only contain instances of canonical expressions, and so we shouldn't need
1036        // to look any further than backporting those. Backporting should have the property that the simplest
1037        // localized expression in each class does not contain any non-localized canonical expressions
1038        // (as that would make it non-localized); our backporting of non-localized canonicals with localized
1039        // expressions should never fire a second
1040
1041        // Let's say an expression is "localized" once we are able to rewrite its support in terms of `output_columns`.
1042        // Not all expressions can be localized, although some of them may be equivalent to localized expressions.
1043        // As we find localized expressions, we can replace uses of their equivalent representative with them,
1044        // which may allow further expression localization.
1045        // We continue the process until no further classes can be localized.
1046
1047        // A map from representatives to our first localization of their equivalence class.
1048        let mut localized = false;
1049        while !localized {
1050            localized = true;
1051            let mut current_map = BTreeMap::default();
1052            for class in self.classes.iter_mut() {
1053                if !class[0].support().iter().all(|c| remap.contains_key(c)) {
1054                    if let Some(pos) = class
1055                        .iter()
1056                        .position(|e| e.support().iter().all(|c| remap.contains_key(c)))
1057                    {
1058                        class.swap(0, pos);
1059                        localized = false;
1060                    }
1061                }
1062                for expr in class[1..].iter() {
1063                    current_map.insert(expr.clone(), class[0].clone());
1064                }
1065            }
1066
1067            // attempt to replace representatives with equivalent localizeable expressions.
1068            for class_index in 0..self.classes.len() {
1069                for index in 0..self.classes[class_index].len() {
1070                    current_map.reduce_child(&mut self.classes[class_index][index]);
1071                }
1072            }
1073            // NB: Do *not* `self.minimize()`, as we are developing localizable rather than canonical representatives.
1074        }
1075
1076        // Localize all localizable expressions and discard others.
1077        for class in self.classes.iter_mut() {
1078            class.retain(|e| e.support().iter().all(|c| remap.contains_key(c)));
1079            for expr in class.iter_mut() {
1080                expr.permute_map(&remap);
1081            }
1082        }
1083        self.classes.retain(|c| c.len() > 1);
1084        // If column repetitions, introduce them as equivalences.
1085        // We introduce only the equivalence to the first occurrence, and rely on minimization to collect them.
1086        for (col1, col2) in dupes {
1087            self.classes.push(vec![
1088                MirScalarExpr::column(col1),
1089                MirScalarExpr::column(col2),
1090            ]);
1091        }
1092        self.remap.clear();
1093        self.minimize(None);
1094    }
1095
1096    /// True if any equivalence class contains two distinct non-error literals.
1097    pub fn unsatisfiable(&self) -> bool {
1098        for class in self.classes.iter() {
1099            let mut literal_ok = None;
1100            for expr in class.iter() {
1101                if let MirScalarExpr::Literal(Ok(row), _) = expr {
1102                    if literal_ok.is_some() && literal_ok != Some(row) {
1103                        return true;
1104                    } else {
1105                        literal_ok = Some(row);
1106                    }
1107                }
1108            }
1109        }
1110        false
1111    }
1112
1113    /// Returns a map that can be used to replace (sub-)expressions.
1114    pub fn reducer(&self) -> &BTreeMap<MirScalarExpr, MirScalarExpr> {
1115        &self.remap
1116    }
1117
1118    /// Examines the prefix of `class` of literals, looking for any satisfying `predicate`.
1119    ///
1120    /// This test bails out as soon as it sees a non-literal, and may have false negatives
1121    /// if the data are not sorted with literals at the front.
1122    fn class_contains_literal<P>(class: &[MirScalarExpr], mut predicate: P) -> bool
1123    where
1124        P: FnMut(&Result<Datum, &mz_expr::EvalError>) -> bool,
1125    {
1126        class
1127            .iter()
1128            .take_while(|e| e.is_literal())
1129            .filter_map(|e| e.as_literal())
1130            .any(move |e| predicate(&e))
1131    }
1132}
1133
1134/// A type capable of simplifying `MirScalarExpr`s.
1135pub trait ExpressionReducer {
1136    /// Attempt to replace `expr` itself with another expression.
1137    /// Returns true if it does so.
1138    fn replace(&self, expr: &mut MirScalarExpr) -> bool;
1139    /// Attempt to replace any subexpressions of `expr` with other expressions.
1140    /// Returns true if it does so.
1141    fn reduce_expr(&self, expr: &mut MirScalarExpr) -> bool {
1142        let mut simplified = false;
1143        simplified = simplified || self.reduce_child(expr);
1144        simplified = simplified || self.replace(expr);
1145        simplified
1146    }
1147    /// Attempt to replace any subexpressions of `expr`'s children with other expressions.
1148    /// Returns true if it does so.
1149    fn reduce_child(&self, expr: &mut MirScalarExpr) -> bool {
1150        let mut simplified = false;
1151        match expr {
1152            MirScalarExpr::CallBinary { expr1, expr2, .. } => {
1153                simplified = self.reduce_expr(expr1) || simplified;
1154                simplified = self.reduce_expr(expr2) || simplified;
1155            }
1156            MirScalarExpr::CallUnary { expr, .. } => {
1157                simplified = self.reduce_expr(expr) || simplified;
1158            }
1159            MirScalarExpr::CallVariadic { exprs, .. } => {
1160                for expr in exprs.iter_mut() {
1161                    simplified = self.reduce_expr(expr) || simplified;
1162                }
1163            }
1164            MirScalarExpr::If { cond: _, then, els } => {
1165                // Do not simplify `cond`, as we cannot ensure the simplification
1166                // continues to hold as expressions migrate around.
1167                simplified = self.reduce_expr(then) || simplified;
1168                simplified = self.reduce_expr(els) || simplified;
1169            }
1170            _ => {}
1171        }
1172        simplified
1173    }
1174}
1175
1176impl ExpressionReducer for BTreeMap<&MirScalarExpr, &MirScalarExpr> {
1177    /// Perform any exact replacement for `expr`, report if it had an effect.
1178    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1179        if let Some(other) = self.get(expr) {
1180            if other != &expr {
1181                expr.clone_from(other);
1182                return true;
1183            }
1184        }
1185        false
1186    }
1187}
1188
1189impl ExpressionReducer for BTreeMap<MirScalarExpr, MirScalarExpr> {
1190    /// Perform any exact replacement for `expr`, report if it had an effect.
1191    fn replace(&self, expr: &mut MirScalarExpr) -> bool {
1192        if let Some(other) = self.get(expr) {
1193            if other != expr {
1194                expr.clone_from(other);
1195                return true;
1196            }
1197        }
1198        false
1199    }
1200}
1201
1202/// True iff the aggregate function returns an input datum.
1203fn aggregate_is_input(aggregate: &AggregateFunc) -> bool {
1204    match aggregate {
1205        AggregateFunc::MaxInt16
1206        | AggregateFunc::MaxInt32
1207        | AggregateFunc::MaxInt64
1208        | AggregateFunc::MaxUInt16
1209        | AggregateFunc::MaxUInt32
1210        | AggregateFunc::MaxUInt64
1211        | AggregateFunc::MaxMzTimestamp
1212        | AggregateFunc::MaxFloat32
1213        | AggregateFunc::MaxFloat64
1214        | AggregateFunc::MaxBool
1215        | AggregateFunc::MaxString
1216        | AggregateFunc::MaxDate
1217        | AggregateFunc::MaxTimestamp
1218        | AggregateFunc::MaxTimestampTz
1219        | AggregateFunc::MinInt16
1220        | AggregateFunc::MinInt32
1221        | AggregateFunc::MinInt64
1222        | AggregateFunc::MinUInt16
1223        | AggregateFunc::MinUInt32
1224        | AggregateFunc::MinUInt64
1225        | AggregateFunc::MinMzTimestamp
1226        | AggregateFunc::MinFloat32
1227        | AggregateFunc::MinFloat64
1228        | AggregateFunc::MinBool
1229        | AggregateFunc::MinString
1230        | AggregateFunc::MinDate
1231        | AggregateFunc::MinTimestamp
1232        | AggregateFunc::MinTimestampTz
1233        | AggregateFunc::Any
1234        | AggregateFunc::All => true,
1235        _ => false,
1236    }
1237}