Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::explain::text::text_string_at;
34use mz_repr::explain::{
35    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
36};
37use mz_repr::{
38    ColumnName, Datum, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
39    ReprScalarType, Row, RowIterator, SqlColumnType, SqlRelationType, SqlScalarType,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::Id::Local;
44use crate::explain::{HumanizedExpr, HumanizerMode};
45use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
46use crate::row::{RowCollection, RowCollectionIter};
47use crate::scalar::func::variadic::{
48    JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
49};
50use crate::visit::{Visit, VisitChildren};
51use crate::{
52    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
53};
54
55pub mod canonicalize;
56pub mod func;
57pub mod join_input_mapper;
58
59/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
60///
61/// The recursion limit must be large enough to accommodate for the linear representation
62/// of some pathological but frequently occurring query fragments.
63///
64/// For example, in MIR we could have long chains of
65/// - (1) `Let` bindings,
66/// - (2) `CallBinary` calls with associative functions such as `+`
67///
68/// Until we fix those, we need to stick with the larger recursion limit.
69pub const RECURSION_LIMIT: usize = 2048;
70
71/// A trait for types that describe how to build a collection.
72pub trait CollectionPlan {
73    /// Collects the set of global identifiers from dataflows referenced in Get.
74    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
75
76    /// Returns the set of global identifiers from dataflows referenced in Get.
77    ///
78    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
79    fn depends_on(&self) -> BTreeSet<GlobalId> {
80        let mut out = BTreeSet::new();
81        self.depends_on_into(&mut out);
82        out
83    }
84}
85
86/// An abstract syntax tree which defines a collection.
87///
88/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
89/// written generically enough to avoid run-time compilation work.
90///
91/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
92/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
93/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
94/// one. This is because the reason for the manual implementation is not to change the semantics
95/// from the derived one, but to avoid stack overflows.
96#[allow(clippy::derived_hash_with_manual_eq)]
97#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
98pub enum MirRelationExpr {
99    /// A constant relation containing specified rows.
100    ///
101    /// The runtime memory footprint of this operator is zero.
102    ///
103    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
104    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
105    /// constant folding doesn't remove `ArrangeBy`s.
106    Constant {
107        /// Rows of the constant collection and their multiplicities.
108        rows: Result<Vec<(Row, Diff)>, EvalError>,
109        /// Schema of the collection.
110        typ: ReprRelationType,
111    },
112    /// Get an existing dataflow.
113    ///
114    /// The runtime memory footprint of this operator is zero.
115    Get {
116        /// The identifier for the collection to load.
117        #[mzreflect(ignore)]
118        id: Id,
119        /// Schema of the collection.
120        typ: ReprRelationType,
121        /// If this is a global Get, this will indicate whether we are going to read from Persist or
122        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
123        /// how downstream dataflow operations will use this index is also recorded. This is filled
124        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
125        /// lowering to LIR, but is used only by EXPLAIN.
126        #[mzreflect(ignore)]
127        access_strategy: AccessStrategy,
128    },
129    /// Introduce a temporary dataflow.
130    ///
131    /// The runtime memory footprint of this operator is zero.
132    Let {
133        /// The identifier to be used in `Get` variants to retrieve `value`.
134        #[mzreflect(ignore)]
135        id: LocalId,
136        /// The collection to be bound to `id`.
137        value: Box<MirRelationExpr>,
138        /// The result of the `Let`, evaluated with `id` bound to `value`.
139        body: Box<MirRelationExpr>,
140    },
141    /// Introduce mutually recursive bindings.
142    ///
143    /// Each `LocalId` is immediately bound to an initially empty  collection
144    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
145    /// binding is evaluated using the current contents of each other binding,
146    /// and is refreshed to contain the new evaluation. This process continues
147    /// through all bindings, and repeats as long as changes continue to occur.
148    ///
149    /// The resulting value of the expression is `body` evaluated once in the
150    /// context of the final iterates.
151    ///
152    /// A zero-binding instance can be replaced by `body`.
153    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
154    ///
155    /// The runtime memory footprint of this operator is zero.
156    LetRec {
157        /// The identifiers to be used in `Get` variants to retrieve each `value`.
158        #[mzreflect(ignore)]
159        ids: Vec<LocalId>,
160        /// The collections to be bound to each `id`.
161        values: Vec<MirRelationExpr>,
162        /// Maximum number of iterations, after which we should artificially force a fixpoint.
163        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
164        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
165        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
166        #[mzreflect(ignore)]
167        limits: Vec<Option<LetRecLimit>>,
168        /// The result of the `Let`, evaluated with `id` bound to `value`.
169        body: Box<MirRelationExpr>,
170    },
171    /// Project out some columns from a dataflow
172    ///
173    /// The runtime memory footprint of this operator is zero.
174    Project {
175        /// The source collection.
176        input: Box<MirRelationExpr>,
177        /// Indices of columns to retain.
178        outputs: Vec<usize>,
179    },
180    /// Append new columns to a dataflow
181    ///
182    /// The runtime memory footprint of this operator is zero.
183    Map {
184        /// The source collection.
185        input: Box<MirRelationExpr>,
186        /// Expressions which determine values to append to each row.
187        /// An expression may refer to columns in `input` or
188        /// expressions defined earlier in the vector
189        scalars: Vec<MirScalarExpr>,
190    },
191    /// Like Map, but yields zero-or-more output rows per input row
192    ///
193    /// The runtime memory footprint of this operator is zero.
194    FlatMap {
195        /// The source collection
196        input: Box<MirRelationExpr>,
197        /// The table func to apply
198        func: TableFunc,
199        /// The argument to the table func
200        exprs: Vec<MirScalarExpr>,
201    },
202    /// Keep rows from a dataflow where all the predicates are true
203    ///
204    /// The runtime memory footprint of this operator is zero.
205    Filter {
206        /// The source collection.
207        input: Box<MirRelationExpr>,
208        /// Predicates, each of which must be true.
209        predicates: Vec<MirScalarExpr>,
210    },
211    /// Join several collections, where some columns must be equal.
212    ///
213    /// For further details consult the documentation for [`MirRelationExpr::join`].
214    ///
215    /// The runtime memory footprint of this operator can be proportional to
216    /// the sizes of all inputs and the size of all joins of prefixes.
217    /// This may be reduced due to arrangements available at rendering time.
218    Join {
219        /// A sequence of input relations.
220        inputs: Vec<MirRelationExpr>,
221        /// A sequence of equivalence classes of expressions on the cross product of inputs.
222        ///
223        /// Each equivalence class is a list of scalar expressions, where for each class the
224        /// intended interpretation is that all evaluated expressions should be equal.
225        ///
226        /// Each scalar expression is to be evaluated over the cross-product of all records
227        /// from all inputs. In many cases this may just be column selection from specific
228        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
229        /// from multiple inputs, or just constant literals).
230        equivalences: Vec<Vec<MirScalarExpr>>,
231        /// Join implementation information.
232        #[serde(default)]
233        implementation: JoinImplementation,
234    },
235    /// Group a dataflow by some columns and aggregate over each group
236    ///
237    /// The runtime memory footprint of this operator is at most proportional to the
238    /// number of distinct records in the input and output. The actual requirements
239    /// can be less: the number of distinct inputs to each aggregate, summed across
240    /// each aggregate, plus the output size. For more details consult the code that
241    /// builds the associated dataflow.
242    Reduce {
243        /// The source collection.
244        input: Box<MirRelationExpr>,
245        /// Column indices used to form groups.
246        group_key: Vec<MirScalarExpr>,
247        /// Expressions which determine values to append to each row, after the group keys.
248        aggregates: Vec<AggregateExpr>,
249        /// True iff the input is known to monotonically increase (only addition of records).
250        #[serde(default)]
251        monotonic: bool,
252        /// User hint: expected number of values per group key. Used to optimize physical rendering.
253        #[serde(default)]
254        expected_group_size: Option<u64>,
255    },
256    /// Groups and orders within each group, limiting output.
257    ///
258    /// The runtime memory footprint of this operator is proportional to its input and output.
259    TopK {
260        /// The source collection.
261        input: Box<MirRelationExpr>,
262        /// Column indices used to form groups.
263        group_key: Vec<usize>,
264        /// Column indices used to order rows within groups.
265        order_key: Vec<ColumnOrder>,
266        /// Number of records to retain
267        #[serde(default)]
268        limit: Option<MirScalarExpr>,
269        /// Number of records to skip
270        #[serde(default)]
271        offset: usize,
272        /// True iff the input is known to monotonically increase (only addition of records).
273        #[serde(default)]
274        monotonic: bool,
275        /// User-supplied hint: how many rows will have the same group key.
276        #[serde(default)]
277        expected_group_size: Option<u64>,
278    },
279    /// Return a dataflow where the row counts are negated
280    ///
281    /// The runtime memory footprint of this operator is zero.
282    Negate {
283        /// The source collection.
284        input: Box<MirRelationExpr>,
285    },
286    /// Keep rows from a dataflow where the row counts are positive
287    ///
288    /// The runtime memory footprint of this operator is proportional to its input and output.
289    Threshold {
290        /// The source collection.
291        input: Box<MirRelationExpr>,
292    },
293    /// Adds the frequencies of elements in contained sets.
294    ///
295    /// The runtime memory footprint of this operator is zero.
296    Union {
297        /// A source collection.
298        base: Box<MirRelationExpr>,
299        /// Source collections to union.
300        inputs: Vec<MirRelationExpr>,
301    },
302    /// Technically a no-op. Used to render an index. Will be used to optimize queries
303    /// on finer grain. Each `keys` item represents a different index that should be
304    /// produced from the `keys`.
305    ///
306    /// The runtime memory footprint of this operator is proportional to its input.
307    ArrangeBy {
308        /// The source collection
309        input: Box<MirRelationExpr>,
310        /// Columns to arrange `input` by, in order of decreasing primacy
311        keys: Vec<Vec<MirScalarExpr>>,
312    },
313}
314
315impl PartialEq for MirRelationExpr {
316    fn eq(&self, other: &Self) -> bool {
317        // Capture the result and test it wrt `Ord` implementation in test environments.
318        let result = structured_diff::MreDiff::new(self, other).next().is_none();
319        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
320        result
321    }
322}
323impl Eq for MirRelationExpr {}
324
325impl MirRelationExpr {
326    /// Reports the schema of the relation.
327    ///
328    /// This is the SQL-type parallel of [`Self::typ`]; it is merely
329    /// a wrapper around it, returning a [`SqlRelationType`] instead of
330    /// a [`ReprRelationType`].
331    pub fn sql_typ(&self) -> SqlRelationType {
332        let repr_typ = self.typ();
333        SqlRelationType::from_repr(&repr_typ)
334    }
335
336    /// Reports the repr schema of the relation.
337    ///
338    /// This method determines the type through recursive traversal of the
339    /// relation expression, drawing from the types of base collections.
340    /// As such, this is not an especially cheap method, and should be used
341    /// judiciously.
342    ///
343    /// The relation type is computed incrementally with a recursive post-order
344    /// traversal, that accumulates the input types for the relations yet to be
345    /// visited in `type_stack`.
346    pub fn typ(&self) -> ReprRelationType {
347        let mut type_stack = Vec::new();
348        #[allow(deprecated)]
349        self.visit_pre_post_nolimit(
350            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
351                match &e {
352                    MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
353                    MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
354                    _ => None,
355                }
356            },
357            &mut |e: &MirRelationExpr| {
358                match e {
359                    MirRelationExpr::Let { .. } => {
360                        let body_typ = type_stack.pop().unwrap();
361                        // Insert a dummy relation type for the value, since `typ_with_input_types`
362                        // won't look at it, but expects the relation type of the body to be second.
363                        type_stack.push(ReprRelationType::empty());
364                        type_stack.push(body_typ);
365                    }
366                    MirRelationExpr::LetRec { values, .. } => {
367                        let body_typ = type_stack.pop().unwrap();
368                        type_stack.extend(
369                            std::iter::repeat(ReprRelationType::empty()).take(values.len()),
370                        );
371                        // Insert dummy relation types for the values, since `typ_with_input_types`
372                        // won't look at them, but expects the relation type of the body to be last.
373                        type_stack.push(body_typ);
374                    }
375                    _ => {}
376                }
377                let num_inputs = e.num_inputs();
378                let relation_type =
379                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
380                type_stack.truncate(type_stack.len() - num_inputs);
381                type_stack.push(relation_type);
382            },
383        );
384        assert_eq!(type_stack.len(), 1);
385        type_stack.pop().unwrap()
386    }
387
388    /// Reports the repr schema of the relation given the repr schema of the input relations.
389    pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
390        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
391        let unique_keys = self.keys_with_input_keys(
392            input_types.iter().map(|i| i.arity()),
393            input_types.iter().map(|i| &i.keys),
394        );
395        ReprRelationType::new(column_types).with_keys(unique_keys)
396    }
397
398    /// Reports the column types of the relation given the column types of the
399    /// input relations.
400    ///
401    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
402    /// variant is returned.
403    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
404    where
405        I: Iterator<Item = &'a Vec<ReprColumnType>>,
406    {
407        match self.try_col_with_input_cols(input_types) {
408            Ok(col_types) => col_types,
409            Err(err) => panic!("{err}"),
410        }
411    }
412
413    /// Reports the column types of the relation given the column types of the input relations.
414    ///
415    /// `input_types` is required to contain the column types for the input relations of
416    /// the current relation in the same order as they are visited by `try_visit_children`
417    /// method, even though not all may be used for computing the schema of the
418    /// current relation. For example, `Let` expects two input types, one for the
419    /// value relation and one for the body, in that order, but only the one for the
420    /// body is used to determine the type of the `Let` relation.
421    ///
422    /// It is meant to be used during post-order traversals to compute column types
423    /// incrementally.
424    pub fn try_col_with_input_cols<'a, I>(
425        &self,
426        mut input_types: I,
427    ) -> Result<Vec<ReprColumnType>, String>
428    where
429        I: Iterator<Item = &'a Vec<ReprColumnType>>,
430    {
431        use MirRelationExpr::*;
432
433        let col_types = match self {
434            Constant { rows, typ } => {
435                let mut col_types = typ.column_types.clone();
436                let mut seen_null = vec![false; typ.arity()];
437                if let Ok(rows) = rows {
438                    for (row, _diff) in rows {
439                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
440                            if datum.is_null() {
441                                seen_null[i] = true;
442                            }
443                        }
444                    }
445                }
446                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
447                    if !seen_null {
448                        col_types[i].nullable = false;
449                    } else {
450                        assert!(col_types[i].nullable);
451                    }
452                }
453                col_types
454            }
455            Get { typ, .. } => typ.column_types.clone(),
456            Project { outputs, .. } => {
457                let input = input_types.next().unwrap();
458                outputs.iter().map(|&i| input[i].clone()).collect()
459            }
460            Map { scalars, .. } => {
461                let mut result = input_types.next().unwrap().clone();
462                for scalar in scalars.iter() {
463                    result.push(scalar.typ(&result))
464                }
465                result
466            }
467            FlatMap { func, .. } => {
468                let mut result = input_types.next().unwrap().clone();
469                result.extend(
470                    func.output_sql_type()
471                        .column_types
472                        .iter()
473                        .map(ReprColumnType::from),
474                );
475                result
476            }
477            Filter { predicates, .. } => {
478                let mut result = input_types.next().unwrap().clone();
479
480                // Set as nonnull any columns where null values would cause
481                // any predicate to evaluate to null.
482                for column in non_nullable_columns(predicates) {
483                    result[column].nullable = false;
484                }
485                result
486            }
487            Join { equivalences, .. } => {
488                // Concatenate input column types
489                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
490                // In an equivalence class, if any column is non-null, then make all non-null
491                for equivalence in equivalences {
492                    let col_inds = equivalence
493                        .iter()
494                        .filter_map(|expr| match expr {
495                            MirScalarExpr::Column(col, _name) => Some(*col),
496                            _ => None,
497                        })
498                        .collect_vec();
499                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
500                        for i in col_inds {
501                            types.get_mut(i).unwrap().nullable = false;
502                        }
503                    }
504                }
505                types
506            }
507            Reduce {
508                group_key,
509                aggregates,
510                ..
511            } => {
512                let input = input_types.next().unwrap();
513                group_key
514                    .iter()
515                    .map(|e| e.typ(input))
516                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
517                    .collect()
518            }
519            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
520                input_types.next().unwrap().clone()
521            }
522            Let { .. } => {
523                // skip over the input types for `value`.
524                input_types.nth(1).unwrap().clone()
525            }
526            LetRec { values, .. } => {
527                // skip over the input types for `values`.
528                input_types.nth(values.len()).unwrap().clone()
529            }
530            Union { .. } => {
531                let mut result = input_types.next().unwrap().clone();
532                for input_col_types in input_types {
533                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
534                        *base_col = base_col
535                            .union(col)
536                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
537                    }
538                }
539                result
540            }
541        };
542
543        Ok(col_types)
544    }
545
546    /// Reports the unique keys of the relation given the arities and the unique
547    /// keys of the input relations.
548    ///
549    /// `input_arities` and `input_keys` are required to contain the
550    /// corresponding info for the input relations of
551    /// the current relation in the same order as they are visited by `try_visit_children`
552    /// method, even though not all may be used for computing the schema of the
553    /// current relation. For example, `Let` expects two input types, one for the
554    /// value relation and one for the body, in that order, but only the one for the
555    /// body is used to determine the type of the `Let` relation.
556    ///
557    /// It is meant to be used during post-order traversals to compute unique keys
558    /// incrementally.
559    pub fn keys_with_input_keys<'a, I, J>(
560        &self,
561        mut input_arities: I,
562        mut input_keys: J,
563    ) -> Vec<Vec<usize>>
564    where
565        I: Iterator<Item = usize>,
566        J: Iterator<Item = &'a Vec<Vec<usize>>>,
567    {
568        use MirRelationExpr::*;
569
570        let mut keys = match self {
571            Constant {
572                rows: Ok(rows),
573                typ,
574            } => {
575                let n_cols = typ.arity();
576                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
577                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
578                for (row, diff) in rows {
579                    for (i, datum) in row.iter().enumerate() {
580                        if datum != Datum::Dummy {
581                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
582                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
583                                if is_dupe {
584                                    unique_values_per_col[i] = None;
585                                }
586                            }
587                        }
588                    }
589                }
590                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
591                    vec![vec![]]
592                } else {
593                    // XXX - Multi-column keys are not detected.
594                    typ.keys
595                        .iter()
596                        .cloned()
597                        .chain(
598                            unique_values_per_col
599                                .into_iter()
600                                .enumerate()
601                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
602                                .map(|(idx, _)| vec![idx]),
603                        )
604                        .collect()
605                }
606            }
607            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
608            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
609            Let { .. } => {
610                // skip over the unique keys for value
611                input_keys.nth(1).unwrap().clone()
612            }
613            LetRec { values, .. } => {
614                // skip over the unique keys for value
615                input_keys.nth(values.len()).unwrap().clone()
616            }
617            Project { outputs, .. } => {
618                let input = input_keys.next().unwrap();
619                input
620                    .iter()
621                    .filter_map(|key_set| {
622                        if key_set.iter().all(|k| outputs.contains(k)) {
623                            Some(
624                                key_set
625                                    .iter()
626                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
627                                    .collect(),
628                            )
629                        } else {
630                            None
631                        }
632                    })
633                    .collect()
634            }
635            Map { scalars, .. } => {
636                let mut remappings = Vec::new();
637                let arity = input_arities.next().unwrap();
638                for (column, scalar) in scalars.iter().enumerate() {
639                    // assess whether the scalar preserves uniqueness,
640                    // and could participate in a key!
641
642                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
643                        match expr {
644                            MirScalarExpr::CallUnary { func, expr } => {
645                                if func.preserves_uniqueness() {
646                                    uniqueness(expr)
647                                } else {
648                                    None
649                                }
650                            }
651                            MirScalarExpr::Column(c, _name) => Some(*c),
652                            _ => None,
653                        }
654                    }
655
656                    if let Some(c) = uniqueness(scalar) {
657                        remappings.push((c, column + arity));
658                    }
659                }
660
661                let mut result = input_keys.next().unwrap().clone();
662                let mut new_keys = Vec::new();
663                // Any column in `remappings` could be replaced in a key
664                // by the corresponding c. This could lead to combinatorial
665                // explosion using our current representation, so we wont
666                // do that. Instead, we'll handle the case of one remapping.
667                if remappings.len() == 1 {
668                    let (old, new) = remappings.pop().unwrap();
669                    for key in &result {
670                        if key.contains(&old) {
671                            let mut new_key: Vec<usize> =
672                                key.iter().cloned().filter(|k| k != &old).collect();
673                            new_key.push(new);
674                            new_key.sort_unstable();
675                            new_keys.push(new_key);
676                        }
677                    }
678                    result.append(&mut new_keys);
679                }
680                result
681            }
682            FlatMap { .. } => {
683                // FlatMap can add duplicate rows, so input keys are no longer
684                // valid
685                vec![]
686            }
687            Negate { .. } => {
688                // Although negate may have distinct records for each key,
689                // the multiplicity is -1 rather than 1. This breaks many
690                // of the optimization uses of "keys".
691                vec![]
692            }
693            Filter { predicates, .. } => {
694                // A filter inherits the keys of its input unless the filters
695                // have reduced the input to a single row, in which case the
696                // keys of the input are `()`.
697                let mut input = input_keys.next().unwrap().clone();
698
699                if !input.is_empty() {
700                    // Track columns equated to literals, which we can prune.
701                    let mut cols_equal_to_literal = BTreeSet::new();
702
703                    // Perform union find on `col1 = col2` to establish
704                    // connected components of equated columns. Absent any
705                    // equalities, this will be `0 .. #c` (where #c is the
706                    // greatest column referenced by a predicate), but each
707                    // equality will orient the root of the greater to the root
708                    // of the lesser.
709                    let mut union_find = Vec::new();
710
711                    for expr in predicates.iter() {
712                        if let MirScalarExpr::CallBinary {
713                            func: crate::BinaryFunc::Eq(_),
714                            expr1,
715                            expr2,
716                        } = expr
717                        {
718                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
719                                if expr2.is_literal_ok() {
720                                    cols_equal_to_literal.insert(c);
721                                }
722                            }
723                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
724                                if expr1.is_literal_ok() {
725                                    cols_equal_to_literal.insert(c);
726                                }
727                            }
728                            // Perform union-find to equate columns.
729                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
730                                if c1 != c2 {
731                                    // Ensure union_find has entries up to
732                                    // max(c1, c2) by filling up missing
733                                    // positions with identity mappings.
734                                    while union_find.len() <= std::cmp::max(c1, c2) {
735                                        union_find.push(union_find.len());
736                                    }
737                                    let mut r1 = c1; // Find the representative column of [c1].
738                                    while r1 != union_find[r1] {
739                                        assert!(union_find[r1] < r1);
740                                        r1 = union_find[r1];
741                                    }
742                                    let mut r2 = c2; // Find the representative column of [c2].
743                                    while r2 != union_find[r2] {
744                                        assert!(union_find[r2] < r2);
745                                        r2 = union_find[r2];
746                                    }
747                                    // Union [c1] and [c2] by pointing the
748                                    // larger to the smaller representative (we
749                                    // update the remaining equivalence class
750                                    // members only once after this for-loop).
751                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
752                                }
753                            }
754                        }
755                    }
756
757                    // Complete union-find by pointing each element at its representative column.
758                    for i in 0..union_find.len() {
759                        // Iteration not required, as each prior already references the right column.
760                        union_find[i] = union_find[union_find[i]];
761                    }
762
763                    // Remove columns bound to literals, and remap columns equated to earlier columns.
764                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
765                    for key_set in &mut input {
766                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
767                        for col in key_set.iter_mut() {
768                            if let Some(equiv) = union_find.get(*col) {
769                                *col = *equiv;
770                            }
771                        }
772                        key_set.sort();
773                        key_set.dedup();
774                    }
775                    input.sort();
776                    input.dedup();
777
778                    // Expand out each key to each of its equivalent forms.
779                    // Each instance of `col` can be replaced by any equivalent column.
780                    // This has the potential to result in exponentially sized number of unique keys,
781                    // and in the future we should probably maintain unique keys modulo equivalence.
782
783                    // First, compute an inverse map from each representative
784                    // column `sub` to all other equivalent columns `col`.
785                    let mut subs = Vec::new();
786                    for (col, sub) in union_find.iter().enumerate() {
787                        if *sub != col {
788                            assert!(*sub < col);
789                            while subs.len() <= *sub {
790                                subs.push(Vec::new());
791                            }
792                            subs[*sub].push(col);
793                        }
794                    }
795                    // For each column, substitute for it in each occurrence.
796                    let mut to_add = Vec::new();
797                    for (col, subs) in subs.iter().enumerate() {
798                        if !subs.is_empty() {
799                            for key_set in input.iter() {
800                                if key_set.contains(&col) {
801                                    let mut to_extend = key_set.clone();
802                                    to_extend.retain(|c| c != &col);
803                                    for sub in subs {
804                                        to_extend.push(*sub);
805                                        to_add.push(to_extend.clone());
806                                        to_extend.pop();
807                                    }
808                                }
809                            }
810                        }
811                        // No deduplication, as we cannot introduce duplicates.
812                        input.append(&mut to_add);
813                    }
814                    for key_set in input.iter_mut() {
815                        key_set.sort();
816                        key_set.dedup();
817                    }
818                }
819                input
820            }
821            Join { equivalences, .. } => {
822                // It is important the `new_from_input_arities` constructor is
823                // used. Otherwise, Materialize may potentially end up in an
824                // infinite loop.
825                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
826
827                input_mapper.global_keys(input_keys, equivalences)
828            }
829            Reduce { group_key, .. } => {
830                // The group key should form a key, but we might already have
831                // keys that are subsets of the group key, and should retain
832                // those instead, if so.
833                let mut result = Vec::new();
834                for key_set in input_keys.next().unwrap() {
835                    if key_set
836                        .iter()
837                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
838                    {
839                        result.push(
840                            key_set
841                                .iter()
842                                .map(|i| {
843                                    group_key
844                                        .iter()
845                                        .position(|k| k == &MirScalarExpr::column(*i))
846                                        .unwrap()
847                                })
848                                .collect::<Vec<_>>(),
849                        );
850                    }
851                }
852                if result.is_empty() {
853                    result.push((0..group_key.len()).collect());
854                }
855                result
856            }
857            TopK {
858                group_key, limit, ..
859            } => {
860                // If `limit` is `Some(1)` then the group key will become
861                // a unique key, as there will be only one record with that key.
862                let mut result = input_keys.next().unwrap().clone();
863                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
864                    result.push(group_key.clone())
865                }
866                result
867            }
868            Union { base, inputs } => {
869                // Generally, unions do not have any unique keys, because
870                // each input might duplicate some. However, there is at
871                // least one idiomatic structure that does preserve keys,
872                // which results from SQL aggregations that must populate
873                // absent records with default values. In that pattern,
874                // the union of one GET with its negation, which has first
875                // been subjected to a projection and map, we can remove
876                // their influence on the key structure.
877                //
878                // If there are A, B, each with a unique `key` such that
879                // we are looking at
880                //
881                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
882                //
883                // Then we can report `key` as a unique key.
884                //
885                // TODO: make unique key structure an optimization analysis
886                // rather than part of the type information.
887                // TODO: perhaps ensure that (above) A.proj(key) is a
888                // subset of B, as otherwise there are negative records
889                // and who knows what is true (not expected, but again
890                // who knows what the query plan might look like).
891
892                let arity = input_arities.next().unwrap();
893                let (base_projection, base_with_project_stripped) =
894                    if let MirRelationExpr::Project { input, outputs } = &**base {
895                        (outputs.clone(), &**input)
896                    } else {
897                        // A input without a project is equivalent to an input
898                        // with the project being all columns in the input in order.
899                        ((0..arity).collect::<Vec<_>>(), &**base)
900                    };
901                let mut result = Vec::new();
902                if let MirRelationExpr::Get {
903                    id: first_id,
904                    typ: _,
905                    ..
906                } = base_with_project_stripped
907                {
908                    if inputs.len() == 1 {
909                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
910                            if let MirRelationExpr::Union { base, inputs } = &**input {
911                                if inputs.len() == 1 {
912                                    if let Some((input, outputs)) = base.is_negated_project() {
913                                        if let MirRelationExpr::Get {
914                                            id: second_id,
915                                            typ: _,
916                                            ..
917                                        } = input
918                                        {
919                                            if first_id == second_id {
920                                                result.extend(
921                                                    input_keys
922                                                        .next()
923                                                        .unwrap()
924                                                        .into_iter()
925                                                        .filter(|key| {
926                                                            key.iter().all(|c| {
927                                                                outputs.get(*c) == Some(c)
928                                                                    && base_projection.get(*c)
929                                                                        == Some(c)
930                                                            })
931                                                        })
932                                                        .cloned(),
933                                                );
934                                            }
935                                        }
936                                    }
937                                }
938                            }
939                        }
940                    }
941                }
942                // Important: do not inherit keys of either input, as not unique.
943                result
944            }
945        };
946        keys.sort();
947        keys.dedup();
948        keys
949    }
950
951    /// The number of columns in the relation.
952    ///
953    /// This number is determined from the type, which is determined recursively
954    /// at non-trivial cost.
955    ///
956    /// The arity is computed incrementally with a recursive post-order
957    /// traversal, that accumulates the arities for the relations yet to be
958    /// visited in `arity_stack`.
959    pub fn arity(&self) -> usize {
960        let mut arity_stack = Vec::new();
961        #[allow(deprecated)]
962        self.visit_pre_post_nolimit(
963            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
964                match &e {
965                    MirRelationExpr::Let { body, .. } => {
966                        // Do not traverse the value sub-graph, since it's not relevant for
967                        // determining the arity of Let operators.
968                        Some(vec![&*body])
969                    }
970                    MirRelationExpr::LetRec { body, .. } => {
971                        // Do not traverse the value sub-graph, since it's not relevant for
972                        // determining the arity of Let operators.
973                        Some(vec![&*body])
974                    }
975                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
976                        // No further traversal is required; these operators know their arity.
977                        Some(Vec::new())
978                    }
979                    _ => None,
980                }
981            },
982            &mut |e: &MirRelationExpr| {
983                match &e {
984                    MirRelationExpr::Let { .. } => {
985                        let body_arity = arity_stack.pop().unwrap();
986                        arity_stack.push(0);
987                        arity_stack.push(body_arity);
988                    }
989                    MirRelationExpr::LetRec { values, .. } => {
990                        let body_arity = arity_stack.pop().unwrap();
991                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
992                        arity_stack.push(body_arity);
993                    }
994                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
995                        arity_stack.push(0);
996                    }
997                    _ => {}
998                }
999                let num_inputs = e.num_inputs();
1000                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1001                let arity = e.arity_with_input_arities(input_arities);
1002                arity_stack.push(arity);
1003            },
1004        );
1005        assert_eq!(arity_stack.len(), 1);
1006        arity_stack.pop().unwrap()
1007    }
1008
1009    /// Reports the arity of the relation given the schema of the input relations.
1010    ///
1011    /// `input_arities` is required to contain the arities for the input relations of
1012    /// the current relation in the same order as they are visited by `try_visit_children`
1013    /// method, even though not all may be used for computing the schema of the
1014    /// current relation. For example, `Let` expects two input types, one for the
1015    /// value relation and one for the body, in that order, but only the one for the
1016    /// body is used to determine the type of the `Let` relation.
1017    ///
1018    /// It is meant to be used during post-order traversals to compute arities
1019    /// incrementally.
1020    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1021    where
1022        I: Iterator<Item = usize>,
1023    {
1024        use MirRelationExpr::*;
1025
1026        match self {
1027            Constant { rows: _, typ } => typ.arity(),
1028            Get { typ, .. } => typ.arity(),
1029            Let { .. } => {
1030                input_arities.next();
1031                input_arities.next().unwrap()
1032            }
1033            LetRec { values, .. } => {
1034                for _ in 0..values.len() {
1035                    input_arities.next();
1036                }
1037                input_arities.next().unwrap()
1038            }
1039            Project { outputs, .. } => outputs.len(),
1040            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1041            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1042            Join { .. } => input_arities.sum(),
1043            Reduce {
1044                input: _,
1045                group_key,
1046                aggregates,
1047                ..
1048            } => group_key.len() + aggregates.len(),
1049            Filter { .. }
1050            | TopK { .. }
1051            | Negate { .. }
1052            | Threshold { .. }
1053            | Union { .. }
1054            | ArrangeBy { .. } => input_arities.next().unwrap(),
1055        }
1056    }
1057
1058    /// The number of child relations this relation has.
1059    pub fn num_inputs(&self) -> usize {
1060        let mut count = 0;
1061
1062        self.visit_children(|_| count += 1);
1063
1064        count
1065    }
1066
1067    /// Constructs a constant collection from specific rows and schema, where
1068    /// each row will have a multiplicity of one.
1069    pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1070        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1071        MirRelationExpr::constant_diff(rows, typ)
1072    }
1073
1074    /// Constructs a constant collection from specific rows and schema, where
1075    /// each row can have an arbitrary multiplicity.
1076    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1077        for (row, _diff) in &rows {
1078            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1079                assert!(
1080                    datum.is_instance_of(column_typ),
1081                    "Expected datum of type {:?}, got value {:?}",
1082                    column_typ,
1083                    datum
1084                );
1085            }
1086        }
1087        let rows = Ok(rows
1088            .into_iter()
1089            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1090            .collect());
1091        MirRelationExpr::Constant { rows, typ }
1092    }
1093
1094    /// If self is a constant, return the value and the type, otherwise `None`.
1095    /// Looks behind `ArrangeBy`s.
1096    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1097        match self {
1098            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1099            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1100            _ => None,
1101        }
1102    }
1103
1104    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1105    /// Looks behind `ArrangeBy`s.
1106    pub fn as_const_mut(
1107        &mut self,
1108    ) -> Option<(
1109        &mut Result<Vec<(Row, Diff)>, EvalError>,
1110        &mut ReprRelationType,
1111    )> {
1112        match self {
1113            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1114            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1115            _ => None,
1116        }
1117    }
1118
1119    /// If self is a constant error, return the error, otherwise `None`.
1120    /// Looks behind `ArrangeBy`s.
1121    pub fn as_const_err(&self) -> Option<&EvalError> {
1122        match self {
1123            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1124            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1125            _ => None,
1126        }
1127    }
1128
1129    /// Checks if `self` is the single element collection with no columns.
1130    pub fn is_constant_singleton(&self) -> bool {
1131        if let Some((Ok(rows), typ)) = self.as_const() {
1132            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1133        } else {
1134            false
1135        }
1136    }
1137
1138    /// Constructs the expression for getting a local collection.
1139    pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1140        MirRelationExpr::Get {
1141            id: Id::Local(id),
1142            typ,
1143            access_strategy: AccessStrategy::UnknownOrLocal,
1144        }
1145    }
1146
1147    /// Constructs the expression for getting a global collection
1148    pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1149        MirRelationExpr::Get {
1150            id: Id::Global(id),
1151            typ,
1152            access_strategy: AccessStrategy::UnknownOrLocal,
1153        }
1154    }
1155
1156    /// Retains only the columns specified by `output`.
1157    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1158        if let MirRelationExpr::Project {
1159            outputs: columns, ..
1160        } = &mut self
1161        {
1162            // Update `outputs` to reference base columns of `input`.
1163            for column in outputs.iter_mut() {
1164                *column = columns[*column];
1165            }
1166            *columns = outputs;
1167            self
1168        } else {
1169            MirRelationExpr::Project {
1170                input: Box::new(self),
1171                outputs,
1172            }
1173        }
1174    }
1175
1176    /// Append to each row the results of applying elements of `scalar`.
1177    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1178        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1179            s.extend(scalars);
1180            self
1181        } else if !scalars.is_empty() {
1182            MirRelationExpr::Map {
1183                input: Box::new(self),
1184                scalars,
1185            }
1186        } else {
1187            self
1188        }
1189    }
1190
1191    /// Append to each row a single `scalar`.
1192    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1193        self.map(vec![scalar])
1194    }
1195
1196    /// Like `map`, but yields zero-or-more output rows per input row
1197    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1198        MirRelationExpr::FlatMap {
1199            input: Box::new(self),
1200            func,
1201            exprs,
1202        }
1203    }
1204
1205    /// Retain only the rows satisfying each of several predicates.
1206    pub fn filter<I>(mut self, predicates: I) -> Self
1207    where
1208        I: IntoIterator<Item = MirScalarExpr>,
1209    {
1210        // Extract existing predicates
1211        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1212            self = *input;
1213            predicates
1214        } else {
1215            Vec::new()
1216        };
1217        // Normalize collection of predicates.
1218        new_predicates.extend(predicates);
1219        new_predicates.retain(|p| !p.is_literal_true());
1220        new_predicates.sort();
1221        new_predicates.dedup();
1222        // Introduce a `Filter` only if we have predicates.
1223        if !new_predicates.is_empty() {
1224            self = MirRelationExpr::Filter {
1225                input: Box::new(self),
1226                predicates: new_predicates,
1227            };
1228        }
1229
1230        self
1231    }
1232
1233    /// Form the Cartesian outer-product of rows in both inputs.
1234    pub fn product(mut self, right: Self) -> Self {
1235        if right.is_constant_singleton() {
1236            self
1237        } else if self.is_constant_singleton() {
1238            right
1239        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1240            inputs.push(right);
1241            self
1242        } else {
1243            MirRelationExpr::join(vec![self, right], vec![])
1244        }
1245    }
1246
1247    /// Performs a relational equijoin among the input collections.
1248    ///
1249    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1250    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1251    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1252    ///
1253    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1254    /// input collection and for each row examining its `column`th column.
1255    ///
1256    /// # Example
1257    ///
1258    /// ```rust
1259    /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1260    /// use mz_expr::MirRelationExpr;
1261    ///
1262    /// // A common schema for each input.
1263    /// let schema = ReprRelationType::new(vec![
1264    ///     ReprScalarType::Int32.nullable(false),
1265    ///     ReprScalarType::Int32.nullable(false),
1266    /// ]);
1267    ///
1268    /// // the specific data are not important here.
1269    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1270    ///
1271    /// // Three collections that could have been different.
1272    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1273    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1274    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275    ///
1276    /// // Join the three relations looking for triangles, like so.
1277    /// //
1278    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1279    /// let joined = MirRelationExpr::join(
1280    ///     vec![input0, input1, input2],
1281    ///     vec![
1282    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1283    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1284    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1285    ///     ],
1286    /// );
1287    ///
1288    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1289    /// // A projection resolves this and produces the correct output.
1290    /// let result = joined.project(vec![0, 1, 3]);
1291    /// ```
1292    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1293        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1294
1295        let equivalences = variables
1296            .into_iter()
1297            .map(|vs| {
1298                vs.into_iter()
1299                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1300                    .collect::<Vec<_>>()
1301            })
1302            .collect::<Vec<_>>();
1303
1304        Self::join_scalars(inputs, equivalences)
1305    }
1306
1307    /// Constructs a join operator from inputs and required-equal scalar expressions.
1308    pub fn join_scalars(
1309        mut inputs: Vec<MirRelationExpr>,
1310        equivalences: Vec<Vec<MirScalarExpr>>,
1311    ) -> Self {
1312        // Remove all constant inputs that are the identity for join.
1313        // They neither introduce nor modify any column references.
1314        inputs.retain(|i| !i.is_constant_singleton());
1315        MirRelationExpr::Join {
1316            inputs,
1317            equivalences,
1318            implementation: JoinImplementation::Unimplemented,
1319        }
1320    }
1321
1322    /// Perform a key-wise reduction / aggregation.
1323    ///
1324    /// The `group_key` argument indicates columns in the input collection that should
1325    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1326    /// one output column in addition to the keys.
1327    pub fn reduce(
1328        self,
1329        group_key: Vec<usize>,
1330        aggregates: Vec<AggregateExpr>,
1331        expected_group_size: Option<u64>,
1332    ) -> Self {
1333        MirRelationExpr::Reduce {
1334            input: Box::new(self),
1335            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1336            aggregates,
1337            monotonic: false,
1338            expected_group_size,
1339        }
1340    }
1341
1342    /// Perform a key-wise reduction order by and limit.
1343    ///
1344    /// The `group_key` argument indicates columns in the input collection that should
1345    /// be grouped, the `order_key` argument indicates columns that should be further
1346    /// used to order records within groups, and the `limit` argument constrains the
1347    /// total number of records that should be produced in each group.
1348    pub fn top_k(
1349        self,
1350        group_key: Vec<usize>,
1351        order_key: Vec<ColumnOrder>,
1352        limit: Option<MirScalarExpr>,
1353        offset: usize,
1354        expected_group_size: Option<u64>,
1355    ) -> Self {
1356        MirRelationExpr::TopK {
1357            input: Box::new(self),
1358            group_key,
1359            order_key,
1360            limit,
1361            offset,
1362            expected_group_size,
1363            monotonic: false,
1364        }
1365    }
1366
1367    /// Negates the occurrences of each row.
1368    pub fn negate(self) -> Self {
1369        if let MirRelationExpr::Negate { input } = self {
1370            *input
1371        } else {
1372            MirRelationExpr::Negate {
1373                input: Box::new(self),
1374            }
1375        }
1376    }
1377
1378    /// Removes all but the first occurrence of each row.
1379    pub fn distinct(self) -> Self {
1380        let arity = self.arity();
1381        self.distinct_by((0..arity).collect())
1382    }
1383
1384    /// Removes all but the first occurrence of each key. Columns not included
1385    /// in the `group_key` are discarded.
1386    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1387        self.reduce(group_key, vec![], None)
1388    }
1389
1390    /// Discards rows with a negative frequency.
1391    pub fn threshold(self) -> Self {
1392        if let MirRelationExpr::Threshold { .. } = &self {
1393            self
1394        } else {
1395            MirRelationExpr::Threshold {
1396                input: Box::new(self),
1397            }
1398        }
1399    }
1400
1401    /// Unions together any number inputs.
1402    ///
1403    /// If `inputs` is empty, then an empty relation of type `typ` is
1404    /// constructed.
1405    pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1406        // Deconstruct `inputs` as `Union`s and reconstitute.
1407        let mut flat_inputs = Vec::with_capacity(inputs.len());
1408        for input in inputs {
1409            if let MirRelationExpr::Union { base, inputs } = input {
1410                flat_inputs.push(*base);
1411                flat_inputs.extend(inputs);
1412            } else {
1413                flat_inputs.push(input);
1414            }
1415        }
1416        inputs = flat_inputs;
1417        if inputs.len() == 0 {
1418            MirRelationExpr::Constant {
1419                rows: Ok(vec![]),
1420                typ,
1421            }
1422        } else if inputs.len() == 1 {
1423            inputs.into_element()
1424        } else {
1425            MirRelationExpr::Union {
1426                base: Box::new(inputs.remove(0)),
1427                inputs,
1428            }
1429        }
1430    }
1431
1432    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1433    pub fn union(self, other: Self) -> Self {
1434        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1435        let mut flat_inputs = Vec::with_capacity(2);
1436        if let MirRelationExpr::Union { base, inputs } = self {
1437            flat_inputs.push(*base);
1438            flat_inputs.extend(inputs);
1439        } else {
1440            flat_inputs.push(self);
1441        }
1442        if let MirRelationExpr::Union { base, inputs } = other {
1443            flat_inputs.push(*base);
1444            flat_inputs.extend(inputs);
1445        } else {
1446            flat_inputs.push(other);
1447        }
1448
1449        MirRelationExpr::Union {
1450            base: Box::new(flat_inputs.remove(0)),
1451            inputs: flat_inputs,
1452        }
1453    }
1454
1455    /// Arranges the collection by the specified columns
1456    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1457        MirRelationExpr::ArrangeBy {
1458            input: Box::new(self),
1459            keys: keys.to_owned(),
1460        }
1461    }
1462
1463    /// Indicates if this is a constant empty collection.
1464    ///
1465    /// A false value does not mean the collection is known to be non-empty,
1466    /// only that we cannot currently determine that it is statically empty.
1467    pub fn is_empty(&self) -> bool {
1468        if let Some((Ok(rows), ..)) = self.as_const() {
1469            rows.is_empty()
1470        } else {
1471            false
1472        }
1473    }
1474
1475    /// If the expression is a negated project, return the input and the projection.
1476    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1477        if let MirRelationExpr::Negate { input } = self {
1478            if let MirRelationExpr::Project { input, outputs } = &**input {
1479                return Some((&**input, outputs));
1480            }
1481        }
1482        if let MirRelationExpr::Project { input, outputs } = self {
1483            if let MirRelationExpr::Negate { input } = &**input {
1484                return Some((&**input, outputs));
1485            }
1486        }
1487        None
1488    }
1489
1490    /// Pretty-print this [MirRelationExpr] to a string.
1491    pub fn pretty(&self) -> String {
1492        let config = ExplainConfig::default();
1493        self.debug_explain(&config, None)
1494    }
1495
1496    /// Pretty-print this [MirRelationExpr] to a string using a custom
1497    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1498    /// This is intended for debugging and tests, not users.
1499    pub fn debug_explain(
1500        &self,
1501        config: &ExplainConfig,
1502        humanizer: Option<&dyn ExprHumanizer>,
1503    ) -> String {
1504        text_string_at(self, || PlanRenderingContext {
1505            indent: Indent::default(),
1506            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1507            annotations: BTreeMap::default(),
1508            config,
1509            ambiguous_ids: BTreeSet::default(),
1510        })
1511    }
1512
1513    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1514    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1515    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1516    /// the best possible key and nullability, since we are making an empty collection.
1517    ///
1518    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1519    /// the correct type.
1520    pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1521        if let Some(typ) = &typ {
1522            let self_typ = self.typ();
1523            soft_assert_no_log!(
1524                self_typ
1525                    .column_types
1526                    .iter()
1527                    .zip_eq(typ.column_types.iter())
1528                    .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1529            );
1530        }
1531        let mut typ = typ.unwrap_or_else(|| self.typ());
1532        typ.keys = vec![vec![]];
1533        for ct in typ.column_types.iter_mut() {
1534            ct.nullable = false;
1535        }
1536        std::mem::replace(
1537            self,
1538            MirRelationExpr::Constant {
1539                rows: Ok(vec![]),
1540                typ,
1541            },
1542        )
1543    }
1544
1545    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1546    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1547    /// possible nullability, since we are making an empty collection.
1548    pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1549        self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1550    }
1551
1552    /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1553    ///
1554    /// This is the preferred entry point for optimizer transforms, where repr
1555    /// types are the native currency. Internally converts to [`SqlColumnType`]
1556    /// and delegates to [`Self::take_safely_with_col_types`].
1557    pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1558        self.take_safely(Some(ReprRelationType::new(typ)))
1559    }
1560
1561    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1562    ///
1563    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1564    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1565        let empty = MirRelationExpr::Constant {
1566            rows: Ok(vec![]),
1567            typ: ReprRelationType::new(Vec::new()),
1568        };
1569        std::mem::replace(self, empty)
1570    }
1571
1572    /// Replaces `self` with some logic applied to `self`.
1573    pub fn replace_using<F>(&mut self, logic: F)
1574    where
1575        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1576    {
1577        let empty = MirRelationExpr::Constant {
1578            rows: Ok(vec![]),
1579            typ: ReprRelationType::new(Vec::new()),
1580        };
1581        let expr = std::mem::replace(self, empty);
1582        *self = logic(expr);
1583    }
1584
1585    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1586    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1587    where
1588        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1589    {
1590        if let MirRelationExpr::Get { .. } = self {
1591            // already done
1592            body(id_gen, self)
1593        } else {
1594            let id = LocalId::new(id_gen.allocate_id());
1595            let get = MirRelationExpr::Get {
1596                id: Id::Local(id),
1597                typ: self.typ(),
1598                access_strategy: AccessStrategy::UnknownOrLocal,
1599            };
1600            let body = (body)(id_gen, get)?;
1601            Ok(MirRelationExpr::Let {
1602                id,
1603                value: Box::new(self),
1604                body: Box::new(body),
1605            })
1606        }
1607    }
1608
1609    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1610    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1611    pub fn anti_lookup<E>(
1612        self,
1613        id_gen: &mut IdGen,
1614        keys_and_values: MirRelationExpr,
1615        default: Vec<(Datum, ReprScalarType)>,
1616    ) -> Result<MirRelationExpr, E> {
1617        let (data, column_types): (Vec<_>, Vec<_>) = default
1618            .into_iter()
1619            .map(|(datum, scalar_type)| {
1620                (
1621                    datum,
1622                    ReprColumnType {
1623                        scalar_type,
1624                        nullable: datum.is_null(),
1625                    },
1626                )
1627            })
1628            .unzip();
1629        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1630        self.let_in(id_gen, |_id_gen, get_keys| {
1631            let get_keys_arity = get_keys.arity();
1632            Ok(MirRelationExpr::join(
1633                vec![
1634                    // all the missing keys (with count 1)
1635                    keys_and_values
1636                        .distinct_by((0..get_keys_arity).collect())
1637                        .negate()
1638                        .union(get_keys.clone().distinct()),
1639                    // join with keys to get the correct counts
1640                    get_keys.clone(),
1641                ],
1642                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1643            )
1644            // get rid of the extra copies of columns from keys
1645            .project((0..get_keys_arity).collect())
1646            // This join is logically equivalent to
1647            // `.map(<default_expr>)`, but using a join allows for
1648            // potential predicate pushdown and elision in the
1649            // optimizer.
1650            .product(MirRelationExpr::constant(
1651                vec![data],
1652                ReprRelationType::new(column_types),
1653            )))
1654        })
1655    }
1656
1657    /// Return:
1658    /// * every row in keys_and_values
1659    /// * every row in `self` that does not have a matching row in the first columns of
1660    ///   `keys_and_values`, using `default` to fill in the remaining columns
1661    /// (This is LEFT OUTER JOIN if:
1662    /// 1) `default` is a row of null
1663    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1664    pub fn lookup<E>(
1665        self,
1666        id_gen: &mut IdGen,
1667        keys_and_values: MirRelationExpr,
1668        default: Vec<(Datum<'static>, ReprScalarType)>,
1669    ) -> Result<MirRelationExpr, E> {
1670        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1671            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1672                id_gen,
1673                get_keys_and_values,
1674                default,
1675            )?))
1676        })
1677    }
1678
1679    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1680    pub fn contains_temporal(&self) -> bool {
1681        let mut contains = false;
1682        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1683        contains
1684    }
1685
1686    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1687    ///
1688    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1689    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1690    where
1691        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1692    {
1693        use MirRelationExpr::*;
1694        match self {
1695            Map { scalars, .. } => {
1696                for s in scalars {
1697                    f(s)?;
1698                }
1699            }
1700            Filter { predicates, .. } => {
1701                for p in predicates {
1702                    f(p)?;
1703                }
1704            }
1705            FlatMap { exprs, .. } => {
1706                for expr in exprs {
1707                    f(expr)?;
1708                }
1709            }
1710            Join {
1711                inputs: _,
1712                equivalences,
1713                implementation,
1714            } => {
1715                for equivalence in equivalences {
1716                    for expr in equivalence {
1717                        f(expr)?;
1718                    }
1719                }
1720                match implementation {
1721                    JoinImplementation::Differential((_, start_key, _), order) => {
1722                        if let Some(start_key) = start_key {
1723                            for k in start_key {
1724                                f(k)?;
1725                            }
1726                        }
1727                        for (_, lookup_key, _) in order {
1728                            for k in lookup_key {
1729                                f(k)?;
1730                            }
1731                        }
1732                    }
1733                    JoinImplementation::DeltaQuery(paths) => {
1734                        for path in paths {
1735                            for (_, lookup_key, _) in path {
1736                                for k in lookup_key {
1737                                    f(k)?;
1738                                }
1739                            }
1740                        }
1741                    }
1742                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1743                        for k in index_key {
1744                            f(k)?;
1745                        }
1746                    }
1747                    JoinImplementation::Unimplemented => {} // No scalar exprs
1748                }
1749            }
1750            ArrangeBy { keys, .. } => {
1751                for key in keys {
1752                    for s in key {
1753                        f(s)?;
1754                    }
1755                }
1756            }
1757            Reduce {
1758                group_key,
1759                aggregates,
1760                ..
1761            } => {
1762                for s in group_key {
1763                    f(s)?;
1764                }
1765                for agg in aggregates {
1766                    f(&mut agg.expr)?;
1767                }
1768            }
1769            TopK { limit, .. } => {
1770                if let Some(s) = limit {
1771                    f(s)?;
1772                }
1773            }
1774            Constant { .. }
1775            | Get { .. }
1776            | Let { .. }
1777            | LetRec { .. }
1778            | Project { .. }
1779            | Negate { .. }
1780            | Threshold { .. }
1781            | Union { .. } => (),
1782        }
1783        Ok(())
1784    }
1785
1786    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1787    /// rooted at `self`.
1788    ///
1789    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1790    /// nodes.
1791    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1792    where
1793        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1794        E: From<RecursionLimitError>,
1795    {
1796        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1797    }
1798
1799    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1800    /// rooted at `self`.
1801    ///
1802    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1803    /// nodes.
1804    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1805    where
1806        F: FnMut(&mut MirScalarExpr),
1807    {
1808        self.try_visit_scalars_mut(&mut |s| {
1809            f(s);
1810            Ok::<_, RecursionLimitError>(())
1811        })
1812        .expect("Unexpected error in `visit_scalars_mut` call");
1813    }
1814
1815    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1816    ///
1817    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1818    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1819    where
1820        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1821    {
1822        use MirRelationExpr::*;
1823        match self {
1824            Map { scalars, .. } => {
1825                for s in scalars {
1826                    f(s)?;
1827                }
1828            }
1829            Filter { predicates, .. } => {
1830                for p in predicates {
1831                    f(p)?;
1832                }
1833            }
1834            FlatMap { exprs, .. } => {
1835                for expr in exprs {
1836                    f(expr)?;
1837                }
1838            }
1839            Join {
1840                inputs: _,
1841                equivalences,
1842                implementation,
1843            } => {
1844                for equivalence in equivalences {
1845                    for expr in equivalence {
1846                        f(expr)?;
1847                    }
1848                }
1849                match implementation {
1850                    JoinImplementation::Differential((_, start_key, _), order) => {
1851                        if let Some(start_key) = start_key {
1852                            for k in start_key {
1853                                f(k)?;
1854                            }
1855                        }
1856                        for (_, lookup_key, _) in order {
1857                            for k in lookup_key {
1858                                f(k)?;
1859                            }
1860                        }
1861                    }
1862                    JoinImplementation::DeltaQuery(paths) => {
1863                        for path in paths {
1864                            for (_, lookup_key, _) in path {
1865                                for k in lookup_key {
1866                                    f(k)?;
1867                                }
1868                            }
1869                        }
1870                    }
1871                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1872                        for k in index_key {
1873                            f(k)?;
1874                        }
1875                    }
1876                    JoinImplementation::Unimplemented => {} // No scalar exprs
1877                }
1878            }
1879            ArrangeBy { keys, .. } => {
1880                for key in keys {
1881                    for s in key {
1882                        f(s)?;
1883                    }
1884                }
1885            }
1886            Reduce {
1887                group_key,
1888                aggregates,
1889                ..
1890            } => {
1891                for s in group_key {
1892                    f(s)?;
1893                }
1894                for agg in aggregates {
1895                    f(&agg.expr)?;
1896                }
1897            }
1898            TopK { limit, .. } => {
1899                if let Some(s) = limit {
1900                    f(s)?;
1901                }
1902            }
1903            Constant { .. }
1904            | Get { .. }
1905            | Let { .. }
1906            | LetRec { .. }
1907            | Project { .. }
1908            | Negate { .. }
1909            | Threshold { .. }
1910            | Union { .. } => (),
1911        }
1912        Ok(())
1913    }
1914
1915    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1916    /// rooted at `self`.
1917    ///
1918    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1919    /// nodes.
1920    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1921    where
1922        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1923        E: From<RecursionLimitError>,
1924    {
1925        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1926    }
1927
1928    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1929    /// rooted at `self`.
1930    ///
1931    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1932    /// nodes.
1933    pub fn visit_scalars<F>(&self, f: &mut F)
1934    where
1935        F: FnMut(&MirScalarExpr),
1936    {
1937        self.try_visit_scalars(&mut |s| {
1938            f(s);
1939            Ok::<_, RecursionLimitError>(())
1940        })
1941        .expect("Unexpected error in `visit_scalars` call");
1942    }
1943
1944    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1945    /// stack overflow in `drop_in_place`.
1946    ///
1947    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1948    /// dropped or otherwise overwritten.
1949    pub fn destroy_carefully(&mut self) {
1950        let mut todo = vec![self.take_dangerous()];
1951        while let Some(mut expr) = todo.pop() {
1952            for child in expr.children_mut() {
1953                todo.push(child.take_dangerous());
1954            }
1955        }
1956    }
1957
1958    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1959    /// debug printing purposes.
1960    pub fn debug_size_and_depth(&self) -> (usize, usize) {
1961        let mut size = 0;
1962        let mut max_depth = 0;
1963        let mut todo = vec![(self, 1)];
1964        while let Some((expr, depth)) = todo.pop() {
1965            size += 1;
1966            max_depth = max(max_depth, depth);
1967            todo.extend(expr.children().map(|c| (c, depth + 1)));
1968        }
1969        (size, max_depth)
1970    }
1971
1972    /// The MirRelationExpr is considered potentially expensive if and only if
1973    /// at least one of the following conditions is true:
1974    ///
1975    ///  - It contains at least one FlatMap or a Reduce operator.
1976    ///  - It contains at least one MirScalarExpr with a function call.
1977    ///
1978    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1979    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1980    pub fn could_run_expensive_function(&self) -> bool {
1981        let mut result = false;
1982        self.visit_pre(|e: &MirRelationExpr| {
1983            use MirRelationExpr::*;
1984            use MirScalarExpr::*;
1985            if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1986                result |= match scalar {
1987                    Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1988                    // Function calls are considered expensive
1989                    CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1990                };
1991                Ok(())
1992            }) {
1993                // Conservatively set `true` if on RecursionLimitError.
1994                result = true;
1995            }
1996            // FlatMap has a table function; Reduce has an aggregate function.
1997            // Other constructs use MirScalarExpr to run a function
1998            result |= matches!(e, FlatMap { .. } | Reduce { .. });
1999        });
2000        result
2001    }
2002
2003    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2004    /// than what `Hashable::hashed` would give us.)
2005    pub fn hash_to_u64(&self) -> u64 {
2006        let mut h = DefaultHasher::new();
2007        self.hash(&mut h);
2008        h.finish()
2009    }
2010}
2011
2012// `LetRec` helpers
2013impl MirRelationExpr {
2014    /// True when `expr` contains a `LetRec` AST node.
2015    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2016        let mut worklist = vec![self];
2017        while let Some(expr) = worklist.pop() {
2018            if let MirRelationExpr::LetRec { .. } = expr {
2019                return true;
2020            }
2021            worklist.extend(expr.children());
2022        }
2023        false
2024    }
2025
2026    /// Return the number of sub-expressions in the tree (including self).
2027    pub fn size(&self) -> usize {
2028        let mut size = 0;
2029        self.visit_pre(|_| size += 1);
2030        size
2031    }
2032
2033    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2034    /// iterations. These are those ids that have a reference before they are defined, when reading
2035    /// all the bindings in order.
2036    ///
2037    /// For example:
2038    /// ```SQL
2039    /// WITH MUTUALLY RECURSIVE
2040    ///     x(...) AS f(z),
2041    ///     y(...) AS g(x),
2042    ///     z(...) AS h(y)
2043    /// ...;
2044    /// ```
2045    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2046    /// iteration.
2047    ///
2048    /// Note that if a binding references itself, that is also returned.
2049    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2050        let mut used_across_iterations = BTreeSet::new();
2051        let mut defined = BTreeSet::new();
2052        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2053            value.visit_pre(|expr| {
2054                if let MirRelationExpr::Get {
2055                    id: Local(get_id), ..
2056                } = expr
2057                {
2058                    // If we haven't seen a definition for it yet, then this will refer
2059                    // to the previous iteration.
2060                    // The `ids.contains` part of the condition is needed to exclude
2061                    // those ids that are not really in this LetRec, but either an inner
2062                    // or outer one.
2063                    if !defined.contains(get_id) && ids.contains(get_id) {
2064                        used_across_iterations.insert(*get_id);
2065                    }
2066                }
2067            });
2068            defined.insert(*binding_id);
2069        }
2070        used_across_iterations
2071    }
2072
2073    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2074    ///
2075    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2076    /// identifiers are rewritten to be the constant collection.
2077    /// This makes the computation perform exactly "one" iteration.
2078    ///
2079    /// This was used only temporarily while developing `LetRec`.
2080    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2081        let mut deadlist = BTreeSet::new();
2082        let mut worklist = vec![self];
2083        while let Some(expr) = worklist.pop() {
2084            if let MirRelationExpr::LetRec {
2085                ids,
2086                values,
2087                limits: _,
2088                body,
2089            } = expr
2090            {
2091                let ids_values = values
2092                    .drain(..)
2093                    .zip_eq(ids)
2094                    .map(|(value, id)| (*id, value))
2095                    .collect::<Vec<_>>();
2096                *expr = body.take_dangerous();
2097                for (id, mut value) in ids_values.into_iter().rev() {
2098                    // Remove references to potentially recursive identifiers.
2099                    deadlist.insert(id);
2100                    value.visit_pre_mut(|e| {
2101                        if let MirRelationExpr::Get {
2102                            id: crate::Id::Local(id),
2103                            typ,
2104                            ..
2105                        } = e
2106                        {
2107                            let typ = typ.clone();
2108                            if deadlist.contains(id) {
2109                                e.take_safely(Some(typ));
2110                            }
2111                        }
2112                    });
2113                    *expr = MirRelationExpr::Let {
2114                        id,
2115                        value: Box::new(value),
2116                        body: Box::new(expr.take_dangerous()),
2117                    };
2118                }
2119                worklist.push(expr);
2120            } else {
2121                worklist.extend(expr.children_mut().rev());
2122            }
2123        }
2124    }
2125
2126    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2127    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2128    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2129    /// redefinition.
2130    ///
2131    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2132    pub fn collect_expirations(
2133        id: LocalId,
2134        expr: &MirRelationExpr,
2135        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2136    ) {
2137        expr.visit_pre(|e| {
2138            if let MirRelationExpr::Get {
2139                id: Id::Local(referenced_id),
2140                ..
2141            } = e
2142            {
2143                // The following check needs `renumber_bindings` to have run recently
2144                if referenced_id >= &id {
2145                    expire_whens
2146                        .entry(*referenced_id)
2147                        .or_insert_with(Vec::new)
2148                        .push(id);
2149                }
2150            }
2151        });
2152    }
2153
2154    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2155    /// about such Ids whose information depended on the earlier definition of `id`, according to
2156    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2157    pub fn do_expirations<I>(
2158        redefined_id: LocalId,
2159        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2160        id_infos: &mut BTreeMap<LocalId, I>,
2161    ) -> Vec<(LocalId, I)> {
2162        let mut expired_infos = Vec::new();
2163        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2164            for expired_id in expirations.into_iter() {
2165                if let Some(offer) = id_infos.remove(&expired_id) {
2166                    expired_infos.push((expired_id, offer));
2167                }
2168            }
2169        }
2170        expired_infos
2171    }
2172}
2173/// Augment non-nullability of columns, by observing either
2174/// 1. Predicates that explicitly test for null values, and
2175/// 2. Columns that if null would make a predicate be null.
2176pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2177    let mut nonnull_required_columns = BTreeSet::new();
2178    for predicate in predicates {
2179        // Add any columns that being null would force the predicate to be null.
2180        // Should that happen, the row would be discarded.
2181        predicate.non_null_requirements(&mut nonnull_required_columns);
2182
2183        /*
2184        Test for explicit checks that a column is non-null.
2185
2186        This analysis is ad hoc, and will miss things:
2187
2188        materialize=> create table a(x int, y int);
2189        CREATE TABLE
2190        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2191        Optimized Plan
2192        --------------------------------------------------------------------------------------------------------
2193        Explained Query:                                                                                      +
2194        Project (#0) // { types: "(integer?)" }                                                             +
2195        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2196        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2197                                                                                  +
2198        Source materialize.public.a                                                                           +
2199        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2200
2201        (1 row)
2202        */
2203
2204        if let MirScalarExpr::CallUnary {
2205            func: UnaryFunc::Not(scalar_func::Not),
2206            expr,
2207        } = predicate
2208        {
2209            if let MirScalarExpr::CallUnary {
2210                func: UnaryFunc::IsNull(scalar_func::IsNull),
2211                expr,
2212            } = &**expr
2213            {
2214                if let MirScalarExpr::Column(c, _name) = &**expr {
2215                    nonnull_required_columns.insert(*c);
2216                }
2217            }
2218        }
2219    }
2220
2221    nonnull_required_columns
2222}
2223
2224impl CollectionPlan for MirRelationExpr {
2225    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2226    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2227    ///
2228    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2229    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2230    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2231        if let MirRelationExpr::Get {
2232            id: Id::Global(id), ..
2233        } = self
2234        {
2235            out.insert(*id);
2236        }
2237        self.visit_children(|expr| expr.depends_on_into(out))
2238    }
2239}
2240
2241impl MirRelationExpr {
2242    /// Iterates through references to child expressions.
2243    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2244        let mut first = None;
2245        let mut second = None;
2246        let mut rest = None;
2247        let mut last = None;
2248
2249        use MirRelationExpr::*;
2250        match self {
2251            Constant { .. } | Get { .. } => (),
2252            Let { value, body, .. } => {
2253                first = Some(&**value);
2254                second = Some(&**body);
2255            }
2256            LetRec { values, body, .. } => {
2257                rest = Some(values);
2258                last = Some(&**body);
2259            }
2260            Project { input, .. }
2261            | Map { input, .. }
2262            | FlatMap { input, .. }
2263            | Filter { input, .. }
2264            | Reduce { input, .. }
2265            | TopK { input, .. }
2266            | Negate { input }
2267            | Threshold { input }
2268            | ArrangeBy { input, .. } => {
2269                first = Some(&**input);
2270            }
2271            Join { inputs, .. } => {
2272                rest = Some(inputs);
2273            }
2274            Union { base, inputs } => {
2275                first = Some(&**base);
2276                rest = Some(inputs);
2277            }
2278        }
2279
2280        first
2281            .into_iter()
2282            .chain(second)
2283            .chain(rest.into_iter().flatten())
2284            .chain(last)
2285    }
2286
2287    /// Iterates through mutable references to child expressions.
2288    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2289        let mut first = None;
2290        let mut second = None;
2291        let mut rest = None;
2292        let mut last = None;
2293
2294        use MirRelationExpr::*;
2295        match self {
2296            Constant { .. } | Get { .. } => (),
2297            Let { value, body, .. } => {
2298                first = Some(&mut **value);
2299                second = Some(&mut **body);
2300            }
2301            LetRec { values, body, .. } => {
2302                rest = Some(values);
2303                last = Some(&mut **body);
2304            }
2305            Project { input, .. }
2306            | Map { input, .. }
2307            | FlatMap { input, .. }
2308            | Filter { input, .. }
2309            | Reduce { input, .. }
2310            | TopK { input, .. }
2311            | Negate { input }
2312            | Threshold { input }
2313            | ArrangeBy { input, .. } => {
2314                first = Some(&mut **input);
2315            }
2316            Join { inputs, .. } => {
2317                rest = Some(inputs);
2318            }
2319            Union { base, inputs } => {
2320                first = Some(&mut **base);
2321                rest = Some(inputs);
2322            }
2323        }
2324
2325        first
2326            .into_iter()
2327            .chain(second)
2328            .chain(rest.into_iter().flatten())
2329            .chain(last)
2330    }
2331
2332    /// Iterative pre-order visitor.
2333    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2334        let mut worklist = vec![self];
2335        while let Some(expr) = worklist.pop() {
2336            f(expr);
2337            worklist.extend(expr.children().rev());
2338        }
2339    }
2340
2341    /// Iterative pre-order visitor.
2342    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2343        let mut worklist = vec![self];
2344        while let Some(expr) = worklist.pop() {
2345            f(expr);
2346            worklist.extend(expr.children_mut().rev());
2347        }
2348    }
2349
2350    /// Return a vector of references to the subtrees of this expression
2351    /// in post-visit order (the last element is `&self`).
2352    pub fn post_order_vec(&self) -> Vec<&Self> {
2353        let mut stack = vec![self];
2354        let mut result = vec![];
2355        while let Some(expr) = stack.pop() {
2356            result.push(expr);
2357            stack.extend(expr.children());
2358        }
2359        result.reverse();
2360        result
2361    }
2362}
2363
2364impl VisitChildren<Self> for MirRelationExpr {
2365    fn visit_children<F>(&self, mut f: F)
2366    where
2367        F: FnMut(&Self),
2368    {
2369        for child in self.children() {
2370            f(child)
2371        }
2372    }
2373
2374    fn visit_mut_children<F>(&mut self, mut f: F)
2375    where
2376        F: FnMut(&mut Self),
2377    {
2378        for child in self.children_mut() {
2379            f(child)
2380        }
2381    }
2382
2383    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2384    where
2385        F: FnMut(&Self) -> Result<(), E>,
2386        E: From<RecursionLimitError>,
2387    {
2388        for child in self.children() {
2389            f(child)?
2390        }
2391        Ok(())
2392    }
2393
2394    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2395    where
2396        F: FnMut(&mut Self) -> Result<(), E>,
2397        E: From<RecursionLimitError>,
2398    {
2399        for child in self.children_mut() {
2400            f(child)?
2401        }
2402        Ok(())
2403    }
2404}
2405
2406/// Specification for an ordering by a column.
2407#[derive(
2408    Debug,
2409    Clone,
2410    Copy,
2411    Eq,
2412    PartialEq,
2413    Ord,
2414    PartialOrd,
2415    Serialize,
2416    Deserialize,
2417    Hash,
2418    MzReflect
2419)]
2420pub struct ColumnOrder {
2421    /// The column index.
2422    pub column: usize,
2423    /// Whether to sort in descending order.
2424    #[serde(default)]
2425    pub desc: bool,
2426    /// Whether to sort nulls last.
2427    #[serde(default)]
2428    pub nulls_last: bool,
2429}
2430
2431impl Columnation for ColumnOrder {
2432    type InnerRegion = CopyRegion<Self>;
2433}
2434
2435impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2436where
2437    M: HumanizerMode,
2438{
2439    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2440        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2441        write!(
2442            f,
2443            "{} {} {}",
2444            self.child(&self.expr.column),
2445            if self.expr.desc { "desc" } else { "asc" },
2446            if self.expr.nulls_last {
2447                "nulls_last"
2448            } else {
2449                "nulls_first"
2450            },
2451        )
2452    }
2453}
2454
2455/// Describes an aggregation expression.
2456#[derive(
2457    Clone,
2458    Debug,
2459    Eq,
2460    PartialEq,
2461    Ord,
2462    PartialOrd,
2463    Serialize,
2464    Deserialize,
2465    Hash,
2466    MzReflect
2467)]
2468pub struct AggregateExpr {
2469    /// Names the aggregation function.
2470    pub func: AggregateFunc,
2471    /// An expression which extracts from each row the input to `func`.
2472    pub expr: MirScalarExpr,
2473    /// Should the aggregation be applied only to distinct results in each group.
2474    #[serde(default)]
2475    pub distinct: bool,
2476}
2477
2478impl AggregateExpr {
2479    /// Computes the type of this `AggregateExpr`.
2480    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2481        self.func.output_sql_type(self.expr.sql_typ(column_types))
2482    }
2483
2484    /// Computes the type of this `AggregateExpr`.
2485    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2486        self.func.output_type(self.expr.typ(column_types))
2487    }
2488
2489    /// Returns whether the expression has a constant result.
2490    pub fn is_constant(&self) -> bool {
2491        match self.func {
2492            AggregateFunc::MaxInt16
2493            | AggregateFunc::MaxInt32
2494            | AggregateFunc::MaxInt64
2495            | AggregateFunc::MaxUInt16
2496            | AggregateFunc::MaxUInt32
2497            | AggregateFunc::MaxUInt64
2498            | AggregateFunc::MaxMzTimestamp
2499            | AggregateFunc::MaxFloat32
2500            | AggregateFunc::MaxFloat64
2501            | AggregateFunc::MaxBool
2502            | AggregateFunc::MaxString
2503            | AggregateFunc::MaxDate
2504            | AggregateFunc::MaxTimestamp
2505            | AggregateFunc::MaxTimestampTz
2506            | AggregateFunc::MinInt16
2507            | AggregateFunc::MinInt32
2508            | AggregateFunc::MinInt64
2509            | AggregateFunc::MinUInt16
2510            | AggregateFunc::MinUInt32
2511            | AggregateFunc::MinUInt64
2512            | AggregateFunc::MinMzTimestamp
2513            | AggregateFunc::MinFloat32
2514            | AggregateFunc::MinFloat64
2515            | AggregateFunc::MinBool
2516            | AggregateFunc::MinString
2517            | AggregateFunc::MinDate
2518            | AggregateFunc::MinTimestamp
2519            | AggregateFunc::MinTimestampTz
2520            | AggregateFunc::Any
2521            | AggregateFunc::All
2522            | AggregateFunc::Dummy => self.expr.is_literal(),
2523            AggregateFunc::Count => self.expr.is_literal_null(),
2524            _ => self.expr.is_literal_err(),
2525        }
2526    }
2527
2528    /// Returns an expression that computes `self` on a group that has exactly one row.
2529    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2530    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2531    pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2532        match &self.func {
2533            // Count is one if non-null, and zero if null.
2534            AggregateFunc::Count => self
2535                .expr
2536                .clone()
2537                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2538                .if_then_else(
2539                    MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2540                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2541                ),
2542
2543            // SumInt16 takes Int16s as input, but outputs Int64s.
2544            AggregateFunc::SumInt16 => self
2545                .expr
2546                .clone()
2547                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2548
2549            // SumInt32 takes Int32s as input, but outputs Int64s.
2550            AggregateFunc::SumInt32 => self
2551                .expr
2552                .clone()
2553                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2554
2555            // SumInt64 takes Int64s as input, but outputs numerics.
2556            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2557                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2558            )),
2559
2560            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2561            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2562                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2563            ),
2564
2565            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2566            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2567                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2568            ),
2569
2570            // SumUInt64 takes UInt64s as input, but outputs numerics.
2571            AggregateFunc::SumUInt64 => {
2572                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2573                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2574                ))
2575            }
2576
2577            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2578            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2579                JsonbBuildArray,
2580                vec![
2581                    self.expr
2582                        .clone()
2583                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2584                ],
2585            ),
2586
2587            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2588            AggregateFunc::JsonbObjectAgg { .. } => {
2589                let record = self
2590                    .expr
2591                    .clone()
2592                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2593                MirScalarExpr::call_variadic(
2594                    JsonbBuildObject,
2595                    (0..2)
2596                        .map(|i| {
2597                            record
2598                                .clone()
2599                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2600                        })
2601                        .collect(),
2602                )
2603            }
2604
2605            AggregateFunc::MapAgg { value_type, .. } => {
2606                let record = self
2607                    .expr
2608                    .clone()
2609                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2610                MirScalarExpr::call_variadic(
2611                    MapBuild {
2612                        value_type: value_type.clone(),
2613                    },
2614                    (0..2)
2615                        .map(|i| {
2616                            record
2617                                .clone()
2618                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2619                        })
2620                        .collect(),
2621                )
2622            }
2623
2624            // StringAgg takes nested records of strings and outputs a string
2625            AggregateFunc::StringAgg { .. } => self
2626                .expr
2627                .clone()
2628                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2629                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2630
2631            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2632            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2633                .expr
2634                .clone()
2635                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2636
2637            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2638            AggregateFunc::RowNumber { .. } => {
2639                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2640            }
2641            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2642            AggregateFunc::DenseRank { .. } => {
2643                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2644            }
2645
2646            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2647            AggregateFunc::LagLead { lag_lead, .. } => {
2648                let tuple = self
2649                    .expr
2650                    .clone()
2651                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2652
2653                // Get the overall return type
2654                let return_type_with_orig_row = self
2655                    .typ(input_type)
2656                    .scalar_type
2657                    .unwrap_list_element_type()
2658                    .clone();
2659                let lag_lead_return_type =
2660                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2661
2662                // Extract the original row
2663                let original_row = tuple
2664                    .clone()
2665                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2666
2667                // Extract the encoded args
2668                let encoded_args =
2669                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2670
2671                let (result_expr, column_name) =
2672                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2673
2674                MirScalarExpr::call_variadic(
2675                    ListCreate {
2676                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2677                    },
2678                    vec![MirScalarExpr::call_variadic(
2679                        RecordCreate {
2680                            field_names: vec![column_name, ColumnName::from("?record?")],
2681                        },
2682                        vec![result_expr, original_row],
2683                    )],
2684                )
2685            }
2686
2687            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2688            AggregateFunc::FirstValue { window_frame, .. } => {
2689                let tuple = self
2690                    .expr
2691                    .clone()
2692                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2693
2694                // Get the overall return type
2695                let return_type_with_orig_row = self
2696                    .typ(input_type)
2697                    .scalar_type
2698                    .unwrap_list_element_type()
2699                    .clone();
2700                let first_value_return_type =
2701                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2702
2703                // Extract the original row
2704                let original_row = tuple
2705                    .clone()
2706                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2707
2708                // Extract the input value
2709                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2710
2711                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2712                    window_frame,
2713                    arg,
2714                    first_value_return_type,
2715                );
2716
2717                MirScalarExpr::call_variadic(
2718                    ListCreate {
2719                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2720                    },
2721                    vec![MirScalarExpr::call_variadic(
2722                        RecordCreate {
2723                            field_names: vec![column_name, ColumnName::from("?record?")],
2724                        },
2725                        vec![result_expr, original_row],
2726                    )],
2727                )
2728            }
2729
2730            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2731            AggregateFunc::LastValue { window_frame, .. } => {
2732                let tuple = self
2733                    .expr
2734                    .clone()
2735                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2736
2737                // Get the overall return type
2738                let return_type_with_orig_row = self
2739                    .typ(input_type)
2740                    .scalar_type
2741                    .unwrap_list_element_type()
2742                    .clone();
2743                let last_value_return_type =
2744                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2745
2746                // Extract the original row
2747                let original_row = tuple
2748                    .clone()
2749                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2750
2751                // Extract the input value
2752                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2753
2754                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2755                    window_frame,
2756                    arg,
2757                    last_value_return_type,
2758                );
2759
2760                MirScalarExpr::call_variadic(
2761                    ListCreate {
2762                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2763                    },
2764                    vec![MirScalarExpr::call_variadic(
2765                        RecordCreate {
2766                            field_names: vec![column_name, ColumnName::from("?record?")],
2767                        },
2768                        vec![result_expr, original_row],
2769                    )],
2770                )
2771            }
2772
2773            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2774            // See an example MIR in `window_func_applied_to`.
2775            AggregateFunc::WindowAggregate {
2776                wrapped_aggregate,
2777                window_frame,
2778                order_by: _,
2779            } => {
2780                // TODO: deduplicate code between the various window function cases.
2781
2782                let tuple = self
2783                    .expr
2784                    .clone()
2785                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2786
2787                // Get the overall return type
2788                let return_type = self
2789                    .typ(input_type)
2790                    .scalar_type
2791                    .unwrap_list_element_type()
2792                    .clone();
2793                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2794
2795                // Extract the original row
2796                let original_row = tuple
2797                    .clone()
2798                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2799
2800                // Extract the input value
2801                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2802
2803                let (result, column_name) = Self::on_unique_window_agg(
2804                    window_frame,
2805                    arg_expr,
2806                    input_type,
2807                    window_agg_return_type,
2808                    wrapped_aggregate,
2809                );
2810
2811                MirScalarExpr::call_variadic(
2812                    ListCreate {
2813                        elem_type: SqlScalarType::from_repr(&return_type),
2814                    },
2815                    vec![MirScalarExpr::call_variadic(
2816                        RecordCreate {
2817                            field_names: vec![column_name, ColumnName::from("?record?")],
2818                        },
2819                        vec![result, original_row],
2820                    )],
2821                )
2822            }
2823
2824            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2825            AggregateFunc::FusedWindowAggregate {
2826                wrapped_aggregates,
2827                order_by: _,
2828                window_frame,
2829            } => {
2830                // Throw away OrderByExprs
2831                let tuple = self
2832                    .expr
2833                    .clone()
2834                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2835
2836                // Extract the original row
2837                let original_row = tuple
2838                    .clone()
2839                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2840
2841                // Extract the args of the fused call
2842                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2843
2844                let return_type_with_orig_row = self
2845                    .typ(input_type)
2846                    .scalar_type
2847                    .unwrap_list_element_type()
2848                    .clone();
2849
2850                let all_func_return_types =
2851                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2852                let mut func_result_exprs = Vec::new();
2853                let mut col_names = Vec::new();
2854                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2855                    let arg = all_args
2856                        .clone()
2857                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2858                    let return_type =
2859                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2860                    let (result, column_name) = Self::on_unique_window_agg(
2861                        window_frame,
2862                        arg,
2863                        input_type,
2864                        return_type,
2865                        wrapped_aggr,
2866                    );
2867                    func_result_exprs.push(result);
2868                    col_names.push(column_name);
2869                }
2870
2871                MirScalarExpr::call_variadic(
2872                    ListCreate {
2873                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2874                    },
2875                    vec![MirScalarExpr::call_variadic(
2876                        RecordCreate {
2877                            field_names: vec![
2878                                ColumnName::from("?fused_window_aggr?"),
2879                                ColumnName::from("?record?"),
2880                            ],
2881                        },
2882                        vec![
2883                            MirScalarExpr::call_variadic(
2884                                RecordCreate {
2885                                    field_names: col_names,
2886                                },
2887                                func_result_exprs,
2888                            ),
2889                            original_row,
2890                        ],
2891                    )],
2892                )
2893            }
2894
2895            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2896            AggregateFunc::FusedValueWindowFunc {
2897                funcs,
2898                order_by: outer_order_by,
2899            } => {
2900                // Throw away OrderByExprs
2901                let tuple = self
2902                    .expr
2903                    .clone()
2904                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2905
2906                // Extract the original row
2907                let original_row = tuple
2908                    .clone()
2909                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2910
2911                // Extract the encoded args of the fused call
2912                let all_encoded_args =
2913                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2914
2915                let return_type_with_orig_row = self
2916                    .typ(input_type)
2917                    .scalar_type
2918                    .unwrap_list_element_type()
2919                    .clone();
2920
2921                let all_func_return_types =
2922                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2923                let mut func_result_exprs = Vec::new();
2924                let mut col_names = Vec::new();
2925                for (idx, func) in funcs.iter().enumerate() {
2926                    let args_for_func = all_encoded_args
2927                        .clone()
2928                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2929                    let return_type_for_func =
2930                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2931                    let (result, column_name) = match func {
2932                        AggregateFunc::LagLead {
2933                            lag_lead,
2934                            order_by,
2935                            ignore_nulls: _,
2936                        } => {
2937                            assert_eq!(order_by, outer_order_by);
2938                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2939                        }
2940                        AggregateFunc::FirstValue {
2941                            window_frame,
2942                            order_by,
2943                        } => {
2944                            assert_eq!(order_by, outer_order_by);
2945                            Self::on_unique_first_value_last_value(
2946                                window_frame,
2947                                args_for_func,
2948                                return_type_for_func,
2949                            )
2950                        }
2951                        AggregateFunc::LastValue {
2952                            window_frame,
2953                            order_by,
2954                        } => {
2955                            assert_eq!(order_by, outer_order_by);
2956                            Self::on_unique_first_value_last_value(
2957                                window_frame,
2958                                args_for_func,
2959                                return_type_for_func,
2960                            )
2961                        }
2962                        _ => panic!("unknown function in FusedValueWindowFunc"),
2963                    };
2964                    func_result_exprs.push(result);
2965                    col_names.push(column_name);
2966                }
2967
2968                MirScalarExpr::call_variadic(
2969                    ListCreate {
2970                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2971                    },
2972                    vec![MirScalarExpr::call_variadic(
2973                        RecordCreate {
2974                            field_names: vec![
2975                                ColumnName::from("?fused_value_window_func?"),
2976                                ColumnName::from("?record?"),
2977                            ],
2978                        },
2979                        vec![
2980                            MirScalarExpr::call_variadic(
2981                                RecordCreate {
2982                                    field_names: col_names,
2983                                },
2984                                func_result_exprs,
2985                            ),
2986                            original_row,
2987                        ],
2988                    )],
2989                )
2990            }
2991
2992            // All other variants should return the argument to the aggregation.
2993            AggregateFunc::MaxNumeric
2994            | AggregateFunc::MaxInt16
2995            | AggregateFunc::MaxInt32
2996            | AggregateFunc::MaxInt64
2997            | AggregateFunc::MaxUInt16
2998            | AggregateFunc::MaxUInt32
2999            | AggregateFunc::MaxUInt64
3000            | AggregateFunc::MaxMzTimestamp
3001            | AggregateFunc::MaxFloat32
3002            | AggregateFunc::MaxFloat64
3003            | AggregateFunc::MaxBool
3004            | AggregateFunc::MaxString
3005            | AggregateFunc::MaxDate
3006            | AggregateFunc::MaxTimestamp
3007            | AggregateFunc::MaxTimestampTz
3008            | AggregateFunc::MaxInterval
3009            | AggregateFunc::MaxTime
3010            | AggregateFunc::MinNumeric
3011            | AggregateFunc::MinInt16
3012            | AggregateFunc::MinInt32
3013            | AggregateFunc::MinInt64
3014            | AggregateFunc::MinUInt16
3015            | AggregateFunc::MinUInt32
3016            | AggregateFunc::MinUInt64
3017            | AggregateFunc::MinMzTimestamp
3018            | AggregateFunc::MinFloat32
3019            | AggregateFunc::MinFloat64
3020            | AggregateFunc::MinBool
3021            | AggregateFunc::MinString
3022            | AggregateFunc::MinDate
3023            | AggregateFunc::MinTimestamp
3024            | AggregateFunc::MinTimestampTz
3025            | AggregateFunc::MinInterval
3026            | AggregateFunc::MinTime
3027            | AggregateFunc::SumFloat32
3028            | AggregateFunc::SumFloat64
3029            | AggregateFunc::SumNumeric
3030            | AggregateFunc::Any
3031            | AggregateFunc::All
3032            | AggregateFunc::Dummy => self.expr.clone(),
3033        }
3034    }
3035
3036    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3037    fn on_unique_ranking_window_funcs(
3038        &self,
3039        input_type: &[ReprColumnType],
3040        col_name: &str,
3041    ) -> MirScalarExpr {
3042        let sql_input_type: Vec<SqlColumnType> =
3043            input_type.iter().map(SqlColumnType::from_repr).collect();
3044        let list = self
3045            .expr
3046            .clone()
3047            // extract the list within the record
3048            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3049
3050        // extract the expression within the list
3051        let record = MirScalarExpr::call_variadic(
3052            ListIndex,
3053            vec![
3054                list,
3055                MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3056            ],
3057        );
3058
3059        MirScalarExpr::call_variadic(
3060            ListCreate {
3061                elem_type: self
3062                    .sql_typ(&sql_input_type)
3063                    .scalar_type
3064                    .unwrap_list_element_type()
3065                    .clone(),
3066            },
3067            vec![MirScalarExpr::call_variadic(
3068                RecordCreate {
3069                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3070                },
3071                vec![
3072                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3073                    record,
3074                ],
3075            )],
3076        )
3077    }
3078
3079    /// `on_unique` for `lag` and `lead`
3080    fn on_unique_lag_lead(
3081        lag_lead: &LagLeadType,
3082        encoded_args: MirScalarExpr,
3083        return_type: ReprScalarType,
3084    ) -> (MirScalarExpr, ColumnName) {
3085        let expr = encoded_args
3086            .clone()
3087            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3088        let offset = encoded_args
3089            .clone()
3090            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3091        let default_value =
3092            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3093
3094        // In this case, the window always has only one element, so if the offset is not null and
3095        // not zero, the default value should be returned instead.
3096        let value = offset
3097            .clone()
3098            .call_binary(
3099                MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3100                crate::func::Eq,
3101            )
3102            .if_then_else(expr, default_value);
3103        let result_expr = offset
3104            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3105            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3106
3107        let column_name = ColumnName::from(match lag_lead {
3108            LagLeadType::Lag => "?lag?",
3109            LagLeadType::Lead => "?lead?",
3110        });
3111
3112        (result_expr, column_name)
3113    }
3114
3115    /// `on_unique` for `first_value` and `last_value`
3116    fn on_unique_first_value_last_value(
3117        window_frame: &WindowFrame,
3118        arg: MirScalarExpr,
3119        return_type: ReprScalarType,
3120    ) -> (MirScalarExpr, ColumnName) {
3121        // If the window frame includes the current (single) row, return its value, null otherwise
3122        let result_expr = if window_frame.includes_current_row() {
3123            arg
3124        } else {
3125            MirScalarExpr::literal_null(return_type)
3126        };
3127        (result_expr, ColumnName::from("?first_value?"))
3128    }
3129
3130    /// `on_unique` for window aggregations
3131    fn on_unique_window_agg(
3132        window_frame: &WindowFrame,
3133        arg_expr: MirScalarExpr,
3134        input_type: &[ReprColumnType],
3135        return_type: ReprScalarType,
3136        wrapped_aggr: &AggregateFunc,
3137    ) -> (MirScalarExpr, ColumnName) {
3138        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3139        // that row. Otherwise, return the default value for the aggregate.
3140        let result_expr = if window_frame.includes_current_row() {
3141            AggregateExpr {
3142                func: wrapped_aggr.clone(),
3143                expr: arg_expr,
3144                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3145            }
3146            .on_unique(input_type)
3147        } else {
3148            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3149        };
3150        (result_expr, ColumnName::from("?window_agg?"))
3151    }
3152
3153    /// Returns whether the expression is COUNT(*) or not.  Note that
3154    /// when we define the count builtin in sql::func, we convert
3155    /// COUNT(*) to COUNT(true), making it indistinguishable from
3156    /// literal COUNT(true), but we prefer to consider this as the
3157    /// former.
3158    ///
3159    /// (HIR has the same `is_count_asterisk`.)
3160    pub fn is_count_asterisk(&self) -> bool {
3161        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3162    }
3163}
3164
3165/// Describe a join implementation in dataflow.
3166#[derive(
3167    Clone,
3168    Debug,
3169    Eq,
3170    PartialEq,
3171    Ord,
3172    PartialOrd,
3173    Serialize,
3174    Deserialize,
3175    Hash,
3176    MzReflect
3177)]
3178pub enum JoinImplementation {
3179    /// Perform a sequence of binary differential dataflow joins.
3180    ///
3181    /// The first argument indicates
3182    /// 1) the index of the starting collection,
3183    /// 2) if it should be arranged, the keys to arrange it by, and
3184    /// 3) the characteristics of the starting collection (for EXPLAINing).
3185    /// The sequence that follows lists other relation indexes, and the key for
3186    /// the arrangement we should use when joining it in.
3187    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3188    /// were used for join ordering.
3189    ///
3190    /// Each collection index should occur exactly once, either as the starting collection
3191    /// or somewhere in the list.
3192    Differential(
3193        (
3194            usize,
3195            Option<Vec<MirScalarExpr>>,
3196            Option<JoinInputCharacteristics>,
3197        ),
3198        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3199    ),
3200    /// Perform independent delta query dataflows for each input.
3201    ///
3202    /// The argument is a sequence of plans, for the input collections in order.
3203    /// Each plan starts from the corresponding index, and then in sequence joins
3204    /// against collections identified by index and with the specified arrangement key.
3205    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3206    /// used for join ordering.
3207    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3208    /// Join a user-created index with a constant collection to speed up the evaluation of a
3209    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3210    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3211    /// to represent it in MIR, because the fast path detection wants to match on this.
3212    ///
3213    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3214    IndexedFilter(
3215        GlobalId,
3216        GlobalId,
3217        Vec<MirScalarExpr>,
3218        #[mzreflect(ignore)] Vec<Row>,
3219    ),
3220    /// No implementation yet selected.
3221    Unimplemented,
3222}
3223
3224impl Default for JoinImplementation {
3225    fn default() -> Self {
3226        JoinImplementation::Unimplemented
3227    }
3228}
3229
3230impl JoinImplementation {
3231    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3232    pub fn is_implemented(&self) -> bool {
3233        match self {
3234            Self::Unimplemented => false,
3235            _ => true,
3236        }
3237    }
3238
3239    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3240    pub fn name(&self) -> Option<&'static str> {
3241        match self {
3242            Self::Differential(..) => Some("differential"),
3243            Self::DeltaQuery(..) => Some("delta"),
3244            Self::IndexedFilter(..) => Some("indexed_filter"),
3245            Self::Unimplemented => None,
3246        }
3247    }
3248}
3249
3250/// Characteristics of a join order candidate collection.
3251///
3252/// A candidate is described by a collection and a key, and may have various liabilities.
3253/// Primarily, the candidate may risk substantial inflation of records, which is something
3254/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3255/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3256/// collections in the interest of consistent tie-breaking. For more characteristics, see
3257/// comments on individual fields.
3258///
3259/// This has more than one version. `new` instantiates the appropriate version based on a
3260/// feature flag.
3261#[derive(
3262    Eq,
3263    PartialEq,
3264    Ord,
3265    PartialOrd,
3266    Debug,
3267    Clone,
3268    Serialize,
3269    Deserialize,
3270    Hash,
3271    MzReflect
3272)]
3273pub enum JoinInputCharacteristics {
3274    /// Old version, with `enable_join_prioritize_arranged` turned off.
3275    V1(JoinInputCharacteristicsV1),
3276    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3277    V2(JoinInputCharacteristicsV2),
3278}
3279
3280impl JoinInputCharacteristics {
3281    /// Creates a new instance with the given characteristics.
3282    pub fn new(
3283        unique_key: bool,
3284        key_length: usize,
3285        arranged: bool,
3286        cardinality: Option<usize>,
3287        filters: FilterCharacteristics,
3288        input: usize,
3289        enable_join_prioritize_arranged: bool,
3290    ) -> Self {
3291        if enable_join_prioritize_arranged {
3292            Self::V2(JoinInputCharacteristicsV2::new(
3293                unique_key,
3294                key_length,
3295                arranged,
3296                cardinality,
3297                filters,
3298                input,
3299            ))
3300        } else {
3301            Self::V1(JoinInputCharacteristicsV1::new(
3302                unique_key,
3303                key_length,
3304                arranged,
3305                cardinality,
3306                filters,
3307                input,
3308            ))
3309        }
3310    }
3311
3312    /// Turns the instance into a String to be printed in EXPLAIN.
3313    pub fn explain(&self) -> String {
3314        match self {
3315            Self::V1(jic) => jic.explain(),
3316            Self::V2(jic) => jic.explain(),
3317        }
3318    }
3319
3320    /// Whether the join input described by `self` is arranged.
3321    pub fn arranged(&self) -> bool {
3322        match self {
3323            Self::V1(jic) => jic.arranged,
3324            Self::V2(jic) => jic.arranged,
3325        }
3326    }
3327
3328    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3329    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3330        match self {
3331            Self::V1(jic) => &mut jic.filters,
3332            Self::V2(jic) => &mut jic.filters,
3333        }
3334    }
3335}
3336
3337/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3338#[derive(
3339    Eq,
3340    PartialEq,
3341    Ord,
3342    PartialOrd,
3343    Debug,
3344    Clone,
3345    Serialize,
3346    Deserialize,
3347    Hash,
3348    MzReflect
3349)]
3350pub struct JoinInputCharacteristicsV2 {
3351    /// An excellent indication that record count will not increase.
3352    pub unique_key: bool,
3353    /// Cross joins are bad.
3354    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3355    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3356    /// but otherwise `key_length` is less important than `arranged`.)
3357    pub not_cross: bool,
3358    /// Indicates that there will be no additional in-memory footprint.
3359    pub arranged: bool,
3360    /// A weaker signal that record count will not increase.
3361    pub key_length: usize,
3362    /// Estimated cardinality (lower is better)
3363    pub cardinality: Option<std::cmp::Reverse<usize>>,
3364    /// Characteristics of the filter that is applied at this input.
3365    pub filters: FilterCharacteristics,
3366    /// We want to prefer input earlier in the input list, for stability of ordering.
3367    pub input: std::cmp::Reverse<usize>,
3368}
3369
3370impl JoinInputCharacteristicsV2 {
3371    /// Creates a new instance with the given characteristics.
3372    pub fn new(
3373        unique_key: bool,
3374        key_length: usize,
3375        arranged: bool,
3376        cardinality: Option<usize>,
3377        filters: FilterCharacteristics,
3378        input: usize,
3379    ) -> Self {
3380        Self {
3381            unique_key,
3382            not_cross: key_length > 0,
3383            arranged,
3384            key_length,
3385            cardinality: cardinality.map(std::cmp::Reverse),
3386            filters,
3387            input: std::cmp::Reverse(input),
3388        }
3389    }
3390
3391    /// Turns the instance into a String to be printed in EXPLAIN.
3392    pub fn explain(&self) -> String {
3393        let mut e = "".to_owned();
3394        if self.unique_key {
3395            e.push_str("U");
3396        }
3397        // Don't need to print `not_cross`, because that is visible in the printed key.
3398        // if !self.not_cross {
3399        //     e.push_str("C");
3400        // }
3401        for _ in 0..self.key_length {
3402            e.push_str("K");
3403        }
3404        if self.arranged {
3405            e.push_str("A");
3406        }
3407        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3408            e.push_str(&format!("|{cardinality}|"));
3409        }
3410        e.push_str(&self.filters.explain());
3411        e
3412    }
3413}
3414
3415/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3416#[derive(
3417    Eq,
3418    PartialEq,
3419    Ord,
3420    PartialOrd,
3421    Debug,
3422    Clone,
3423    Serialize,
3424    Deserialize,
3425    Hash,
3426    MzReflect
3427)]
3428pub struct JoinInputCharacteristicsV1 {
3429    /// An excellent indication that record count will not increase.
3430    pub unique_key: bool,
3431    /// A weaker signal that record count will not increase.
3432    pub key_length: usize,
3433    /// Indicates that there will be no additional in-memory footprint.
3434    pub arranged: bool,
3435    /// Estimated cardinality (lower is better)
3436    pub cardinality: Option<std::cmp::Reverse<usize>>,
3437    /// Characteristics of the filter that is applied at this input.
3438    pub filters: FilterCharacteristics,
3439    /// We want to prefer input earlier in the input list, for stability of ordering.
3440    pub input: std::cmp::Reverse<usize>,
3441}
3442
3443impl JoinInputCharacteristicsV1 {
3444    /// Creates a new instance with the given characteristics.
3445    pub fn new(
3446        unique_key: bool,
3447        key_length: usize,
3448        arranged: bool,
3449        cardinality: Option<usize>,
3450        filters: FilterCharacteristics,
3451        input: usize,
3452    ) -> Self {
3453        Self {
3454            unique_key,
3455            key_length,
3456            arranged,
3457            cardinality: cardinality.map(std::cmp::Reverse),
3458            filters,
3459            input: std::cmp::Reverse(input),
3460        }
3461    }
3462
3463    /// Turns the instance into a String to be printed in EXPLAIN.
3464    pub fn explain(&self) -> String {
3465        let mut e = "".to_owned();
3466        if self.unique_key {
3467            e.push_str("U");
3468        }
3469        for _ in 0..self.key_length {
3470            e.push_str("K");
3471        }
3472        if self.arranged {
3473            e.push_str("A");
3474        }
3475        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3476            e.push_str(&format!("|{cardinality}|"));
3477        }
3478        e.push_str(&self.filters.explain());
3479        e
3480    }
3481}
3482
3483/// Instructions for finishing the result of a query.
3484///
3485/// The primary reason for the existence of this structure and attendant code
3486/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3487/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3488/// multisets. But as it turns out, the same idea can be used to optimize
3489/// trivial peeks.
3490///
3491/// The generic parameters are for accommodating prepared statement parameters in
3492/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3493/// `bind_parameters` on them.
3494#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3495pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3496    /// Order rows by the given columns.
3497    pub order_by: Vec<ColumnOrder>,
3498    /// Include only as many rows (after offset).
3499    pub limit: Option<L>,
3500    /// Omit as many rows.
3501    pub offset: O,
3502    /// Include only given columns.
3503    pub project: Vec<usize>,
3504}
3505
3506impl<L> RowSetFinishing<L> {
3507    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3508    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3509        RowSetFinishing {
3510            order_by: Vec::new(),
3511            limit: None,
3512            offset: 0,
3513            project: (0..arity).collect(),
3514        }
3515    }
3516    /// True if the finishing does nothing to any result set.
3517    pub fn is_trivial(&self, arity: usize) -> bool {
3518        self.limit.is_none()
3519            && self.order_by.is_empty()
3520            && self.offset == 0
3521            && self.project.iter().copied().eq(0..arity)
3522    }
3523    /// True if the finishing does not require an ORDER BY.
3524    ///
3525    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3526    /// explicit ordering we will skip an arbitrary bag of elements and return
3527    /// the first arbitrary elements in the remaining bag. The result semantics
3528    /// are still correct but maybe surprising for some users.
3529    pub fn is_streamable(&self, arity: usize) -> bool {
3530        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3531    }
3532}
3533
3534impl RowSetFinishing<NonNeg<i64>, usize> {
3535    /// The number of rows needed from before the finishing to evaluate the finishing:
3536    /// offset + limit.
3537    ///
3538    /// If it returns None, then we need all the rows.
3539    pub fn num_rows_needed(&self) -> Option<usize> {
3540        self.limit
3541            .as_ref()
3542            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3543    }
3544}
3545
3546impl RowSetFinishing {
3547    /// Applies finishing actions to a [`RowCollection`], and reports the total
3548    /// time it took to run.
3549    ///
3550    /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3551    /// well as the size of the response in bytes.
3552    pub fn finish(
3553        &self,
3554        rows: RowCollection,
3555        max_result_size: u64,
3556        max_returned_query_size: Option<u64>,
3557        duration_histogram: &Histogram,
3558    ) -> Result<(RowCollectionIter, usize), String> {
3559        let now = Instant::now();
3560        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3561        let duration = now.elapsed();
3562        duration_histogram.observe(duration.as_secs_f64());
3563
3564        result
3565    }
3566
3567    /// Implementation for [`RowSetFinishing::finish`].
3568    fn finish_inner(
3569        &self,
3570        rows: RowCollection,
3571        max_result_size: u64,
3572        max_returned_query_size: Option<u64>,
3573    ) -> Result<(RowCollectionIter, usize), String> {
3574        // How much additional memory is required to make a sorted view.
3575        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3576        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3577
3578        // Bail if creating the sorted view would require us to use too much memory.
3579        if required_memory > usize::cast_from(max_result_size) {
3580            let max_bytes = ByteSize::b(max_result_size);
3581            return Err(format!("result exceeds max size of {max_bytes}",));
3582        }
3583
3584        let sorted_view = rows;
3585        let mut iter = sorted_view
3586            .into_row_iter()
3587            .apply_offset(self.offset)
3588            .with_projection(self.project.clone());
3589
3590        if let Some(limit) = self.limit {
3591            let limit = u64::from(limit);
3592            let limit = usize::cast_from(limit);
3593            iter = iter.with_limit(limit);
3594        };
3595
3596        // TODO(parkmycar): Re-think how we can calculate the total response size without
3597        // having to iterate through the entire collection of Rows, while still
3598        // respecting the LIMIT, OFFSET, and projections.
3599        //
3600        // Note: It feels a bit bad always calculating the response size, but we almost
3601        // always need it to either check the `max_returned_query_size`, or for reporting
3602        // in the query history.
3603        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3604
3605        // Bail if we would end up returning more data to the client than they can support.
3606        if let Some(max) = max_returned_query_size {
3607            if response_size > usize::cast_from(max) {
3608                let max_bytes = ByteSize::b(max);
3609                return Err(format!("result exceeds max size of {max_bytes}"));
3610            }
3611        }
3612
3613        Ok((iter, response_size))
3614    }
3615}
3616
3617/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3618/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3619/// on query result size.
3620#[derive(Debug)]
3621pub struct RowSetFinishingIncremental {
3622    /// Include only as many rows (after offset).
3623    pub remaining_limit: Option<usize>,
3624    /// Omit as many rows.
3625    pub remaining_offset: usize,
3626    /// The maximum allowed result size, as requested by the client.
3627    pub max_returned_query_size: Option<u64>,
3628    /// Tracks our remaining allowed budget for result size.
3629    pub remaining_max_returned_query_size: Option<u64>,
3630    /// Include only given columns.
3631    pub project: Vec<usize>,
3632}
3633
3634impl RowSetFinishingIncremental {
3635    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3636    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3637    /// `true`.
3638    ///
3639    /// # Panics
3640    ///
3641    /// Panics if the result is not streamable, that is it has an ORDER BY.
3642    pub fn new(
3643        offset: usize,
3644        limit: Option<NonNeg<i64>>,
3645        project: Vec<usize>,
3646        max_returned_query_size: Option<u64>,
3647    ) -> Self {
3648        let limit = limit.map(|l| {
3649            let l = u64::from(l);
3650            let l = usize::cast_from(l);
3651            l
3652        });
3653
3654        RowSetFinishingIncremental {
3655            remaining_limit: limit,
3656            remaining_offset: offset,
3657            max_returned_query_size,
3658            remaining_max_returned_query_size: max_returned_query_size,
3659            project,
3660        }
3661    }
3662
3663    /// Applies finishing actions to the given [`RowCollection`], and reports
3664    /// the total time it took to run.
3665    ///
3666    /// Returns a [`RowCollectionIter`] that contains all of the response
3667    /// data.
3668    pub fn finish_incremental(
3669        &mut self,
3670        rows: RowCollection,
3671        max_result_size: u64,
3672        duration_histogram: &Histogram,
3673    ) -> Result<RowCollectionIter, String> {
3674        let now = Instant::now();
3675        let result = self.finish_incremental_inner(rows, max_result_size);
3676        let duration = now.elapsed();
3677        duration_histogram.observe(duration.as_secs_f64());
3678
3679        result
3680    }
3681
3682    fn finish_incremental_inner(
3683        &mut self,
3684        rows: RowCollection,
3685        max_result_size: u64,
3686    ) -> Result<RowCollectionIter, String> {
3687        // How much additional memory is required to make a sorted view.
3688        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3689        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3690
3691        // Bail if creating the sorted view would require us to use too much memory.
3692        if required_memory > usize::cast_from(max_result_size) {
3693            let max_bytes = ByteSize::b(max_result_size);
3694            return Err(format!("total result exceeds max size of {max_bytes}",));
3695        }
3696
3697        let batch_num_rows = rows.count();
3698
3699        let sorted_view = rows;
3700        let mut iter = sorted_view
3701            .into_row_iter()
3702            .apply_offset(self.remaining_offset)
3703            .with_projection(self.project.clone());
3704
3705        if let Some(limit) = self.remaining_limit {
3706            iter = iter.with_limit(limit);
3707        };
3708
3709        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3710        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3711            *remaining_limit -= iter.count();
3712        }
3713
3714        // TODO(parkmycar): Re-think how we can calculate the total response size without
3715        // having to iterate through the entire collection of Rows, while still
3716        // respecting the LIMIT, OFFSET, and projections.
3717        //
3718        // Note: It feels a bit bad always calculating the response size, but we almost
3719        // always need it to either check the `max_returned_query_size`, or for reporting
3720        // in the query history.
3721        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3722
3723        // Bail if we would end up returning more data to the client than they can support.
3724        if let Some(max) = self.remaining_max_returned_query_size {
3725            if response_size > usize::cast_from(max) {
3726                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3727                return Err(format!("total result exceeds max size of {max_bytes}"));
3728            }
3729        }
3730
3731        Ok(iter)
3732    }
3733}
3734
3735/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3736/// ordering, call `tiebreaker`.
3737pub fn compare_columns<F>(
3738    order: &[ColumnOrder],
3739    left: &[Datum],
3740    right: &[Datum],
3741    tiebreaker: F,
3742) -> Ordering
3743where
3744    F: Fn() -> Ordering,
3745{
3746    for order in order {
3747        let cmp = match (&left[order.column], &right[order.column]) {
3748            (Datum::Null, Datum::Null) => Ordering::Equal,
3749            (Datum::Null, _) => {
3750                if order.nulls_last {
3751                    Ordering::Greater
3752                } else {
3753                    Ordering::Less
3754                }
3755            }
3756            (_, Datum::Null) => {
3757                if order.nulls_last {
3758                    Ordering::Less
3759                } else {
3760                    Ordering::Greater
3761                }
3762            }
3763            (lval, rval) => {
3764                if order.desc {
3765                    rval.cmp(lval)
3766                } else {
3767                    lval.cmp(rval)
3768                }
3769            }
3770        };
3771        if cmp != Ordering::Equal {
3772            return cmp;
3773        }
3774    }
3775    tiebreaker()
3776}
3777
3778/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3779/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3780///
3781/// Window frames define a subset of the partition , and only a subset of
3782/// window functions make use of the window frame.
3783#[derive(
3784    Debug,
3785    Clone,
3786    Eq,
3787    PartialEq,
3788    Ord,
3789    PartialOrd,
3790    Serialize,
3791    Deserialize,
3792    Hash,
3793    MzReflect
3794)]
3795pub struct WindowFrame {
3796    /// ROWS, RANGE or GROUPS
3797    pub units: WindowFrameUnits,
3798    /// Where the frame starts
3799    pub start_bound: WindowFrameBound,
3800    /// Where the frame ends
3801    pub end_bound: WindowFrameBound,
3802}
3803
3804impl Display for WindowFrame {
3805    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3806        write!(
3807            f,
3808            "{} between {} and {}",
3809            self.units, self.start_bound, self.end_bound
3810        )
3811    }
3812}
3813
3814impl WindowFrame {
3815    /// Return the default window frame used when one is not explicitly defined
3816    pub fn default() -> Self {
3817        WindowFrame {
3818            units: WindowFrameUnits::Range,
3819            start_bound: WindowFrameBound::UnboundedPreceding,
3820            end_bound: WindowFrameBound::CurrentRow,
3821        }
3822    }
3823
3824    fn includes_current_row(&self) -> bool {
3825        use WindowFrameBound::*;
3826        match self.start_bound {
3827            UnboundedPreceding => match self.end_bound {
3828                UnboundedPreceding => false,
3829                OffsetPreceding(0) => true,
3830                OffsetPreceding(_) => false,
3831                CurrentRow => true,
3832                OffsetFollowing(_) => true,
3833                UnboundedFollowing => true,
3834            },
3835            OffsetPreceding(0) => match self.end_bound {
3836                UnboundedPreceding => unreachable!(),
3837                OffsetPreceding(0) => true,
3838                // Any nonzero offsets here will create an empty window
3839                OffsetPreceding(_) => false,
3840                CurrentRow => true,
3841                OffsetFollowing(_) => true,
3842                UnboundedFollowing => true,
3843            },
3844            OffsetPreceding(_) => match self.end_bound {
3845                UnboundedPreceding => unreachable!(),
3846                // Window ends at the current row
3847                OffsetPreceding(0) => true,
3848                OffsetPreceding(_) => false,
3849                CurrentRow => true,
3850                OffsetFollowing(_) => true,
3851                UnboundedFollowing => true,
3852            },
3853            CurrentRow => true,
3854            OffsetFollowing(0) => match self.end_bound {
3855                UnboundedPreceding => unreachable!(),
3856                OffsetPreceding(_) => unreachable!(),
3857                CurrentRow => unreachable!(),
3858                OffsetFollowing(_) => true,
3859                UnboundedFollowing => true,
3860            },
3861            OffsetFollowing(_) => match self.end_bound {
3862                UnboundedPreceding => unreachable!(),
3863                OffsetPreceding(_) => unreachable!(),
3864                CurrentRow => unreachable!(),
3865                OffsetFollowing(_) => false,
3866                UnboundedFollowing => false,
3867            },
3868            UnboundedFollowing => false,
3869        }
3870    }
3871}
3872
3873/// Describe how frame bounds are interpreted
3874#[derive(
3875    Debug,
3876    Clone,
3877    Eq,
3878    PartialEq,
3879    Ord,
3880    PartialOrd,
3881    Serialize,
3882    Deserialize,
3883    Hash,
3884    MzReflect
3885)]
3886pub enum WindowFrameUnits {
3887    /// Each row is treated as the unit of work for bounds
3888    Rows,
3889    /// Each peer group is treated as the unit of work for bounds,
3890    /// and offset-based bounds use the value of the ORDER BY expression
3891    Range,
3892    /// Each peer group is treated as the unit of work for bounds.
3893    /// Groups is currently not supported, and it is rejected during planning.
3894    Groups,
3895}
3896
3897impl Display for WindowFrameUnits {
3898    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3899        match self {
3900            WindowFrameUnits::Rows => write!(f, "rows"),
3901            WindowFrameUnits::Range => write!(f, "range"),
3902            WindowFrameUnits::Groups => write!(f, "groups"),
3903        }
3904    }
3905}
3906
3907/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3908///
3909/// The order between frame bounds is significant, as Postgres enforces
3910/// some restrictions there.
3911#[derive(
3912    Debug,
3913    Clone,
3914    Serialize,
3915    Deserialize,
3916    PartialEq,
3917    Eq,
3918    Hash,
3919    MzReflect,
3920    PartialOrd,
3921    Ord
3922)]
3923pub enum WindowFrameBound {
3924    /// `UNBOUNDED PRECEDING`
3925    UnboundedPreceding,
3926    /// `<N> PRECEDING`
3927    OffsetPreceding(u64),
3928    /// `CURRENT ROW`
3929    CurrentRow,
3930    /// `<N> FOLLOWING`
3931    OffsetFollowing(u64),
3932    /// `UNBOUNDED FOLLOWING`.
3933    UnboundedFollowing,
3934}
3935
3936impl Display for WindowFrameBound {
3937    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3938        match self {
3939            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
3940            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
3941            WindowFrameBound::CurrentRow => write!(f, "current row"),
3942            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
3943            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
3944        }
3945    }
3946}
3947
3948/// Maximum iterations for a LetRec.
3949#[derive(
3950    Debug,
3951    Clone,
3952    Copy,
3953    PartialEq,
3954    Eq,
3955    PartialOrd,
3956    Ord,
3957    Hash,
3958    Serialize,
3959    Deserialize
3960)]
3961pub struct LetRecLimit {
3962    /// Maximum number of iterations to evaluate.
3963    pub max_iters: NonZeroU64,
3964    /// Whether to throw an error when reaching the above limit.
3965    /// If true, we simply use the current contents of each Id as the final result.
3966    pub return_at_limit: bool,
3967}
3968
3969impl LetRecLimit {
3970    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
3971    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
3972        limits
3973            .iter()
3974            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
3975            .min()
3976    }
3977
3978    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
3979    /// WMR without ERROR AT or RETURN AT.
3980    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
3981}
3982
3983impl Display for LetRecLimit {
3984    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3985        write!(f, "[recursion_limit={}", self.max_iters)?;
3986        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
3987            write!(f, ", return_at_limit")?;
3988        }
3989        write!(f, "]")
3990    }
3991}
3992
3993/// For a global Get, this indicates whether we are going to read from Persist or from an index.
3994/// (See comment in MirRelationExpr::Get.)
3995#[derive(
3996    Clone,
3997    Debug,
3998    Eq,
3999    PartialEq,
4000    Ord,
4001    PartialOrd,
4002    Serialize,
4003    Deserialize,
4004    Hash
4005)]
4006pub enum AccessStrategy {
4007    /// It's either a local Get (a CTE), or unknown at the time.
4008    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4009    /// one of the other variants.
4010    UnknownOrLocal,
4011    /// The Get will read from Persist.
4012    Persist,
4013    /// The Get will read from an index or indexes: (index id, how the index will be used).
4014    Index(Vec<(GlobalId, IndexUsageType)>),
4015    /// The Get will read a collection that is computed by the same dataflow, but in a different
4016    /// `BuildDesc` in `objects_to_build`.
4017    SameDataflow,
4018}
4019
4020#[cfg(test)]
4021mod tests {
4022    use mz_repr::explain::text::text_string_at;
4023
4024    use crate::explain::HumanizedExplain;
4025
4026    use super::*;
4027
4028    #[mz_ore::test]
4029    fn test_row_set_finishing_as_text() {
4030        let finishing = RowSetFinishing {
4031            order_by: vec![ColumnOrder {
4032                column: 4,
4033                desc: true,
4034                nulls_last: true,
4035            }],
4036            limit: Some(NonNeg::try_from(7).unwrap()),
4037            offset: Default::default(),
4038            project: vec![1, 3, 4, 5],
4039        };
4040
4041        let mode = HumanizedExplain::new(false);
4042        let expr = mode.expr(&finishing, None);
4043
4044        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4045
4046        let exp = {
4047            use mz_ore::fmt::FormatBuffer;
4048            let mut s = String::new();
4049            write!(&mut s, "Finish");
4050            write!(&mut s, " order_by=[#4 desc nulls_last]");
4051            write!(&mut s, " limit=7");
4052            write!(&mut s, " output=[#1, #3..=#5]");
4053            writeln!(&mut s, "");
4054            s
4055        };
4056
4057        assert_eq!(act, exp);
4058    }
4059}
4060
4061/// An iterator over AST structures, which calls out nodes in difference.
4062///
4063/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4064/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4065/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4066/// to compare two ASTs.
4067mod structured_diff {
4068
4069    use super::MirRelationExpr;
4070    use itertools::Itertools;
4071
4072    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4073    pub struct MreDiff<'a> {
4074        /// Pairs of expressions that must still be compared.
4075        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4076    }
4077
4078    impl<'a> MreDiff<'a> {
4079        /// Create a new `MirRelationExpr` structured difference.
4080        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4081            MreDiff {
4082                todo: vec![(expr1, expr2)],
4083            }
4084        }
4085    }
4086
4087    impl<'a> Iterator for MreDiff<'a> {
4088        // Pairs of expressions that do not match.
4089        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4090
4091        fn next(&mut self) -> Option<Self::Item> {
4092            while let Some((expr1, expr2)) = self.todo.pop() {
4093                match (expr1, expr2) {
4094                    (
4095                        MirRelationExpr::Constant {
4096                            rows: rows1,
4097                            typ: typ1,
4098                        },
4099                        MirRelationExpr::Constant {
4100                            rows: rows2,
4101                            typ: typ2,
4102                        },
4103                    ) => {
4104                        if rows1 != rows2 || typ1 != typ2 {
4105                            return Some((expr1, expr2));
4106                        }
4107                    }
4108                    (
4109                        MirRelationExpr::Get {
4110                            id: id1,
4111                            typ: typ1,
4112                            access_strategy: as1,
4113                        },
4114                        MirRelationExpr::Get {
4115                            id: id2,
4116                            typ: typ2,
4117                            access_strategy: as2,
4118                        },
4119                    ) => {
4120                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4121                            return Some((expr1, expr2));
4122                        }
4123                    }
4124                    (
4125                        MirRelationExpr::Let {
4126                            id: id1,
4127                            body: body1,
4128                            value: value1,
4129                        },
4130                        MirRelationExpr::Let {
4131                            id: id2,
4132                            body: body2,
4133                            value: value2,
4134                        },
4135                    ) => {
4136                        if id1 != id2 {
4137                            return Some((expr1, expr2));
4138                        } else {
4139                            self.todo.push((body1, body2));
4140                            self.todo.push((value1, value2));
4141                        }
4142                    }
4143                    (
4144                        MirRelationExpr::LetRec {
4145                            ids: ids1,
4146                            body: body1,
4147                            values: values1,
4148                            limits: limits1,
4149                        },
4150                        MirRelationExpr::LetRec {
4151                            ids: ids2,
4152                            body: body2,
4153                            values: values2,
4154                            limits: limits2,
4155                        },
4156                    ) => {
4157                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4158                            return Some((expr1, expr2));
4159                        } else {
4160                            self.todo.push((body1, body2));
4161                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4162                        }
4163                    }
4164                    (
4165                        MirRelationExpr::Project {
4166                            outputs: outputs1,
4167                            input: input1,
4168                        },
4169                        MirRelationExpr::Project {
4170                            outputs: outputs2,
4171                            input: input2,
4172                        },
4173                    ) => {
4174                        if outputs1 != outputs2 {
4175                            return Some((expr1, expr2));
4176                        } else {
4177                            self.todo.push((input1, input2));
4178                        }
4179                    }
4180                    (
4181                        MirRelationExpr::Map {
4182                            scalars: scalars1,
4183                            input: input1,
4184                        },
4185                        MirRelationExpr::Map {
4186                            scalars: scalars2,
4187                            input: input2,
4188                        },
4189                    ) => {
4190                        if scalars1 != scalars2 {
4191                            return Some((expr1, expr2));
4192                        } else {
4193                            self.todo.push((input1, input2));
4194                        }
4195                    }
4196                    (
4197                        MirRelationExpr::Filter {
4198                            predicates: predicates1,
4199                            input: input1,
4200                        },
4201                        MirRelationExpr::Filter {
4202                            predicates: predicates2,
4203                            input: input2,
4204                        },
4205                    ) => {
4206                        if predicates1 != predicates2 {
4207                            return Some((expr1, expr2));
4208                        } else {
4209                            self.todo.push((input1, input2));
4210                        }
4211                    }
4212                    (
4213                        MirRelationExpr::FlatMap {
4214                            input: input1,
4215                            func: func1,
4216                            exprs: exprs1,
4217                        },
4218                        MirRelationExpr::FlatMap {
4219                            input: input2,
4220                            func: func2,
4221                            exprs: exprs2,
4222                        },
4223                    ) => {
4224                        if func1 != func2 || exprs1 != exprs2 {
4225                            return Some((expr1, expr2));
4226                        } else {
4227                            self.todo.push((input1, input2));
4228                        }
4229                    }
4230                    (
4231                        MirRelationExpr::Join {
4232                            inputs: inputs1,
4233                            equivalences: eq1,
4234                            implementation: impl1,
4235                        },
4236                        MirRelationExpr::Join {
4237                            inputs: inputs2,
4238                            equivalences: eq2,
4239                            implementation: impl2,
4240                        },
4241                    ) => {
4242                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4243                            return Some((expr1, expr2));
4244                        } else {
4245                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4246                        }
4247                    }
4248                    (
4249                        MirRelationExpr::Reduce {
4250                            aggregates: aggregates1,
4251                            input: inputs1,
4252                            group_key: gk1,
4253                            monotonic: m1,
4254                            expected_group_size: egs1,
4255                        },
4256                        MirRelationExpr::Reduce {
4257                            aggregates: aggregates2,
4258                            input: inputs2,
4259                            group_key: gk2,
4260                            monotonic: m2,
4261                            expected_group_size: egs2,
4262                        },
4263                    ) => {
4264                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4265                            return Some((expr1, expr2));
4266                        } else {
4267                            self.todo.push((inputs1, inputs2));
4268                        }
4269                    }
4270                    (
4271                        MirRelationExpr::TopK {
4272                            group_key: gk1,
4273                            order_key: order1,
4274                            input: input1,
4275                            limit: l1,
4276                            offset: o1,
4277                            monotonic: m1,
4278                            expected_group_size: egs1,
4279                        },
4280                        MirRelationExpr::TopK {
4281                            group_key: gk2,
4282                            order_key: order2,
4283                            input: input2,
4284                            limit: l2,
4285                            offset: o2,
4286                            monotonic: m2,
4287                            expected_group_size: egs2,
4288                        },
4289                    ) => {
4290                        if order1 != order2
4291                            || gk1 != gk2
4292                            || l1 != l2
4293                            || o1 != o2
4294                            || m1 != m2
4295                            || egs1 != egs2
4296                        {
4297                            return Some((expr1, expr2));
4298                        } else {
4299                            self.todo.push((input1, input2));
4300                        }
4301                    }
4302                    (
4303                        MirRelationExpr::Negate { input: input1 },
4304                        MirRelationExpr::Negate { input: input2 },
4305                    ) => {
4306                        self.todo.push((input1, input2));
4307                    }
4308                    (
4309                        MirRelationExpr::Threshold { input: input1 },
4310                        MirRelationExpr::Threshold { input: input2 },
4311                    ) => {
4312                        self.todo.push((input1, input2));
4313                    }
4314                    (
4315                        MirRelationExpr::Union {
4316                            base: base1,
4317                            inputs: inputs1,
4318                        },
4319                        MirRelationExpr::Union {
4320                            base: base2,
4321                            inputs: inputs2,
4322                        },
4323                    ) => {
4324                        if inputs1.len() != inputs2.len() {
4325                            return Some((expr1, expr2));
4326                        } else {
4327                            self.todo.push((base1, base2));
4328                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4329                        }
4330                    }
4331                    (
4332                        MirRelationExpr::ArrangeBy {
4333                            keys: keys1,
4334                            input: input1,
4335                        },
4336                        MirRelationExpr::ArrangeBy {
4337                            keys: keys2,
4338                            input: input2,
4339                        },
4340                    ) => {
4341                        if keys1 != keys2 {
4342                            return Some((expr1, expr2));
4343                        } else {
4344                            self.todo.push((input1, input2));
4345                        }
4346                    }
4347                    _ => {
4348                        return Some((expr1, expr2));
4349                    }
4350                }
4351            }
4352            None
4353        }
4354    }
4355}