Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::explain::text::text_string_at;
34use mz_repr::explain::{
35    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
36};
37use mz_repr::{
38    ColumnName, Datum, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType, Row,
39    RowIterator, SqlColumnType, SqlRelationType, SqlScalarType,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::Id::Local;
44use crate::explain::{HumanizedExpr, HumanizerMode};
45use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
46use crate::row::{RowCollection, SortedRowCollectionIter};
47use crate::visit::{Visit, VisitChildren};
48use crate::{
49    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
50    func as scalar_func,
51};
52
53pub mod canonicalize;
54pub mod func;
55pub mod join_input_mapper;
56
57/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
58///
59/// The recursion limit must be large enough to accommodate for the linear representation
60/// of some pathological but frequently occurring query fragments.
61///
62/// For example, in MIR we could have long chains of
63/// - (1) `Let` bindings,
64/// - (2) `CallBinary` calls with associative functions such as `+`
65///
66/// Until we fix those, we need to stick with the larger recursion limit.
67pub const RECURSION_LIMIT: usize = 2048;
68
69/// A trait for types that describe how to build a collection.
70pub trait CollectionPlan {
71    /// Collects the set of global identifiers from dataflows referenced in Get.
72    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
73
74    /// Returns the set of global identifiers from dataflows referenced in Get.
75    ///
76    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
77    fn depends_on(&self) -> BTreeSet<GlobalId> {
78        let mut out = BTreeSet::new();
79        self.depends_on_into(&mut out);
80        out
81    }
82}
83
84/// An abstract syntax tree which defines a collection.
85///
86/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
87/// written generically enough to avoid run-time compilation work.
88///
89/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
90/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
91/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
92/// one. This is because the reason for the manual implementation is not to change the semantics
93/// from the derived one, but to avoid stack overflows.
94#[allow(clippy::derived_hash_with_manual_eq)]
95#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
96pub enum MirRelationExpr {
97    /// A constant relation containing specified rows.
98    ///
99    /// The runtime memory footprint of this operator is zero.
100    ///
101    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
102    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
103    /// constant folding doesn't remove `ArrangeBy`s.
104    Constant {
105        /// Rows of the constant collection and their multiplicities.
106        rows: Result<Vec<(Row, Diff)>, EvalError>,
107        /// Schema of the collection.
108        typ: SqlRelationType,
109    },
110    /// Get an existing dataflow.
111    ///
112    /// The runtime memory footprint of this operator is zero.
113    Get {
114        /// The identifier for the collection to load.
115        #[mzreflect(ignore)]
116        id: Id,
117        /// Schema of the collection.
118        typ: SqlRelationType,
119        /// If this is a global Get, this will indicate whether we are going to read from Persist or
120        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
121        /// how downstream dataflow operations will use this index is also recorded. This is filled
122        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
123        /// lowering to LIR, but is used only by EXPLAIN.
124        #[mzreflect(ignore)]
125        access_strategy: AccessStrategy,
126    },
127    /// Introduce a temporary dataflow.
128    ///
129    /// The runtime memory footprint of this operator is zero.
130    Let {
131        /// The identifier to be used in `Get` variants to retrieve `value`.
132        #[mzreflect(ignore)]
133        id: LocalId,
134        /// The collection to be bound to `id`.
135        value: Box<MirRelationExpr>,
136        /// The result of the `Let`, evaluated with `id` bound to `value`.
137        body: Box<MirRelationExpr>,
138    },
139    /// Introduce mutually recursive bindings.
140    ///
141    /// Each `LocalId` is immediately bound to an initially empty  collection
142    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
143    /// binding is evaluated using the current contents of each other binding,
144    /// and is refreshed to contain the new evaluation. This process continues
145    /// through all bindings, and repeats as long as changes continue to occur.
146    ///
147    /// The resulting value of the expression is `body` evaluated once in the
148    /// context of the final iterates.
149    ///
150    /// A zero-binding instance can be replaced by `body`.
151    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
152    ///
153    /// The runtime memory footprint of this operator is zero.
154    LetRec {
155        /// The identifiers to be used in `Get` variants to retrieve each `value`.
156        #[mzreflect(ignore)]
157        ids: Vec<LocalId>,
158        /// The collections to be bound to each `id`.
159        values: Vec<MirRelationExpr>,
160        /// Maximum number of iterations, after which we should artificially force a fixpoint.
161        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
162        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
163        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
164        #[mzreflect(ignore)]
165        limits: Vec<Option<LetRecLimit>>,
166        /// The result of the `Let`, evaluated with `id` bound to `value`.
167        body: Box<MirRelationExpr>,
168    },
169    /// Project out some columns from a dataflow
170    ///
171    /// The runtime memory footprint of this operator is zero.
172    Project {
173        /// The source collection.
174        input: Box<MirRelationExpr>,
175        /// Indices of columns to retain.
176        outputs: Vec<usize>,
177    },
178    /// Append new columns to a dataflow
179    ///
180    /// The runtime memory footprint of this operator is zero.
181    Map {
182        /// The source collection.
183        input: Box<MirRelationExpr>,
184        /// Expressions which determine values to append to each row.
185        /// An expression may refer to columns in `input` or
186        /// expressions defined earlier in the vector
187        scalars: Vec<MirScalarExpr>,
188    },
189    /// Like Map, but yields zero-or-more output rows per input row
190    ///
191    /// The runtime memory footprint of this operator is zero.
192    FlatMap {
193        /// The source collection
194        input: Box<MirRelationExpr>,
195        /// The table func to apply
196        func: TableFunc,
197        /// The argument to the table func
198        exprs: Vec<MirScalarExpr>,
199    },
200    /// Keep rows from a dataflow where all the predicates are true
201    ///
202    /// The runtime memory footprint of this operator is zero.
203    Filter {
204        /// The source collection.
205        input: Box<MirRelationExpr>,
206        /// Predicates, each of which must be true.
207        predicates: Vec<MirScalarExpr>,
208    },
209    /// Join several collections, where some columns must be equal.
210    ///
211    /// For further details consult the documentation for [`MirRelationExpr::join`].
212    ///
213    /// The runtime memory footprint of this operator can be proportional to
214    /// the sizes of all inputs and the size of all joins of prefixes.
215    /// This may be reduced due to arrangements available at rendering time.
216    Join {
217        /// A sequence of input relations.
218        inputs: Vec<MirRelationExpr>,
219        /// A sequence of equivalence classes of expressions on the cross product of inputs.
220        ///
221        /// Each equivalence class is a list of scalar expressions, where for each class the
222        /// intended interpretation is that all evaluated expressions should be equal.
223        ///
224        /// Each scalar expression is to be evaluated over the cross-product of all records
225        /// from all inputs. In many cases this may just be column selection from specific
226        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
227        /// from multiple inputs, or just constant literals).
228        equivalences: Vec<Vec<MirScalarExpr>>,
229        /// Join implementation information.
230        #[serde(default)]
231        implementation: JoinImplementation,
232    },
233    /// Group a dataflow by some columns and aggregate over each group
234    ///
235    /// The runtime memory footprint of this operator is at most proportional to the
236    /// number of distinct records in the input and output. The actual requirements
237    /// can be less: the number of distinct inputs to each aggregate, summed across
238    /// each aggregate, plus the output size. For more details consult the code that
239    /// builds the associated dataflow.
240    Reduce {
241        /// The source collection.
242        input: Box<MirRelationExpr>,
243        /// Column indices used to form groups.
244        group_key: Vec<MirScalarExpr>,
245        /// Expressions which determine values to append to each row, after the group keys.
246        aggregates: Vec<AggregateExpr>,
247        /// True iff the input is known to monotonically increase (only addition of records).
248        #[serde(default)]
249        monotonic: bool,
250        /// User hint: expected number of values per group key. Used to optimize physical rendering.
251        #[serde(default)]
252        expected_group_size: Option<u64>,
253    },
254    /// Groups and orders within each group, limiting output.
255    ///
256    /// The runtime memory footprint of this operator is proportional to its input and output.
257    TopK {
258        /// The source collection.
259        input: Box<MirRelationExpr>,
260        /// Column indices used to form groups.
261        group_key: Vec<usize>,
262        /// Column indices used to order rows within groups.
263        order_key: Vec<ColumnOrder>,
264        /// Number of records to retain
265        #[serde(default)]
266        limit: Option<MirScalarExpr>,
267        /// Number of records to skip
268        #[serde(default)]
269        offset: usize,
270        /// True iff the input is known to monotonically increase (only addition of records).
271        #[serde(default)]
272        monotonic: bool,
273        /// User-supplied hint: how many rows will have the same group key.
274        #[serde(default)]
275        expected_group_size: Option<u64>,
276    },
277    /// Return a dataflow where the row counts are negated
278    ///
279    /// The runtime memory footprint of this operator is zero.
280    Negate {
281        /// The source collection.
282        input: Box<MirRelationExpr>,
283    },
284    /// Keep rows from a dataflow where the row counts are positive
285    ///
286    /// The runtime memory footprint of this operator is proportional to its input and output.
287    Threshold {
288        /// The source collection.
289        input: Box<MirRelationExpr>,
290    },
291    /// Adds the frequencies of elements in contained sets.
292    ///
293    /// The runtime memory footprint of this operator is zero.
294    Union {
295        /// A source collection.
296        base: Box<MirRelationExpr>,
297        /// Source collections to union.
298        inputs: Vec<MirRelationExpr>,
299    },
300    /// Technically a no-op. Used to render an index. Will be used to optimize queries
301    /// on finer grain. Each `keys` item represents a different index that should be
302    /// produced from the `keys`.
303    ///
304    /// The runtime memory footprint of this operator is proportional to its input.
305    ArrangeBy {
306        /// The source collection
307        input: Box<MirRelationExpr>,
308        /// Columns to arrange `input` by, in order of decreasing primacy
309        keys: Vec<Vec<MirScalarExpr>>,
310    },
311}
312
313impl PartialEq for MirRelationExpr {
314    fn eq(&self, other: &Self) -> bool {
315        // Capture the result and test it wrt `Ord` implementation in test environments.
316        let result = structured_diff::MreDiff::new(self, other).next().is_none();
317        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
318        result
319    }
320}
321impl Eq for MirRelationExpr {}
322
323impl MirRelationExpr {
324    /// Reports the schema of the relation.
325    ///
326    /// This method determines the type through recursive traversal of the
327    /// relation expression, drawing from the types of base collections.
328    /// As such, this is not an especially cheap method, and should be used
329    /// judiciously.
330    ///
331    /// The relation type is computed incrementally with a recursive post-order
332    /// traversal, that accumulates the input types for the relations yet to be
333    /// visited in `type_stack`.
334    pub fn typ(&self) -> SqlRelationType {
335        let mut type_stack = Vec::new();
336        #[allow(deprecated)]
337        self.visit_pre_post_nolimit(
338            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
339                match &e {
340                    MirRelationExpr::Let { body, .. } => {
341                        // Do not traverse the value sub-graph, since it's not relevant for
342                        // determining the relation type of Let operators.
343                        Some(vec![&*body])
344                    }
345                    MirRelationExpr::LetRec { body, .. } => {
346                        // Do not traverse the value sub-graph, since it's not relevant for
347                        // determining the relation type of Let operators.
348                        Some(vec![&*body])
349                    }
350                    _ => None,
351                }
352            },
353            &mut |e: &MirRelationExpr| {
354                match e {
355                    MirRelationExpr::Let { .. } => {
356                        let body_typ = type_stack.pop().unwrap();
357                        // Insert a dummy relation type for the value, since `typ_with_input_types`
358                        // won't look at it, but expects the relation type of the body to be second.
359                        type_stack.push(SqlRelationType::empty());
360                        type_stack.push(body_typ);
361                    }
362                    MirRelationExpr::LetRec { values, .. } => {
363                        let body_typ = type_stack.pop().unwrap();
364                        // Insert dummy relation types for the values, since `typ_with_input_types`
365                        // won't look at them, but expects the relation type of the body to be last.
366                        type_stack
367                            .extend(std::iter::repeat(SqlRelationType::empty()).take(values.len()));
368                        type_stack.push(body_typ);
369                    }
370                    _ => {}
371                }
372                let num_inputs = e.num_inputs();
373                let relation_type =
374                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
375                type_stack.truncate(type_stack.len() - num_inputs);
376                type_stack.push(relation_type);
377            },
378        );
379        assert_eq!(type_stack.len(), 1);
380        type_stack.pop().unwrap()
381    }
382
383    /// Reports the schema of the relation given the schema of the input relations.
384    ///
385    /// `input_types` is required to contain the schemas for the input relations of
386    /// the current relation in the same order as they are visited by `try_visit_children`
387    /// method, even though not all may be used for computing the schema of the
388    /// current relation. For example, `Let` expects two input types, one for the
389    /// value relation and one for the body, in that order, but only the one for the
390    /// body is used to determine the type of the `Let` relation.
391    ///
392    /// It is meant to be used during post-order traversals to compute relation
393    /// schemas incrementally.
394    pub fn typ_with_input_types(&self, input_types: &[SqlRelationType]) -> SqlRelationType {
395        let repr_types = input_types.iter().map(ReprRelationType::from).collect_vec();
396
397        let column_types =
398            self.repr_col_with_input_repr_cols(repr_types.iter().map(|i| &i.column_types));
399        let unique_keys = self.keys_with_input_keys(
400            input_types.iter().map(|i| i.arity()),
401            input_types.iter().map(|i| &i.keys),
402        );
403        SqlRelationType::new(column_types.iter().map(SqlColumnType::from_repr).collect())
404            .with_keys(unique_keys)
405    }
406
407    /// Reports the column types of the relation given the column types of the
408    /// input relations.
409    ///
410    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
411    /// variant is returned.
412    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<SqlColumnType>
413    where
414        I: Iterator<Item = &'a Vec<SqlColumnType>>,
415    {
416        match self.try_col_with_input_cols(input_types) {
417            Ok(col_types) => col_types,
418            Err(err) => panic!("{err}"),
419        }
420    }
421
422    /// Reports the column types of the relation given the column types of the input relations.
423    ///
424    /// `input_types` is required to contain the column types for the input relations of
425    /// the current relation in the same order as they are visited by `try_visit_children`
426    /// method, even though not all may be used for computing the schema of the
427    /// current relation. For example, `Let` expects two input types, one for the
428    /// value relation and one for the body, in that order, but only the one for the
429    /// body is used to determine the type of the `Let` relation.
430    ///
431    /// It is meant to be used during post-order traversals to compute column types
432    /// incrementally.
433    pub fn try_col_with_input_cols<'a, I>(
434        &self,
435        mut input_types: I,
436    ) -> Result<Vec<SqlColumnType>, String>
437    where
438        I: Iterator<Item = &'a Vec<SqlColumnType>>,
439    {
440        use MirRelationExpr::*;
441
442        let col_types = match self {
443            Constant { rows, typ } => {
444                let mut col_types = typ.column_types.clone();
445                let mut seen_null = vec![false; typ.arity()];
446                if let Ok(rows) = rows {
447                    for (row, _diff) in rows {
448                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
449                            if datum.is_null() {
450                                seen_null[i] = true;
451                            }
452                        }
453                    }
454                }
455                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
456                    if !seen_null {
457                        col_types[i].nullable = false;
458                    } else {
459                        assert!(col_types[i].nullable);
460                    }
461                }
462                col_types
463            }
464            Get { typ, .. } => typ.column_types.clone(),
465            Project { outputs, .. } => {
466                let input = input_types.next().unwrap();
467                outputs.iter().map(|&i| input[i].clone()).collect()
468            }
469            Map { scalars, .. } => {
470                let mut result = input_types.next().unwrap().clone();
471                for scalar in scalars.iter() {
472                    result.push(scalar.typ(&result))
473                }
474                result
475            }
476            FlatMap { func, .. } => {
477                let mut result = input_types.next().unwrap().clone();
478                result.extend(func.output_type().column_types);
479                result
480            }
481            Filter { predicates, .. } => {
482                let mut result = input_types.next().unwrap().clone();
483
484                // Set as nonnull any columns where null values would cause
485                // any predicate to evaluate to null.
486                for column in non_nullable_columns(predicates) {
487                    result[column].nullable = false;
488                }
489                result
490            }
491            Join { equivalences, .. } => {
492                // Concatenate input column types
493                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
494                // In an equivalence class, if any column is non-null, then make all non-null
495                for equivalence in equivalences {
496                    let col_inds = equivalence
497                        .iter()
498                        .filter_map(|expr| match expr {
499                            MirScalarExpr::Column(col, _name) => Some(*col),
500                            _ => None,
501                        })
502                        .collect_vec();
503                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
504                        for i in col_inds {
505                            types.get_mut(i).unwrap().nullable = false;
506                        }
507                    }
508                }
509                types
510            }
511            Reduce {
512                group_key,
513                aggregates,
514                ..
515            } => {
516                let input = input_types.next().unwrap();
517                group_key
518                    .iter()
519                    .map(|e| e.typ(input))
520                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
521                    .collect()
522            }
523            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
524                input_types.next().unwrap().clone()
525            }
526            Let { .. } => {
527                // skip over the input types for `value`.
528                input_types.nth(1).unwrap().clone()
529            }
530            LetRec { values, .. } => {
531                // skip over the input types for `values`.
532                input_types.nth(values.len()).unwrap().clone()
533            }
534            Union { .. } => {
535                let mut result = input_types.next().unwrap().clone();
536                for input_col_types in input_types {
537                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
538                        *base_col = base_col
539                            .try_union(col)
540                            .map_err(|e| format!("{e}\nin plan:\n{}", self.pretty()))?;
541                    }
542                }
543                result
544            }
545        };
546
547        Ok(col_types)
548    }
549
550    /// Reports the column types of the relation given the column types of the
551    /// input relations.
552    ///
553    /// This method delegates to `try_repr_col_with_input_repr_cols`, panicking if an `Err`
554    /// variant is returned.
555    pub fn repr_col_with_input_repr_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
556    where
557        I: Iterator<Item = &'a Vec<ReprColumnType>>,
558    {
559        match self.try_repr_col_with_input_repr_cols(input_types) {
560            Ok(col_types) => col_types,
561            Err(err) => panic!("{err}"),
562        }
563    }
564
565    /// Reports the column types of the relation given the column types of the input relations.
566    ///
567    /// `input_types` is required to contain the column types for the input relations of
568    /// the current relation in the same order as they are visited by `try_visit_children`
569    /// method, even though not all may be used for computing the schema of the
570    /// current relation. For example, `Let` expects two input types, one for the
571    /// value relation and one for the body, in that order, but only the one for the
572    /// body is used to determine the type of the `Let` relation.
573    ///
574    /// It is meant to be used during post-order traversals to compute column types
575    /// incrementally.
576    pub fn try_repr_col_with_input_repr_cols<'a, I>(
577        &self,
578        mut input_types: I,
579    ) -> Result<Vec<ReprColumnType>, String>
580    where
581        I: Iterator<Item = &'a Vec<ReprColumnType>>,
582    {
583        use MirRelationExpr::*;
584
585        let col_types = match self {
586            Constant { rows, typ } => {
587                let mut col_types = typ
588                    .column_types
589                    .iter()
590                    .map(ReprColumnType::from)
591                    .collect_vec();
592                let mut seen_null = vec![false; typ.arity()];
593                if let Ok(rows) = rows {
594                    for (row, _diff) in rows {
595                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
596                            if datum.is_null() {
597                                seen_null[i] = true;
598                            }
599                        }
600                    }
601                }
602                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
603                    if !seen_null {
604                        col_types[i].nullable = false;
605                    } else {
606                        assert!(col_types[i].nullable);
607                    }
608                }
609                col_types
610            }
611            Get { typ, .. } => typ
612                .column_types
613                .iter()
614                .map(ReprColumnType::from)
615                .collect_vec(),
616            Project { outputs, .. } => {
617                let input = input_types.next().unwrap();
618                outputs.iter().map(|&i| input[i].clone()).collect()
619            }
620            Map { scalars, .. } => {
621                let mut result = input_types.next().unwrap().clone();
622                for scalar in scalars.iter() {
623                    result.push(scalar.repr_typ(&result))
624                }
625                result
626            }
627            FlatMap { func, .. } => {
628                let mut result = input_types.next().unwrap().clone();
629                result.extend(
630                    func.output_type()
631                        .column_types
632                        .iter()
633                        .map(ReprColumnType::from),
634                );
635                result
636            }
637            Filter { predicates, .. } => {
638                let mut result = input_types.next().unwrap().clone();
639
640                // Set as nonnull any columns where null values would cause
641                // any predicate to evaluate to null.
642                for column in non_nullable_columns(predicates) {
643                    result[column].nullable = false;
644                }
645                result
646            }
647            Join { equivalences, .. } => {
648                // Concatenate input column types
649                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
650                // In an equivalence class, if any column is non-null, then make all non-null
651                for equivalence in equivalences {
652                    let col_inds = equivalence
653                        .iter()
654                        .filter_map(|expr| match expr {
655                            MirScalarExpr::Column(col, _name) => Some(*col),
656                            _ => None,
657                        })
658                        .collect_vec();
659                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
660                        for i in col_inds {
661                            types.get_mut(i).unwrap().nullable = false;
662                        }
663                    }
664                }
665                types
666            }
667            Reduce {
668                group_key,
669                aggregates,
670                ..
671            } => {
672                let input = input_types.next().unwrap();
673                group_key
674                    .iter()
675                    .map(|e| e.repr_typ(input))
676                    .chain(aggregates.iter().map(|agg| agg.repr_typ(input)))
677                    .collect()
678            }
679            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
680                input_types.next().unwrap().clone()
681            }
682            Let { .. } => {
683                // skip over the input types for `value`.
684                input_types.nth(1).unwrap().clone()
685            }
686            LetRec { values, .. } => {
687                // skip over the input types for `values`.
688                input_types.nth(values.len()).unwrap().clone()
689            }
690            Union { .. } => {
691                let mut result = input_types.next().unwrap().clone();
692                for input_col_types in input_types {
693                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
694                        *base_col = base_col
695                            .union(col)
696                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
697                    }
698                }
699                result
700            }
701        };
702
703        Ok(col_types)
704    }
705
706    /// Reports the unique keys of the relation given the arities and the unique
707    /// keys of the input relations.
708    ///
709    /// `input_arities` and `input_keys` are required to contain the
710    /// corresponding info for the input relations of
711    /// the current relation in the same order as they are visited by `try_visit_children`
712    /// method, even though not all may be used for computing the schema of the
713    /// current relation. For example, `Let` expects two input types, one for the
714    /// value relation and one for the body, in that order, but only the one for the
715    /// body is used to determine the type of the `Let` relation.
716    ///
717    /// It is meant to be used during post-order traversals to compute unique keys
718    /// incrementally.
719    pub fn keys_with_input_keys<'a, I, J>(
720        &self,
721        mut input_arities: I,
722        mut input_keys: J,
723    ) -> Vec<Vec<usize>>
724    where
725        I: Iterator<Item = usize>,
726        J: Iterator<Item = &'a Vec<Vec<usize>>>,
727    {
728        use MirRelationExpr::*;
729
730        let mut keys = match self {
731            Constant {
732                rows: Ok(rows),
733                typ,
734            } => {
735                let n_cols = typ.arity();
736                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
737                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
738                for (row, diff) in rows {
739                    for (i, datum) in row.iter().enumerate() {
740                        if datum != Datum::Dummy {
741                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
742                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
743                                if is_dupe {
744                                    unique_values_per_col[i] = None;
745                                }
746                            }
747                        }
748                    }
749                }
750                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
751                    vec![vec![]]
752                } else {
753                    // XXX - Multi-column keys are not detected.
754                    typ.keys
755                        .iter()
756                        .cloned()
757                        .chain(
758                            unique_values_per_col
759                                .into_iter()
760                                .enumerate()
761                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
762                                .map(|(idx, _)| vec![idx]),
763                        )
764                        .collect()
765                }
766            }
767            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
768            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
769            Let { .. } => {
770                // skip over the unique keys for value
771                input_keys.nth(1).unwrap().clone()
772            }
773            LetRec { values, .. } => {
774                // skip over the unique keys for value
775                input_keys.nth(values.len()).unwrap().clone()
776            }
777            Project { outputs, .. } => {
778                let input = input_keys.next().unwrap();
779                input
780                    .iter()
781                    .filter_map(|key_set| {
782                        if key_set.iter().all(|k| outputs.contains(k)) {
783                            Some(
784                                key_set
785                                    .iter()
786                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
787                                    .collect(),
788                            )
789                        } else {
790                            None
791                        }
792                    })
793                    .collect()
794            }
795            Map { scalars, .. } => {
796                let mut remappings = Vec::new();
797                let arity = input_arities.next().unwrap();
798                for (column, scalar) in scalars.iter().enumerate() {
799                    // assess whether the scalar preserves uniqueness,
800                    // and could participate in a key!
801
802                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
803                        match expr {
804                            MirScalarExpr::CallUnary { func, expr } => {
805                                if func.preserves_uniqueness() {
806                                    uniqueness(expr)
807                                } else {
808                                    None
809                                }
810                            }
811                            MirScalarExpr::Column(c, _name) => Some(*c),
812                            _ => None,
813                        }
814                    }
815
816                    if let Some(c) = uniqueness(scalar) {
817                        remappings.push((c, column + arity));
818                    }
819                }
820
821                let mut result = input_keys.next().unwrap().clone();
822                let mut new_keys = Vec::new();
823                // Any column in `remappings` could be replaced in a key
824                // by the corresponding c. This could lead to combinatorial
825                // explosion using our current representation, so we wont
826                // do that. Instead, we'll handle the case of one remapping.
827                if remappings.len() == 1 {
828                    let (old, new) = remappings.pop().unwrap();
829                    for key in &result {
830                        if key.contains(&old) {
831                            let mut new_key: Vec<usize> =
832                                key.iter().cloned().filter(|k| k != &old).collect();
833                            new_key.push(new);
834                            new_key.sort_unstable();
835                            new_keys.push(new_key);
836                        }
837                    }
838                    result.append(&mut new_keys);
839                }
840                result
841            }
842            FlatMap { .. } => {
843                // FlatMap can add duplicate rows, so input keys are no longer
844                // valid
845                vec![]
846            }
847            Negate { .. } => {
848                // Although negate may have distinct records for each key,
849                // the multiplicity is -1 rather than 1. This breaks many
850                // of the optimization uses of "keys".
851                vec![]
852            }
853            Filter { predicates, .. } => {
854                // A filter inherits the keys of its input unless the filters
855                // have reduced the input to a single row, in which case the
856                // keys of the input are `()`.
857                let mut input = input_keys.next().unwrap().clone();
858
859                if !input.is_empty() {
860                    // Track columns equated to literals, which we can prune.
861                    let mut cols_equal_to_literal = BTreeSet::new();
862
863                    // Perform union find on `col1 = col2` to establish
864                    // connected components of equated columns. Absent any
865                    // equalities, this will be `0 .. #c` (where #c is the
866                    // greatest column referenced by a predicate), but each
867                    // equality will orient the root of the greater to the root
868                    // of the lesser.
869                    let mut union_find = Vec::new();
870
871                    for expr in predicates.iter() {
872                        if let MirScalarExpr::CallBinary {
873                            func: crate::BinaryFunc::Eq(_),
874                            expr1,
875                            expr2,
876                        } = expr
877                        {
878                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
879                                if expr2.is_literal_ok() {
880                                    cols_equal_to_literal.insert(c);
881                                }
882                            }
883                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
884                                if expr1.is_literal_ok() {
885                                    cols_equal_to_literal.insert(c);
886                                }
887                            }
888                            // Perform union-find to equate columns.
889                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
890                                if c1 != c2 {
891                                    // Ensure union_find has entries up to
892                                    // max(c1, c2) by filling up missing
893                                    // positions with identity mappings.
894                                    while union_find.len() <= std::cmp::max(c1, c2) {
895                                        union_find.push(union_find.len());
896                                    }
897                                    let mut r1 = c1; // Find the representative column of [c1].
898                                    while r1 != union_find[r1] {
899                                        assert!(union_find[r1] < r1);
900                                        r1 = union_find[r1];
901                                    }
902                                    let mut r2 = c2; // Find the representative column of [c2].
903                                    while r2 != union_find[r2] {
904                                        assert!(union_find[r2] < r2);
905                                        r2 = union_find[r2];
906                                    }
907                                    // Union [c1] and [c2] by pointing the
908                                    // larger to the smaller representative (we
909                                    // update the remaining equivalence class
910                                    // members only once after this for-loop).
911                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
912                                }
913                            }
914                        }
915                    }
916
917                    // Complete union-find by pointing each element at its representative column.
918                    for i in 0..union_find.len() {
919                        // Iteration not required, as each prior already references the right column.
920                        union_find[i] = union_find[union_find[i]];
921                    }
922
923                    // Remove columns bound to literals, and remap columns equated to earlier columns.
924                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
925                    for key_set in &mut input {
926                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
927                        for col in key_set.iter_mut() {
928                            if let Some(equiv) = union_find.get(*col) {
929                                *col = *equiv;
930                            }
931                        }
932                        key_set.sort();
933                        key_set.dedup();
934                    }
935                    input.sort();
936                    input.dedup();
937
938                    // Expand out each key to each of its equivalent forms.
939                    // Each instance of `col` can be replaced by any equivalent column.
940                    // This has the potential to result in exponentially sized number of unique keys,
941                    // and in the future we should probably maintain unique keys modulo equivalence.
942
943                    // First, compute an inverse map from each representative
944                    // column `sub` to all other equivalent columns `col`.
945                    let mut subs = Vec::new();
946                    for (col, sub) in union_find.iter().enumerate() {
947                        if *sub != col {
948                            assert!(*sub < col);
949                            while subs.len() <= *sub {
950                                subs.push(Vec::new());
951                            }
952                            subs[*sub].push(col);
953                        }
954                    }
955                    // For each column, substitute for it in each occurrence.
956                    let mut to_add = Vec::new();
957                    for (col, subs) in subs.iter().enumerate() {
958                        if !subs.is_empty() {
959                            for key_set in input.iter() {
960                                if key_set.contains(&col) {
961                                    let mut to_extend = key_set.clone();
962                                    to_extend.retain(|c| c != &col);
963                                    for sub in subs {
964                                        to_extend.push(*sub);
965                                        to_add.push(to_extend.clone());
966                                        to_extend.pop();
967                                    }
968                                }
969                            }
970                        }
971                        // No deduplication, as we cannot introduce duplicates.
972                        input.append(&mut to_add);
973                    }
974                    for key_set in input.iter_mut() {
975                        key_set.sort();
976                        key_set.dedup();
977                    }
978                }
979                input
980            }
981            Join { equivalences, .. } => {
982                // It is important the `new_from_input_arities` constructor is
983                // used. Otherwise, Materialize may potentially end up in an
984                // infinite loop.
985                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
986
987                input_mapper.global_keys(input_keys, equivalences)
988            }
989            Reduce { group_key, .. } => {
990                // The group key should form a key, but we might already have
991                // keys that are subsets of the group key, and should retain
992                // those instead, if so.
993                let mut result = Vec::new();
994                for key_set in input_keys.next().unwrap() {
995                    if key_set
996                        .iter()
997                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
998                    {
999                        result.push(
1000                            key_set
1001                                .iter()
1002                                .map(|i| {
1003                                    group_key
1004                                        .iter()
1005                                        .position(|k| k == &MirScalarExpr::column(*i))
1006                                        .unwrap()
1007                                })
1008                                .collect::<Vec<_>>(),
1009                        );
1010                    }
1011                }
1012                if result.is_empty() {
1013                    result.push((0..group_key.len()).collect());
1014                }
1015                result
1016            }
1017            TopK {
1018                group_key, limit, ..
1019            } => {
1020                // If `limit` is `Some(1)` then the group key will become
1021                // a unique key, as there will be only one record with that key.
1022                let mut result = input_keys.next().unwrap().clone();
1023                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1024                    result.push(group_key.clone())
1025                }
1026                result
1027            }
1028            Union { base, inputs } => {
1029                // Generally, unions do not have any unique keys, because
1030                // each input might duplicate some. However, there is at
1031                // least one idiomatic structure that does preserve keys,
1032                // which results from SQL aggregations that must populate
1033                // absent records with default values. In that pattern,
1034                // the union of one GET with its negation, which has first
1035                // been subjected to a projection and map, we can remove
1036                // their influence on the key structure.
1037                //
1038                // If there are A, B, each with a unique `key` such that
1039                // we are looking at
1040                //
1041                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1042                //
1043                // Then we can report `key` as a unique key.
1044                //
1045                // TODO: make unique key structure an optimization analysis
1046                // rather than part of the type information.
1047                // TODO: perhaps ensure that (above) A.proj(key) is a
1048                // subset of B, as otherwise there are negative records
1049                // and who knows what is true (not expected, but again
1050                // who knows what the query plan might look like).
1051
1052                let arity = input_arities.next().unwrap();
1053                let (base_projection, base_with_project_stripped) =
1054                    if let MirRelationExpr::Project { input, outputs } = &**base {
1055                        (outputs.clone(), &**input)
1056                    } else {
1057                        // A input without a project is equivalent to an input
1058                        // with the project being all columns in the input in order.
1059                        ((0..arity).collect::<Vec<_>>(), &**base)
1060                    };
1061                let mut result = Vec::new();
1062                if let MirRelationExpr::Get {
1063                    id: first_id,
1064                    typ: _,
1065                    ..
1066                } = base_with_project_stripped
1067                {
1068                    if inputs.len() == 1 {
1069                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1070                            if let MirRelationExpr::Union { base, inputs } = &**input {
1071                                if inputs.len() == 1 {
1072                                    if let Some((input, outputs)) = base.is_negated_project() {
1073                                        if let MirRelationExpr::Get {
1074                                            id: second_id,
1075                                            typ: _,
1076                                            ..
1077                                        } = input
1078                                        {
1079                                            if first_id == second_id {
1080                                                result.extend(
1081                                                    input_keys
1082                                                        .next()
1083                                                        .unwrap()
1084                                                        .into_iter()
1085                                                        .filter(|key| {
1086                                                            key.iter().all(|c| {
1087                                                                outputs.get(*c) == Some(c)
1088                                                                    && base_projection.get(*c)
1089                                                                        == Some(c)
1090                                                            })
1091                                                        })
1092                                                        .cloned(),
1093                                                );
1094                                            }
1095                                        }
1096                                    }
1097                                }
1098                            }
1099                        }
1100                    }
1101                }
1102                // Important: do not inherit keys of either input, as not unique.
1103                result
1104            }
1105        };
1106        keys.sort();
1107        keys.dedup();
1108        keys
1109    }
1110
1111    /// The number of columns in the relation.
1112    ///
1113    /// This number is determined from the type, which is determined recursively
1114    /// at non-trivial cost.
1115    ///
1116    /// The arity is computed incrementally with a recursive post-order
1117    /// traversal, that accumulates the arities for the relations yet to be
1118    /// visited in `arity_stack`.
1119    pub fn arity(&self) -> usize {
1120        let mut arity_stack = Vec::new();
1121        #[allow(deprecated)]
1122        self.visit_pre_post_nolimit(
1123            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1124                match &e {
1125                    MirRelationExpr::Let { body, .. } => {
1126                        // Do not traverse the value sub-graph, since it's not relevant for
1127                        // determining the arity of Let operators.
1128                        Some(vec![&*body])
1129                    }
1130                    MirRelationExpr::LetRec { body, .. } => {
1131                        // Do not traverse the value sub-graph, since it's not relevant for
1132                        // determining the arity of Let operators.
1133                        Some(vec![&*body])
1134                    }
1135                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1136                        // No further traversal is required; these operators know their arity.
1137                        Some(Vec::new())
1138                    }
1139                    _ => None,
1140                }
1141            },
1142            &mut |e: &MirRelationExpr| {
1143                match &e {
1144                    MirRelationExpr::Let { .. } => {
1145                        let body_arity = arity_stack.pop().unwrap();
1146                        arity_stack.push(0);
1147                        arity_stack.push(body_arity);
1148                    }
1149                    MirRelationExpr::LetRec { values, .. } => {
1150                        let body_arity = arity_stack.pop().unwrap();
1151                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
1152                        arity_stack.push(body_arity);
1153                    }
1154                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1155                        arity_stack.push(0);
1156                    }
1157                    _ => {}
1158                }
1159                let num_inputs = e.num_inputs();
1160                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1161                let arity = e.arity_with_input_arities(input_arities);
1162                arity_stack.push(arity);
1163            },
1164        );
1165        assert_eq!(arity_stack.len(), 1);
1166        arity_stack.pop().unwrap()
1167    }
1168
1169    /// Reports the arity of the relation given the schema of the input relations.
1170    ///
1171    /// `input_arities` is required to contain the arities for the input relations of
1172    /// the current relation in the same order as they are visited by `try_visit_children`
1173    /// method, even though not all may be used for computing the schema of the
1174    /// current relation. For example, `Let` expects two input types, one for the
1175    /// value relation and one for the body, in that order, but only the one for the
1176    /// body is used to determine the type of the `Let` relation.
1177    ///
1178    /// It is meant to be used during post-order traversals to compute arities
1179    /// incrementally.
1180    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1181    where
1182        I: Iterator<Item = usize>,
1183    {
1184        use MirRelationExpr::*;
1185
1186        match self {
1187            Constant { rows: _, typ } => typ.arity(),
1188            Get { typ, .. } => typ.arity(),
1189            Let { .. } => {
1190                input_arities.next();
1191                input_arities.next().unwrap()
1192            }
1193            LetRec { values, .. } => {
1194                for _ in 0..values.len() {
1195                    input_arities.next();
1196                }
1197                input_arities.next().unwrap()
1198            }
1199            Project { outputs, .. } => outputs.len(),
1200            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1201            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1202            Join { .. } => input_arities.sum(),
1203            Reduce {
1204                input: _,
1205                group_key,
1206                aggregates,
1207                ..
1208            } => group_key.len() + aggregates.len(),
1209            Filter { .. }
1210            | TopK { .. }
1211            | Negate { .. }
1212            | Threshold { .. }
1213            | Union { .. }
1214            | ArrangeBy { .. } => input_arities.next().unwrap(),
1215        }
1216    }
1217
1218    /// The number of child relations this relation has.
1219    pub fn num_inputs(&self) -> usize {
1220        let mut count = 0;
1221
1222        self.visit_children(|_| count += 1);
1223
1224        count
1225    }
1226
1227    /// Constructs a constant collection from specific rows and schema, where
1228    /// each row will have a multiplicity of one.
1229    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
1230        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1231        MirRelationExpr::constant_diff(rows, typ)
1232    }
1233
1234    /// Constructs a constant collection from specific rows and schema, where
1235    /// each row can have an arbitrary multiplicity.
1236    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: SqlRelationType) -> Self {
1237        for (row, _diff) in &rows {
1238            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1239                assert!(
1240                    datum.is_instance_of_sql(column_typ),
1241                    "Expected datum of type {:?}, got value {:?}",
1242                    column_typ,
1243                    datum
1244                );
1245            }
1246        }
1247        let rows = Ok(rows
1248            .into_iter()
1249            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1250            .collect());
1251        MirRelationExpr::Constant { rows, typ }
1252    }
1253
1254    /// If self is a constant, return the value and the type, otherwise `None`.
1255    /// Looks behind `ArrangeBy`s.
1256    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &SqlRelationType)> {
1257        match self {
1258            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1259            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1260            _ => None,
1261        }
1262    }
1263
1264    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1265    /// Looks behind `ArrangeBy`s.
1266    pub fn as_const_mut(
1267        &mut self,
1268    ) -> Option<(
1269        &mut Result<Vec<(Row, Diff)>, EvalError>,
1270        &mut SqlRelationType,
1271    )> {
1272        match self {
1273            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1274            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1275            _ => None,
1276        }
1277    }
1278
1279    /// If self is a constant error, return the error, otherwise `None`.
1280    /// Looks behind `ArrangeBy`s.
1281    pub fn as_const_err(&self) -> Option<&EvalError> {
1282        match self {
1283            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1284            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1285            _ => None,
1286        }
1287    }
1288
1289    /// Checks if `self` is the single element collection with no columns.
1290    pub fn is_constant_singleton(&self) -> bool {
1291        if let Some((Ok(rows), typ)) = self.as_const() {
1292            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1293        } else {
1294            false
1295        }
1296    }
1297
1298    /// Constructs the expression for getting a local collection.
1299    pub fn local_get(id: LocalId, typ: SqlRelationType) -> Self {
1300        MirRelationExpr::Get {
1301            id: Id::Local(id),
1302            typ,
1303            access_strategy: AccessStrategy::UnknownOrLocal,
1304        }
1305    }
1306
1307    /// Constructs the expression for getting a global collection
1308    pub fn global_get(id: GlobalId, typ: SqlRelationType) -> Self {
1309        MirRelationExpr::Get {
1310            id: Id::Global(id),
1311            typ,
1312            access_strategy: AccessStrategy::UnknownOrLocal,
1313        }
1314    }
1315
1316    /// Retains only the columns specified by `output`.
1317    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1318        if let MirRelationExpr::Project {
1319            outputs: columns, ..
1320        } = &mut self
1321        {
1322            // Update `outputs` to reference base columns of `input`.
1323            for column in outputs.iter_mut() {
1324                *column = columns[*column];
1325            }
1326            *columns = outputs;
1327            self
1328        } else {
1329            MirRelationExpr::Project {
1330                input: Box::new(self),
1331                outputs,
1332            }
1333        }
1334    }
1335
1336    /// Append to each row the results of applying elements of `scalar`.
1337    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1338        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1339            s.extend(scalars);
1340            self
1341        } else if !scalars.is_empty() {
1342            MirRelationExpr::Map {
1343                input: Box::new(self),
1344                scalars,
1345            }
1346        } else {
1347            self
1348        }
1349    }
1350
1351    /// Append to each row a single `scalar`.
1352    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1353        self.map(vec![scalar])
1354    }
1355
1356    /// Like `map`, but yields zero-or-more output rows per input row
1357    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1358        MirRelationExpr::FlatMap {
1359            input: Box::new(self),
1360            func,
1361            exprs,
1362        }
1363    }
1364
1365    /// Retain only the rows satisfying each of several predicates.
1366    pub fn filter<I>(mut self, predicates: I) -> Self
1367    where
1368        I: IntoIterator<Item = MirScalarExpr>,
1369    {
1370        // Extract existing predicates
1371        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1372            self = *input;
1373            predicates
1374        } else {
1375            Vec::new()
1376        };
1377        // Normalize collection of predicates.
1378        new_predicates.extend(predicates);
1379        new_predicates.retain(|p| !p.is_literal_true());
1380        new_predicates.sort();
1381        new_predicates.dedup();
1382        // Introduce a `Filter` only if we have predicates.
1383        if !new_predicates.is_empty() {
1384            self = MirRelationExpr::Filter {
1385                input: Box::new(self),
1386                predicates: new_predicates,
1387            };
1388        }
1389
1390        self
1391    }
1392
1393    /// Form the Cartesian outer-product of rows in both inputs.
1394    pub fn product(mut self, right: Self) -> Self {
1395        if right.is_constant_singleton() {
1396            self
1397        } else if self.is_constant_singleton() {
1398            right
1399        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1400            inputs.push(right);
1401            self
1402        } else {
1403            MirRelationExpr::join(vec![self, right], vec![])
1404        }
1405    }
1406
1407    /// Performs a relational equijoin among the input collections.
1408    ///
1409    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1410    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1411    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1412    ///
1413    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1414    /// input collection and for each row examining its `column`th column.
1415    ///
1416    /// # Example
1417    ///
1418    /// ```rust
1419    /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
1420    /// use mz_expr::MirRelationExpr;
1421    ///
1422    /// // A common schema for each input.
1423    /// let schema = SqlRelationType::new(vec![
1424    ///     SqlScalarType::Int32.nullable(false),
1425    ///     SqlScalarType::Int32.nullable(false),
1426    /// ]);
1427    ///
1428    /// // the specific data are not important here.
1429    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1430    ///
1431    /// // Three collections that could have been different.
1432    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1433    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1434    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1435    ///
1436    /// // Join the three relations looking for triangles, like so.
1437    /// //
1438    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1439    /// let joined = MirRelationExpr::join(
1440    ///     vec![input0, input1, input2],
1441    ///     vec![
1442    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1443    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1444    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1445    ///     ],
1446    /// );
1447    ///
1448    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1449    /// // A projection resolves this and produces the correct output.
1450    /// let result = joined.project(vec![0, 1, 3]);
1451    /// ```
1452    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1453        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1454
1455        let equivalences = variables
1456            .into_iter()
1457            .map(|vs| {
1458                vs.into_iter()
1459                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1460                    .collect::<Vec<_>>()
1461            })
1462            .collect::<Vec<_>>();
1463
1464        Self::join_scalars(inputs, equivalences)
1465    }
1466
1467    /// Constructs a join operator from inputs and required-equal scalar expressions.
1468    pub fn join_scalars(
1469        mut inputs: Vec<MirRelationExpr>,
1470        equivalences: Vec<Vec<MirScalarExpr>>,
1471    ) -> Self {
1472        // Remove all constant inputs that are the identity for join.
1473        // They neither introduce nor modify any column references.
1474        inputs.retain(|i| !i.is_constant_singleton());
1475        MirRelationExpr::Join {
1476            inputs,
1477            equivalences,
1478            implementation: JoinImplementation::Unimplemented,
1479        }
1480    }
1481
1482    /// Perform a key-wise reduction / aggregation.
1483    ///
1484    /// The `group_key` argument indicates columns in the input collection that should
1485    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1486    /// one output column in addition to the keys.
1487    pub fn reduce(
1488        self,
1489        group_key: Vec<usize>,
1490        aggregates: Vec<AggregateExpr>,
1491        expected_group_size: Option<u64>,
1492    ) -> Self {
1493        MirRelationExpr::Reduce {
1494            input: Box::new(self),
1495            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1496            aggregates,
1497            monotonic: false,
1498            expected_group_size,
1499        }
1500    }
1501
1502    /// Perform a key-wise reduction order by and limit.
1503    ///
1504    /// The `group_key` argument indicates columns in the input collection that should
1505    /// be grouped, the `order_key` argument indicates columns that should be further
1506    /// used to order records within groups, and the `limit` argument constrains the
1507    /// total number of records that should be produced in each group.
1508    pub fn top_k(
1509        self,
1510        group_key: Vec<usize>,
1511        order_key: Vec<ColumnOrder>,
1512        limit: Option<MirScalarExpr>,
1513        offset: usize,
1514        expected_group_size: Option<u64>,
1515    ) -> Self {
1516        MirRelationExpr::TopK {
1517            input: Box::new(self),
1518            group_key,
1519            order_key,
1520            limit,
1521            offset,
1522            expected_group_size,
1523            monotonic: false,
1524        }
1525    }
1526
1527    /// Negates the occurrences of each row.
1528    pub fn negate(self) -> Self {
1529        if let MirRelationExpr::Negate { input } = self {
1530            *input
1531        } else {
1532            MirRelationExpr::Negate {
1533                input: Box::new(self),
1534            }
1535        }
1536    }
1537
1538    /// Removes all but the first occurrence of each row.
1539    pub fn distinct(self) -> Self {
1540        let arity = self.arity();
1541        self.distinct_by((0..arity).collect())
1542    }
1543
1544    /// Removes all but the first occurrence of each key. Columns not included
1545    /// in the `group_key` are discarded.
1546    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1547        self.reduce(group_key, vec![], None)
1548    }
1549
1550    /// Discards rows with a negative frequency.
1551    pub fn threshold(self) -> Self {
1552        if let MirRelationExpr::Threshold { .. } = &self {
1553            self
1554        } else {
1555            MirRelationExpr::Threshold {
1556                input: Box::new(self),
1557            }
1558        }
1559    }
1560
1561    /// Unions together any number inputs.
1562    ///
1563    /// If `inputs` is empty, then an empty relation of type `typ` is
1564    /// constructed.
1565    pub fn union_many(mut inputs: Vec<Self>, typ: SqlRelationType) -> Self {
1566        // Deconstruct `inputs` as `Union`s and reconstitute.
1567        let mut flat_inputs = Vec::with_capacity(inputs.len());
1568        for input in inputs {
1569            if let MirRelationExpr::Union { base, inputs } = input {
1570                flat_inputs.push(*base);
1571                flat_inputs.extend(inputs);
1572            } else {
1573                flat_inputs.push(input);
1574            }
1575        }
1576        inputs = flat_inputs;
1577        if inputs.len() == 0 {
1578            MirRelationExpr::Constant {
1579                rows: Ok(vec![]),
1580                typ,
1581            }
1582        } else if inputs.len() == 1 {
1583            inputs.into_element()
1584        } else {
1585            MirRelationExpr::Union {
1586                base: Box::new(inputs.remove(0)),
1587                inputs,
1588            }
1589        }
1590    }
1591
1592    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1593    pub fn union(self, other: Self) -> Self {
1594        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1595        let mut flat_inputs = Vec::with_capacity(2);
1596        if let MirRelationExpr::Union { base, inputs } = self {
1597            flat_inputs.push(*base);
1598            flat_inputs.extend(inputs);
1599        } else {
1600            flat_inputs.push(self);
1601        }
1602        if let MirRelationExpr::Union { base, inputs } = other {
1603            flat_inputs.push(*base);
1604            flat_inputs.extend(inputs);
1605        } else {
1606            flat_inputs.push(other);
1607        }
1608
1609        MirRelationExpr::Union {
1610            base: Box::new(flat_inputs.remove(0)),
1611            inputs: flat_inputs,
1612        }
1613    }
1614
1615    /// Arranges the collection by the specified columns
1616    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1617        MirRelationExpr::ArrangeBy {
1618            input: Box::new(self),
1619            keys: keys.to_owned(),
1620        }
1621    }
1622
1623    /// Indicates if this is a constant empty collection.
1624    ///
1625    /// A false value does not mean the collection is known to be non-empty,
1626    /// only that we cannot currently determine that it is statically empty.
1627    pub fn is_empty(&self) -> bool {
1628        if let Some((Ok(rows), ..)) = self.as_const() {
1629            rows.is_empty()
1630        } else {
1631            false
1632        }
1633    }
1634
1635    /// If the expression is a negated project, return the input and the projection.
1636    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1637        if let MirRelationExpr::Negate { input } = self {
1638            if let MirRelationExpr::Project { input, outputs } = &**input {
1639                return Some((&**input, outputs));
1640            }
1641        }
1642        if let MirRelationExpr::Project { input, outputs } = self {
1643            if let MirRelationExpr::Negate { input } = &**input {
1644                return Some((&**input, outputs));
1645            }
1646        }
1647        None
1648    }
1649
1650    /// Pretty-print this [MirRelationExpr] to a string.
1651    pub fn pretty(&self) -> String {
1652        let config = ExplainConfig::default();
1653        self.explain(&config, None)
1654    }
1655
1656    /// Pretty-print this [MirRelationExpr] to a string using a custom
1657    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1658    pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1659        text_string_at(self, || PlanRenderingContext {
1660            indent: Indent::default(),
1661            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1662            annotations: BTreeMap::default(),
1663            config,
1664            ambiguous_ids: BTreeSet::default(),
1665        })
1666    }
1667
1668    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1669    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1670    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1671    /// the best possible key and nullability, since we are making an empty collection.
1672    ///
1673    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1674    /// the correct type.
1675    pub fn take_safely(&mut self, typ: Option<SqlRelationType>) -> MirRelationExpr {
1676        if let Some(typ) = &typ {
1677            soft_assert_no_log!(
1678                self.typ()
1679                    .column_types
1680                    .iter()
1681                    .zip_eq(typ.column_types.iter())
1682                    .all(|(t1, t2)| t1
1683                        .scalar_type
1684                        .base_eq_or_repr_eq_for_assertion(&t2.scalar_type))
1685            );
1686        }
1687        let mut typ = typ.unwrap_or_else(|| self.typ());
1688        typ.keys = vec![vec![]];
1689        for ct in typ.column_types.iter_mut() {
1690            ct.nullable = false;
1691        }
1692        std::mem::replace(
1693            self,
1694            MirRelationExpr::Constant {
1695                rows: Ok(vec![]),
1696                typ,
1697            },
1698        )
1699    }
1700
1701    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1702    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1703    /// possible nullability, since we are making an empty collection.
1704    pub fn take_safely_with_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1705        self.take_safely(Some(SqlRelationType::new(typ)))
1706    }
1707
1708    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1709    ///
1710    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1711    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1712        let empty = MirRelationExpr::Constant {
1713            rows: Ok(vec![]),
1714            typ: SqlRelationType::new(Vec::new()),
1715        };
1716        std::mem::replace(self, empty)
1717    }
1718
1719    /// Replaces `self` with some logic applied to `self`.
1720    pub fn replace_using<F>(&mut self, logic: F)
1721    where
1722        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1723    {
1724        let empty = MirRelationExpr::Constant {
1725            rows: Ok(vec![]),
1726            typ: SqlRelationType::new(Vec::new()),
1727        };
1728        let expr = std::mem::replace(self, empty);
1729        *self = logic(expr);
1730    }
1731
1732    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1733    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1734    where
1735        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1736    {
1737        if let MirRelationExpr::Get { .. } = self {
1738            // already done
1739            body(id_gen, self)
1740        } else {
1741            let id = LocalId::new(id_gen.allocate_id());
1742            let get = MirRelationExpr::Get {
1743                id: Id::Local(id),
1744                typ: self.typ(),
1745                access_strategy: AccessStrategy::UnknownOrLocal,
1746            };
1747            let body = (body)(id_gen, get)?;
1748            Ok(MirRelationExpr::Let {
1749                id,
1750                value: Box::new(self),
1751                body: Box::new(body),
1752            })
1753        }
1754    }
1755
1756    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1757    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1758    pub fn anti_lookup<E>(
1759        self,
1760        id_gen: &mut IdGen,
1761        keys_and_values: MirRelationExpr,
1762        default: Vec<(Datum, SqlScalarType)>,
1763    ) -> Result<MirRelationExpr, E> {
1764        let (data, column_types): (Vec<_>, Vec<_>) = default
1765            .into_iter()
1766            .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1767            .unzip();
1768        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1769        self.let_in(id_gen, |_id_gen, get_keys| {
1770            let get_keys_arity = get_keys.arity();
1771            Ok(MirRelationExpr::join(
1772                vec![
1773                    // all the missing keys (with count 1)
1774                    keys_and_values
1775                        .distinct_by((0..get_keys_arity).collect())
1776                        .negate()
1777                        .union(get_keys.clone().distinct()),
1778                    // join with keys to get the correct counts
1779                    get_keys.clone(),
1780                ],
1781                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1782            )
1783            // get rid of the extra copies of columns from keys
1784            .project((0..get_keys_arity).collect())
1785            // This join is logically equivalent to
1786            // `.map(<default_expr>)`, but using a join allows for
1787            // potential predicate pushdown and elision in the
1788            // optimizer.
1789            .product(MirRelationExpr::constant(
1790                vec![data],
1791                SqlRelationType::new(column_types),
1792            )))
1793        })
1794    }
1795
1796    /// Return:
1797    /// * every row in keys_and_values
1798    /// * every row in `self` that does not have a matching row in the first columns of
1799    ///   `keys_and_values`, using `default` to fill in the remaining columns
1800    /// (This is LEFT OUTER JOIN if:
1801    /// 1) `default` is a row of null
1802    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1803    pub fn lookup<E>(
1804        self,
1805        id_gen: &mut IdGen,
1806        keys_and_values: MirRelationExpr,
1807        default: Vec<(Datum<'static>, SqlScalarType)>,
1808    ) -> Result<MirRelationExpr, E> {
1809        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1810            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1811                id_gen,
1812                get_keys_and_values,
1813                default,
1814            )?))
1815        })
1816    }
1817
1818    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1819    pub fn contains_temporal(&self) -> bool {
1820        let mut contains = false;
1821        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1822        contains
1823    }
1824
1825    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1826    ///
1827    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1828    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1829    where
1830        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1831    {
1832        use MirRelationExpr::*;
1833        match self {
1834            Map { scalars, .. } => {
1835                for s in scalars {
1836                    f(s)?;
1837                }
1838            }
1839            Filter { predicates, .. } => {
1840                for p in predicates {
1841                    f(p)?;
1842                }
1843            }
1844            FlatMap { exprs, .. } => {
1845                for expr in exprs {
1846                    f(expr)?;
1847                }
1848            }
1849            Join {
1850                inputs: _,
1851                equivalences,
1852                implementation,
1853            } => {
1854                for equivalence in equivalences {
1855                    for expr in equivalence {
1856                        f(expr)?;
1857                    }
1858                }
1859                match implementation {
1860                    JoinImplementation::Differential((_, start_key, _), order) => {
1861                        if let Some(start_key) = start_key {
1862                            for k in start_key {
1863                                f(k)?;
1864                            }
1865                        }
1866                        for (_, lookup_key, _) in order {
1867                            for k in lookup_key {
1868                                f(k)?;
1869                            }
1870                        }
1871                    }
1872                    JoinImplementation::DeltaQuery(paths) => {
1873                        for path in paths {
1874                            for (_, lookup_key, _) in path {
1875                                for k in lookup_key {
1876                                    f(k)?;
1877                                }
1878                            }
1879                        }
1880                    }
1881                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1882                        for k in index_key {
1883                            f(k)?;
1884                        }
1885                    }
1886                    JoinImplementation::Unimplemented => {} // No scalar exprs
1887                }
1888            }
1889            ArrangeBy { keys, .. } => {
1890                for key in keys {
1891                    for s in key {
1892                        f(s)?;
1893                    }
1894                }
1895            }
1896            Reduce {
1897                group_key,
1898                aggregates,
1899                ..
1900            } => {
1901                for s in group_key {
1902                    f(s)?;
1903                }
1904                for agg in aggregates {
1905                    f(&mut agg.expr)?;
1906                }
1907            }
1908            TopK { limit, .. } => {
1909                if let Some(s) = limit {
1910                    f(s)?;
1911                }
1912            }
1913            Constant { .. }
1914            | Get { .. }
1915            | Let { .. }
1916            | LetRec { .. }
1917            | Project { .. }
1918            | Negate { .. }
1919            | Threshold { .. }
1920            | Union { .. } => (),
1921        }
1922        Ok(())
1923    }
1924
1925    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1926    /// rooted at `self`.
1927    ///
1928    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1929    /// nodes.
1930    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1931    where
1932        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1933        E: From<RecursionLimitError>,
1934    {
1935        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1936    }
1937
1938    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1939    /// rooted at `self`.
1940    ///
1941    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1942    /// nodes.
1943    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1944    where
1945        F: FnMut(&mut MirScalarExpr),
1946    {
1947        self.try_visit_scalars_mut(&mut |s| {
1948            f(s);
1949            Ok::<_, RecursionLimitError>(())
1950        })
1951        .expect("Unexpected error in `visit_scalars_mut` call");
1952    }
1953
1954    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1955    ///
1956    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1957    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1958    where
1959        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1960    {
1961        use MirRelationExpr::*;
1962        match self {
1963            Map { scalars, .. } => {
1964                for s in scalars {
1965                    f(s)?;
1966                }
1967            }
1968            Filter { predicates, .. } => {
1969                for p in predicates {
1970                    f(p)?;
1971                }
1972            }
1973            FlatMap { exprs, .. } => {
1974                for expr in exprs {
1975                    f(expr)?;
1976                }
1977            }
1978            Join {
1979                inputs: _,
1980                equivalences,
1981                implementation,
1982            } => {
1983                for equivalence in equivalences {
1984                    for expr in equivalence {
1985                        f(expr)?;
1986                    }
1987                }
1988                match implementation {
1989                    JoinImplementation::Differential((_, start_key, _), order) => {
1990                        if let Some(start_key) = start_key {
1991                            for k in start_key {
1992                                f(k)?;
1993                            }
1994                        }
1995                        for (_, lookup_key, _) in order {
1996                            for k in lookup_key {
1997                                f(k)?;
1998                            }
1999                        }
2000                    }
2001                    JoinImplementation::DeltaQuery(paths) => {
2002                        for path in paths {
2003                            for (_, lookup_key, _) in path {
2004                                for k in lookup_key {
2005                                    f(k)?;
2006                                }
2007                            }
2008                        }
2009                    }
2010                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2011                        for k in index_key {
2012                            f(k)?;
2013                        }
2014                    }
2015                    JoinImplementation::Unimplemented => {} // No scalar exprs
2016                }
2017            }
2018            ArrangeBy { keys, .. } => {
2019                for key in keys {
2020                    for s in key {
2021                        f(s)?;
2022                    }
2023                }
2024            }
2025            Reduce {
2026                group_key,
2027                aggregates,
2028                ..
2029            } => {
2030                for s in group_key {
2031                    f(s)?;
2032                }
2033                for agg in aggregates {
2034                    f(&agg.expr)?;
2035                }
2036            }
2037            TopK { limit, .. } => {
2038                if let Some(s) = limit {
2039                    f(s)?;
2040                }
2041            }
2042            Constant { .. }
2043            | Get { .. }
2044            | Let { .. }
2045            | LetRec { .. }
2046            | Project { .. }
2047            | Negate { .. }
2048            | Threshold { .. }
2049            | Union { .. } => (),
2050        }
2051        Ok(())
2052    }
2053
2054    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2055    /// rooted at `self`.
2056    ///
2057    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2058    /// nodes.
2059    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2060    where
2061        F: FnMut(&MirScalarExpr) -> Result<(), E>,
2062        E: From<RecursionLimitError>,
2063    {
2064        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2065    }
2066
2067    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2068    /// rooted at `self`.
2069    ///
2070    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2071    /// nodes.
2072    pub fn visit_scalars<F>(&self, f: &mut F)
2073    where
2074        F: FnMut(&MirScalarExpr),
2075    {
2076        self.try_visit_scalars(&mut |s| {
2077            f(s);
2078            Ok::<_, RecursionLimitError>(())
2079        })
2080        .expect("Unexpected error in `visit_scalars` call");
2081    }
2082
2083    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2084    /// stack overflow in `drop_in_place`.
2085    ///
2086    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2087    /// dropped or otherwise overwritten.
2088    pub fn destroy_carefully(&mut self) {
2089        let mut todo = vec![self.take_dangerous()];
2090        while let Some(mut expr) = todo.pop() {
2091            for child in expr.children_mut() {
2092                todo.push(child.take_dangerous());
2093            }
2094        }
2095    }
2096
2097    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2098    /// debug printing purposes.
2099    pub fn debug_size_and_depth(&self) -> (usize, usize) {
2100        let mut size = 0;
2101        let mut max_depth = 0;
2102        let mut todo = vec![(self, 1)];
2103        while let Some((expr, depth)) = todo.pop() {
2104            size += 1;
2105            max_depth = max(max_depth, depth);
2106            todo.extend(expr.children().map(|c| (c, depth + 1)));
2107        }
2108        (size, max_depth)
2109    }
2110
2111    /// The MirRelationExpr is considered potentially expensive if and only if
2112    /// at least one of the following conditions is true:
2113    ///
2114    ///  - It contains at least one FlatMap or a Reduce operator.
2115    ///  - It contains at least one MirScalarExpr with a function call.
2116    ///
2117    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2118    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2119    pub fn could_run_expensive_function(&self) -> bool {
2120        let mut result = false;
2121        self.visit_pre(|e: &MirRelationExpr| {
2122            use MirRelationExpr::*;
2123            use MirScalarExpr::*;
2124            if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2125                result |= match scalar {
2126                    Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2127                    // Function calls are considered expensive
2128                    CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2129                };
2130                Ok(())
2131            }) {
2132                // Conservatively set `true` if on RecursionLimitError.
2133                result = true;
2134            }
2135            // FlatMap has a table function; Reduce has an aggregate function.
2136            // Other constructs use MirScalarExpr to run a function
2137            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2138        });
2139        result
2140    }
2141
2142    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2143    /// than what `Hashable::hashed` would give us.)
2144    pub fn hash_to_u64(&self) -> u64 {
2145        let mut h = DefaultHasher::new();
2146        self.hash(&mut h);
2147        h.finish()
2148    }
2149}
2150
2151// `LetRec` helpers
2152impl MirRelationExpr {
2153    /// True when `expr` contains a `LetRec` AST node.
2154    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2155        let mut worklist = vec![self];
2156        while let Some(expr) = worklist.pop() {
2157            if let MirRelationExpr::LetRec { .. } = expr {
2158                return true;
2159            }
2160            worklist.extend(expr.children());
2161        }
2162        false
2163    }
2164
2165    /// Return the number of sub-expressions in the tree (including self).
2166    pub fn size(&self) -> usize {
2167        let mut size = 0;
2168        self.visit_pre(|_| size += 1);
2169        size
2170    }
2171
2172    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2173    /// iterations. These are those ids that have a reference before they are defined, when reading
2174    /// all the bindings in order.
2175    ///
2176    /// For example:
2177    /// ```SQL
2178    /// WITH MUTUALLY RECURSIVE
2179    ///     x(...) AS f(z),
2180    ///     y(...) AS g(x),
2181    ///     z(...) AS h(y)
2182    /// ...;
2183    /// ```
2184    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2185    /// iteration.
2186    ///
2187    /// Note that if a binding references itself, that is also returned.
2188    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2189        let mut used_across_iterations = BTreeSet::new();
2190        let mut defined = BTreeSet::new();
2191        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2192            value.visit_pre(|expr| {
2193                if let MirRelationExpr::Get {
2194                    id: Local(get_id), ..
2195                } = expr
2196                {
2197                    // If we haven't seen a definition for it yet, then this will refer
2198                    // to the previous iteration.
2199                    // The `ids.contains` part of the condition is needed to exclude
2200                    // those ids that are not really in this LetRec, but either an inner
2201                    // or outer one.
2202                    if !defined.contains(get_id) && ids.contains(get_id) {
2203                        used_across_iterations.insert(*get_id);
2204                    }
2205                }
2206            });
2207            defined.insert(*binding_id);
2208        }
2209        used_across_iterations
2210    }
2211
2212    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2213    ///
2214    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2215    /// identifiers are rewritten to be the constant collection.
2216    /// This makes the computation perform exactly "one" iteration.
2217    ///
2218    /// This was used only temporarily while developing `LetRec`.
2219    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2220        let mut deadlist = BTreeSet::new();
2221        let mut worklist = vec![self];
2222        while let Some(expr) = worklist.pop() {
2223            if let MirRelationExpr::LetRec {
2224                ids,
2225                values,
2226                limits: _,
2227                body,
2228            } = expr
2229            {
2230                let ids_values = values
2231                    .drain(..)
2232                    .zip_eq(ids)
2233                    .map(|(value, id)| (*id, value))
2234                    .collect::<Vec<_>>();
2235                *expr = body.take_dangerous();
2236                for (id, mut value) in ids_values.into_iter().rev() {
2237                    // Remove references to potentially recursive identifiers.
2238                    deadlist.insert(id);
2239                    value.visit_pre_mut(|e| {
2240                        if let MirRelationExpr::Get {
2241                            id: crate::Id::Local(id),
2242                            typ,
2243                            ..
2244                        } = e
2245                        {
2246                            let typ = typ.clone();
2247                            if deadlist.contains(id) {
2248                                e.take_safely(Some(typ));
2249                            }
2250                        }
2251                    });
2252                    *expr = MirRelationExpr::Let {
2253                        id,
2254                        value: Box::new(value),
2255                        body: Box::new(expr.take_dangerous()),
2256                    };
2257                }
2258                worklist.push(expr);
2259            } else {
2260                worklist.extend(expr.children_mut().rev());
2261            }
2262        }
2263    }
2264
2265    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2266    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2267    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2268    /// redefinition.
2269    ///
2270    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2271    pub fn collect_expirations(
2272        id: LocalId,
2273        expr: &MirRelationExpr,
2274        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2275    ) {
2276        expr.visit_pre(|e| {
2277            if let MirRelationExpr::Get {
2278                id: Id::Local(referenced_id),
2279                ..
2280            } = e
2281            {
2282                // The following check needs `renumber_bindings` to have run recently
2283                if referenced_id >= &id {
2284                    expire_whens
2285                        .entry(*referenced_id)
2286                        .or_insert_with(Vec::new)
2287                        .push(id);
2288                }
2289            }
2290        });
2291    }
2292
2293    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2294    /// about such Ids whose information depended on the earlier definition of `id`, according to
2295    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2296    pub fn do_expirations<I>(
2297        redefined_id: LocalId,
2298        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2299        id_infos: &mut BTreeMap<LocalId, I>,
2300    ) -> Vec<(LocalId, I)> {
2301        let mut expired_infos = Vec::new();
2302        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2303            for expired_id in expirations.into_iter() {
2304                if let Some(offer) = id_infos.remove(&expired_id) {
2305                    expired_infos.push((expired_id, offer));
2306                }
2307            }
2308        }
2309        expired_infos
2310    }
2311}
2312/// Augment non-nullability of columns, by observing either
2313/// 1. Predicates that explicitly test for null values, and
2314/// 2. Columns that if null would make a predicate be null.
2315pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2316    let mut nonnull_required_columns = BTreeSet::new();
2317    for predicate in predicates {
2318        // Add any columns that being null would force the predicate to be null.
2319        // Should that happen, the row would be discarded.
2320        predicate.non_null_requirements(&mut nonnull_required_columns);
2321
2322        /*
2323        Test for explicit checks that a column is non-null.
2324
2325        This analysis is ad hoc, and will miss things:
2326
2327        materialize=> create table a(x int, y int);
2328        CREATE TABLE
2329        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2330        Optimized Plan
2331        --------------------------------------------------------------------------------------------------------
2332        Explained Query:                                                                                      +
2333        Project (#0) // { types: "(integer?)" }                                                             +
2334        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2335        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2336                                                                                  +
2337        Source materialize.public.a                                                                           +
2338        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2339
2340        (1 row)
2341        */
2342
2343        if let MirScalarExpr::CallUnary {
2344            func: UnaryFunc::Not(scalar_func::Not),
2345            expr,
2346        } = predicate
2347        {
2348            if let MirScalarExpr::CallUnary {
2349                func: UnaryFunc::IsNull(scalar_func::IsNull),
2350                expr,
2351            } = &**expr
2352            {
2353                if let MirScalarExpr::Column(c, _name) = &**expr {
2354                    nonnull_required_columns.insert(*c);
2355                }
2356            }
2357        }
2358    }
2359
2360    nonnull_required_columns
2361}
2362
2363impl CollectionPlan for MirRelationExpr {
2364    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2365    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2366    ///
2367    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2368    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2369    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2370        if let MirRelationExpr::Get {
2371            id: Id::Global(id), ..
2372        } = self
2373        {
2374            out.insert(*id);
2375        }
2376        self.visit_children(|expr| expr.depends_on_into(out))
2377    }
2378}
2379
2380impl MirRelationExpr {
2381    /// Iterates through references to child expressions.
2382    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2383        let mut first = None;
2384        let mut second = None;
2385        let mut rest = None;
2386        let mut last = None;
2387
2388        use MirRelationExpr::*;
2389        match self {
2390            Constant { .. } | Get { .. } => (),
2391            Let { value, body, .. } => {
2392                first = Some(&**value);
2393                second = Some(&**body);
2394            }
2395            LetRec { values, body, .. } => {
2396                rest = Some(values);
2397                last = Some(&**body);
2398            }
2399            Project { input, .. }
2400            | Map { input, .. }
2401            | FlatMap { input, .. }
2402            | Filter { input, .. }
2403            | Reduce { input, .. }
2404            | TopK { input, .. }
2405            | Negate { input }
2406            | Threshold { input }
2407            | ArrangeBy { input, .. } => {
2408                first = Some(&**input);
2409            }
2410            Join { inputs, .. } => {
2411                rest = Some(inputs);
2412            }
2413            Union { base, inputs } => {
2414                first = Some(&**base);
2415                rest = Some(inputs);
2416            }
2417        }
2418
2419        first
2420            .into_iter()
2421            .chain(second)
2422            .chain(rest.into_iter().flatten())
2423            .chain(last)
2424    }
2425
2426    /// Iterates through mutable references to child expressions.
2427    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2428        let mut first = None;
2429        let mut second = None;
2430        let mut rest = None;
2431        let mut last = None;
2432
2433        use MirRelationExpr::*;
2434        match self {
2435            Constant { .. } | Get { .. } => (),
2436            Let { value, body, .. } => {
2437                first = Some(&mut **value);
2438                second = Some(&mut **body);
2439            }
2440            LetRec { values, body, .. } => {
2441                rest = Some(values);
2442                last = Some(&mut **body);
2443            }
2444            Project { input, .. }
2445            | Map { input, .. }
2446            | FlatMap { input, .. }
2447            | Filter { input, .. }
2448            | Reduce { input, .. }
2449            | TopK { input, .. }
2450            | Negate { input }
2451            | Threshold { input }
2452            | ArrangeBy { input, .. } => {
2453                first = Some(&mut **input);
2454            }
2455            Join { inputs, .. } => {
2456                rest = Some(inputs);
2457            }
2458            Union { base, inputs } => {
2459                first = Some(&mut **base);
2460                rest = Some(inputs);
2461            }
2462        }
2463
2464        first
2465            .into_iter()
2466            .chain(second)
2467            .chain(rest.into_iter().flatten())
2468            .chain(last)
2469    }
2470
2471    /// Iterative pre-order visitor.
2472    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2473        let mut worklist = vec![self];
2474        while let Some(expr) = worklist.pop() {
2475            f(expr);
2476            worklist.extend(expr.children().rev());
2477        }
2478    }
2479
2480    /// Iterative pre-order visitor.
2481    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2482        let mut worklist = vec![self];
2483        while let Some(expr) = worklist.pop() {
2484            f(expr);
2485            worklist.extend(expr.children_mut().rev());
2486        }
2487    }
2488
2489    /// Return a vector of references to the subtrees of this expression
2490    /// in post-visit order (the last element is `&self`).
2491    pub fn post_order_vec(&self) -> Vec<&Self> {
2492        let mut stack = vec![self];
2493        let mut result = vec![];
2494        while let Some(expr) = stack.pop() {
2495            result.push(expr);
2496            stack.extend(expr.children());
2497        }
2498        result.reverse();
2499        result
2500    }
2501}
2502
2503impl VisitChildren<Self> for MirRelationExpr {
2504    fn visit_children<F>(&self, mut f: F)
2505    where
2506        F: FnMut(&Self),
2507    {
2508        for child in self.children() {
2509            f(child)
2510        }
2511    }
2512
2513    fn visit_mut_children<F>(&mut self, mut f: F)
2514    where
2515        F: FnMut(&mut Self),
2516    {
2517        for child in self.children_mut() {
2518            f(child)
2519        }
2520    }
2521
2522    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2523    where
2524        F: FnMut(&Self) -> Result<(), E>,
2525        E: From<RecursionLimitError>,
2526    {
2527        for child in self.children() {
2528            f(child)?
2529        }
2530        Ok(())
2531    }
2532
2533    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2534    where
2535        F: FnMut(&mut Self) -> Result<(), E>,
2536        E: From<RecursionLimitError>,
2537    {
2538        for child in self.children_mut() {
2539            f(child)?
2540        }
2541        Ok(())
2542    }
2543}
2544
2545/// Specification for an ordering by a column.
2546#[derive(
2547    Debug,
2548    Clone,
2549    Copy,
2550    Eq,
2551    PartialEq,
2552    Ord,
2553    PartialOrd,
2554    Serialize,
2555    Deserialize,
2556    Hash,
2557    MzReflect
2558)]
2559pub struct ColumnOrder {
2560    /// The column index.
2561    pub column: usize,
2562    /// Whether to sort in descending order.
2563    #[serde(default)]
2564    pub desc: bool,
2565    /// Whether to sort nulls last.
2566    #[serde(default)]
2567    pub nulls_last: bool,
2568}
2569
2570impl Columnation for ColumnOrder {
2571    type InnerRegion = CopyRegion<Self>;
2572}
2573
2574impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2575where
2576    M: HumanizerMode,
2577{
2578    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2579        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2580        write!(
2581            f,
2582            "{} {} {}",
2583            self.child(&self.expr.column),
2584            if self.expr.desc { "desc" } else { "asc" },
2585            if self.expr.nulls_last {
2586                "nulls_last"
2587            } else {
2588                "nulls_first"
2589            },
2590        )
2591    }
2592}
2593
2594/// Describes an aggregation expression.
2595#[derive(
2596    Clone,
2597    Debug,
2598    Eq,
2599    PartialEq,
2600    Ord,
2601    PartialOrd,
2602    Serialize,
2603    Deserialize,
2604    Hash,
2605    MzReflect
2606)]
2607pub struct AggregateExpr {
2608    /// Names the aggregation function.
2609    pub func: AggregateFunc,
2610    /// An expression which extracts from each row the input to `func`.
2611    pub expr: MirScalarExpr,
2612    /// Should the aggregation be applied only to distinct results in each group.
2613    #[serde(default)]
2614    pub distinct: bool,
2615}
2616
2617impl AggregateExpr {
2618    /// Computes the type of this `AggregateExpr`.
2619    pub fn typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2620        self.func.output_type(self.expr.typ(column_types))
2621    }
2622
2623    /// Computes the type of this `AggregateExpr`.
2624    pub fn repr_typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2625        ReprColumnType::from(
2626            &self
2627                .func
2628                .output_type(SqlColumnType::from_repr(&self.expr.repr_typ(column_types))),
2629        )
2630    }
2631
2632    /// Returns whether the expression has a constant result.
2633    pub fn is_constant(&self) -> bool {
2634        match self.func {
2635            AggregateFunc::MaxInt16
2636            | AggregateFunc::MaxInt32
2637            | AggregateFunc::MaxInt64
2638            | AggregateFunc::MaxUInt16
2639            | AggregateFunc::MaxUInt32
2640            | AggregateFunc::MaxUInt64
2641            | AggregateFunc::MaxMzTimestamp
2642            | AggregateFunc::MaxFloat32
2643            | AggregateFunc::MaxFloat64
2644            | AggregateFunc::MaxBool
2645            | AggregateFunc::MaxString
2646            | AggregateFunc::MaxDate
2647            | AggregateFunc::MaxTimestamp
2648            | AggregateFunc::MaxTimestampTz
2649            | AggregateFunc::MinInt16
2650            | AggregateFunc::MinInt32
2651            | AggregateFunc::MinInt64
2652            | AggregateFunc::MinUInt16
2653            | AggregateFunc::MinUInt32
2654            | AggregateFunc::MinUInt64
2655            | AggregateFunc::MinMzTimestamp
2656            | AggregateFunc::MinFloat32
2657            | AggregateFunc::MinFloat64
2658            | AggregateFunc::MinBool
2659            | AggregateFunc::MinString
2660            | AggregateFunc::MinDate
2661            | AggregateFunc::MinTimestamp
2662            | AggregateFunc::MinTimestampTz
2663            | AggregateFunc::Any
2664            | AggregateFunc::All
2665            | AggregateFunc::Dummy => self.expr.is_literal(),
2666            AggregateFunc::Count => self.expr.is_literal_null(),
2667            _ => self.expr.is_literal_err(),
2668        }
2669    }
2670
2671    /// Returns an expression that computes `self` on a group that has exactly one row.
2672    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2673    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2674    pub fn on_unique(&self, input_type: &[SqlColumnType]) -> MirScalarExpr {
2675        match &self.func {
2676            // Count is one if non-null, and zero if null.
2677            AggregateFunc::Count => self
2678                .expr
2679                .clone()
2680                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2681                .if_then_else(
2682                    MirScalarExpr::literal_ok(Datum::Int64(0), SqlScalarType::Int64),
2683                    MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
2684                ),
2685
2686            // SumInt16 takes Int16s as input, but outputs Int64s.
2687            AggregateFunc::SumInt16 => self
2688                .expr
2689                .clone()
2690                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2691
2692            // SumInt32 takes Int32s as input, but outputs Int64s.
2693            AggregateFunc::SumInt32 => self
2694                .expr
2695                .clone()
2696                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2697
2698            // SumInt64 takes Int64s as input, but outputs numerics.
2699            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2700                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2701            )),
2702
2703            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2704            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2705                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2706            ),
2707
2708            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2709            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2710                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2711            ),
2712
2713            // SumUInt64 takes UInt64s as input, but outputs numerics.
2714            AggregateFunc::SumUInt64 => {
2715                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2716                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2717                ))
2718            }
2719
2720            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2721            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2722                func: VariadicFunc::JsonbBuildArray,
2723                exprs: vec![
2724                    self.expr
2725                        .clone()
2726                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2727                ],
2728            },
2729
2730            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2731            AggregateFunc::JsonbObjectAgg { .. } => {
2732                let record = self
2733                    .expr
2734                    .clone()
2735                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2736                MirScalarExpr::CallVariadic {
2737                    func: VariadicFunc::JsonbBuildObject,
2738                    exprs: (0..2)
2739                        .map(|i| {
2740                            record
2741                                .clone()
2742                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2743                        })
2744                        .collect(),
2745                }
2746            }
2747
2748            AggregateFunc::MapAgg { value_type, .. } => {
2749                let record = self
2750                    .expr
2751                    .clone()
2752                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2753                MirScalarExpr::CallVariadic {
2754                    func: VariadicFunc::MapBuild {
2755                        value_type: value_type.clone(),
2756                    },
2757                    exprs: (0..2)
2758                        .map(|i| {
2759                            record
2760                                .clone()
2761                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2762                        })
2763                        .collect(),
2764                }
2765            }
2766
2767            // StringAgg takes nested records of strings and outputs a string
2768            AggregateFunc::StringAgg { .. } => self
2769                .expr
2770                .clone()
2771                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2772                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2773
2774            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2775            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2776                .expr
2777                .clone()
2778                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2779
2780            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2781            AggregateFunc::RowNumber { .. } => {
2782                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2783            }
2784            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2785            AggregateFunc::DenseRank { .. } => {
2786                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2787            }
2788
2789            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2790            AggregateFunc::LagLead { lag_lead, .. } => {
2791                let tuple = self
2792                    .expr
2793                    .clone()
2794                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2795
2796                // Get the overall return type
2797                let return_type_with_orig_row = self
2798                    .typ(input_type)
2799                    .scalar_type
2800                    .unwrap_list_element_type()
2801                    .clone();
2802                let lag_lead_return_type =
2803                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2804
2805                // Extract the original row
2806                let original_row = tuple
2807                    .clone()
2808                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2809
2810                // Extract the encoded args
2811                let encoded_args =
2812                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2813
2814                let (result_expr, column_name) =
2815                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2816
2817                MirScalarExpr::CallVariadic {
2818                    func: VariadicFunc::ListCreate {
2819                        elem_type: return_type_with_orig_row,
2820                    },
2821                    exprs: vec![MirScalarExpr::CallVariadic {
2822                        func: VariadicFunc::RecordCreate {
2823                            field_names: vec![column_name, ColumnName::from("?record?")],
2824                        },
2825                        exprs: vec![result_expr, original_row],
2826                    }],
2827                }
2828            }
2829
2830            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2831            AggregateFunc::FirstValue { window_frame, .. } => {
2832                let tuple = self
2833                    .expr
2834                    .clone()
2835                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2836
2837                // Get the overall return type
2838                let return_type_with_orig_row = self
2839                    .typ(input_type)
2840                    .scalar_type
2841                    .unwrap_list_element_type()
2842                    .clone();
2843                let first_value_return_type =
2844                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2845
2846                // Extract the original row
2847                let original_row = tuple
2848                    .clone()
2849                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2850
2851                // Extract the input value
2852                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2853
2854                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2855                    window_frame,
2856                    arg,
2857                    first_value_return_type,
2858                );
2859
2860                MirScalarExpr::CallVariadic {
2861                    func: VariadicFunc::ListCreate {
2862                        elem_type: return_type_with_orig_row,
2863                    },
2864                    exprs: vec![MirScalarExpr::CallVariadic {
2865                        func: VariadicFunc::RecordCreate {
2866                            field_names: vec![column_name, ColumnName::from("?record?")],
2867                        },
2868                        exprs: vec![result_expr, original_row],
2869                    }],
2870                }
2871            }
2872
2873            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2874            AggregateFunc::LastValue { window_frame, .. } => {
2875                let tuple = self
2876                    .expr
2877                    .clone()
2878                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2879
2880                // Get the overall return type
2881                let return_type_with_orig_row = self
2882                    .typ(input_type)
2883                    .scalar_type
2884                    .unwrap_list_element_type()
2885                    .clone();
2886                let last_value_return_type =
2887                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2888
2889                // Extract the original row
2890                let original_row = tuple
2891                    .clone()
2892                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2893
2894                // Extract the input value
2895                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2896
2897                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2898                    window_frame,
2899                    arg,
2900                    last_value_return_type,
2901                );
2902
2903                MirScalarExpr::CallVariadic {
2904                    func: VariadicFunc::ListCreate {
2905                        elem_type: return_type_with_orig_row,
2906                    },
2907                    exprs: vec![MirScalarExpr::CallVariadic {
2908                        func: VariadicFunc::RecordCreate {
2909                            field_names: vec![column_name, ColumnName::from("?record?")],
2910                        },
2911                        exprs: vec![result_expr, original_row],
2912                    }],
2913                }
2914            }
2915
2916            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2917            // See an example MIR in `window_func_applied_to`.
2918            AggregateFunc::WindowAggregate {
2919                wrapped_aggregate,
2920                window_frame,
2921                order_by: _,
2922            } => {
2923                // TODO: deduplicate code between the various window function cases.
2924
2925                let tuple = self
2926                    .expr
2927                    .clone()
2928                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2929
2930                // Get the overall return type
2931                let return_type = self
2932                    .typ(input_type)
2933                    .scalar_type
2934                    .unwrap_list_element_type()
2935                    .clone();
2936                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2937
2938                // Extract the original row
2939                let original_row = tuple
2940                    .clone()
2941                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2942
2943                // Extract the input value
2944                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2945
2946                let (result, column_name) = Self::on_unique_window_agg(
2947                    window_frame,
2948                    arg_expr,
2949                    input_type,
2950                    window_agg_return_type,
2951                    wrapped_aggregate,
2952                );
2953
2954                MirScalarExpr::CallVariadic {
2955                    func: VariadicFunc::ListCreate {
2956                        elem_type: return_type,
2957                    },
2958                    exprs: vec![MirScalarExpr::CallVariadic {
2959                        func: VariadicFunc::RecordCreate {
2960                            field_names: vec![column_name, ColumnName::from("?record?")],
2961                        },
2962                        exprs: vec![result, original_row],
2963                    }],
2964                }
2965            }
2966
2967            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2968            AggregateFunc::FusedWindowAggregate {
2969                wrapped_aggregates,
2970                order_by: _,
2971                window_frame,
2972            } => {
2973                // Throw away OrderByExprs
2974                let tuple = self
2975                    .expr
2976                    .clone()
2977                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2978
2979                // Extract the original row
2980                let original_row = tuple
2981                    .clone()
2982                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2983
2984                // Extract the args of the fused call
2985                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2986
2987                let return_type_with_orig_row = self
2988                    .typ(input_type)
2989                    .scalar_type
2990                    .unwrap_list_element_type()
2991                    .clone();
2992
2993                let all_func_return_types =
2994                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2995                let mut func_result_exprs = Vec::new();
2996                let mut col_names = Vec::new();
2997                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2998                    let arg = all_args
2999                        .clone()
3000                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3001                    let return_type =
3002                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3003                    let (result, column_name) = Self::on_unique_window_agg(
3004                        window_frame,
3005                        arg,
3006                        input_type,
3007                        return_type,
3008                        wrapped_aggr,
3009                    );
3010                    func_result_exprs.push(result);
3011                    col_names.push(column_name);
3012                }
3013
3014                MirScalarExpr::CallVariadic {
3015                    func: VariadicFunc::ListCreate {
3016                        elem_type: return_type_with_orig_row,
3017                    },
3018                    exprs: vec![MirScalarExpr::CallVariadic {
3019                        func: VariadicFunc::RecordCreate {
3020                            field_names: vec![
3021                                ColumnName::from("?fused_window_aggr?"),
3022                                ColumnName::from("?record?"),
3023                            ],
3024                        },
3025                        exprs: vec![
3026                            MirScalarExpr::CallVariadic {
3027                                func: VariadicFunc::RecordCreate {
3028                                    field_names: col_names,
3029                                },
3030                                exprs: func_result_exprs,
3031                            },
3032                            original_row,
3033                        ],
3034                    }],
3035                }
3036            }
3037
3038            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3039            AggregateFunc::FusedValueWindowFunc {
3040                funcs,
3041                order_by: outer_order_by,
3042            } => {
3043                // Throw away OrderByExprs
3044                let tuple = self
3045                    .expr
3046                    .clone()
3047                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3048
3049                // Extract the original row
3050                let original_row = tuple
3051                    .clone()
3052                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3053
3054                // Extract the encoded args of the fused call
3055                let all_encoded_args =
3056                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3057
3058                let return_type_with_orig_row = self
3059                    .typ(input_type)
3060                    .scalar_type
3061                    .unwrap_list_element_type()
3062                    .clone();
3063
3064                let all_func_return_types =
3065                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3066                let mut func_result_exprs = Vec::new();
3067                let mut col_names = Vec::new();
3068                for (idx, func) in funcs.iter().enumerate() {
3069                    let args_for_func = all_encoded_args
3070                        .clone()
3071                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3072                    let return_type_for_func =
3073                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3074                    let (result, column_name) = match func {
3075                        AggregateFunc::LagLead {
3076                            lag_lead,
3077                            order_by,
3078                            ignore_nulls: _,
3079                        } => {
3080                            assert_eq!(order_by, outer_order_by);
3081                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3082                        }
3083                        AggregateFunc::FirstValue {
3084                            window_frame,
3085                            order_by,
3086                        } => {
3087                            assert_eq!(order_by, outer_order_by);
3088                            Self::on_unique_first_value_last_value(
3089                                window_frame,
3090                                args_for_func,
3091                                return_type_for_func,
3092                            )
3093                        }
3094                        AggregateFunc::LastValue {
3095                            window_frame,
3096                            order_by,
3097                        } => {
3098                            assert_eq!(order_by, outer_order_by);
3099                            Self::on_unique_first_value_last_value(
3100                                window_frame,
3101                                args_for_func,
3102                                return_type_for_func,
3103                            )
3104                        }
3105                        _ => panic!("unknown function in FusedValueWindowFunc"),
3106                    };
3107                    func_result_exprs.push(result);
3108                    col_names.push(column_name);
3109                }
3110
3111                MirScalarExpr::CallVariadic {
3112                    func: VariadicFunc::ListCreate {
3113                        elem_type: return_type_with_orig_row,
3114                    },
3115                    exprs: vec![MirScalarExpr::CallVariadic {
3116                        func: VariadicFunc::RecordCreate {
3117                            field_names: vec![
3118                                ColumnName::from("?fused_value_window_func?"),
3119                                ColumnName::from("?record?"),
3120                            ],
3121                        },
3122                        exprs: vec![
3123                            MirScalarExpr::CallVariadic {
3124                                func: VariadicFunc::RecordCreate {
3125                                    field_names: col_names,
3126                                },
3127                                exprs: func_result_exprs,
3128                            },
3129                            original_row,
3130                        ],
3131                    }],
3132                }
3133            }
3134
3135            // All other variants should return the argument to the aggregation.
3136            AggregateFunc::MaxNumeric
3137            | AggregateFunc::MaxInt16
3138            | AggregateFunc::MaxInt32
3139            | AggregateFunc::MaxInt64
3140            | AggregateFunc::MaxUInt16
3141            | AggregateFunc::MaxUInt32
3142            | AggregateFunc::MaxUInt64
3143            | AggregateFunc::MaxMzTimestamp
3144            | AggregateFunc::MaxFloat32
3145            | AggregateFunc::MaxFloat64
3146            | AggregateFunc::MaxBool
3147            | AggregateFunc::MaxString
3148            | AggregateFunc::MaxDate
3149            | AggregateFunc::MaxTimestamp
3150            | AggregateFunc::MaxTimestampTz
3151            | AggregateFunc::MaxInterval
3152            | AggregateFunc::MaxTime
3153            | AggregateFunc::MinNumeric
3154            | AggregateFunc::MinInt16
3155            | AggregateFunc::MinInt32
3156            | AggregateFunc::MinInt64
3157            | AggregateFunc::MinUInt16
3158            | AggregateFunc::MinUInt32
3159            | AggregateFunc::MinUInt64
3160            | AggregateFunc::MinMzTimestamp
3161            | AggregateFunc::MinFloat32
3162            | AggregateFunc::MinFloat64
3163            | AggregateFunc::MinBool
3164            | AggregateFunc::MinString
3165            | AggregateFunc::MinDate
3166            | AggregateFunc::MinTimestamp
3167            | AggregateFunc::MinTimestampTz
3168            | AggregateFunc::MinInterval
3169            | AggregateFunc::MinTime
3170            | AggregateFunc::SumFloat32
3171            | AggregateFunc::SumFloat64
3172            | AggregateFunc::SumNumeric
3173            | AggregateFunc::Any
3174            | AggregateFunc::All
3175            | AggregateFunc::Dummy => self.expr.clone(),
3176        }
3177    }
3178
3179    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3180    fn on_unique_ranking_window_funcs(
3181        &self,
3182        input_type: &[SqlColumnType],
3183        col_name: &str,
3184    ) -> MirScalarExpr {
3185        let list = self
3186            .expr
3187            .clone()
3188            // extract the list within the record
3189            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3190
3191        // extract the expression within the list
3192        let record = MirScalarExpr::CallVariadic {
3193            func: VariadicFunc::ListIndex,
3194            exprs: vec![
3195                list,
3196                MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3197            ],
3198        };
3199
3200        MirScalarExpr::CallVariadic {
3201            func: VariadicFunc::ListCreate {
3202                elem_type: self
3203                    .typ(input_type)
3204                    .scalar_type
3205                    .unwrap_list_element_type()
3206                    .clone(),
3207            },
3208            exprs: vec![MirScalarExpr::CallVariadic {
3209                func: VariadicFunc::RecordCreate {
3210                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3211                },
3212                exprs: vec![
3213                    MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3214                    record,
3215                ],
3216            }],
3217        }
3218    }
3219
3220    /// `on_unique` for `lag` and `lead`
3221    fn on_unique_lag_lead(
3222        lag_lead: &LagLeadType,
3223        encoded_args: MirScalarExpr,
3224        return_type: SqlScalarType,
3225    ) -> (MirScalarExpr, ColumnName) {
3226        let expr = encoded_args
3227            .clone()
3228            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3229        let offset = encoded_args
3230            .clone()
3231            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3232        let default_value =
3233            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3234
3235        // In this case, the window always has only one element, so if the offset is not null and
3236        // not zero, the default value should be returned instead.
3237        let value = offset
3238            .clone()
3239            .call_binary(
3240                MirScalarExpr::literal_ok(Datum::Int32(0), SqlScalarType::Int32),
3241                crate::func::Eq,
3242            )
3243            .if_then_else(expr, default_value);
3244        let result_expr = offset
3245            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3246            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3247
3248        let column_name = ColumnName::from(match lag_lead {
3249            LagLeadType::Lag => "?lag?",
3250            LagLeadType::Lead => "?lead?",
3251        });
3252
3253        (result_expr, column_name)
3254    }
3255
3256    /// `on_unique` for `first_value` and `last_value`
3257    fn on_unique_first_value_last_value(
3258        window_frame: &WindowFrame,
3259        arg: MirScalarExpr,
3260        return_type: SqlScalarType,
3261    ) -> (MirScalarExpr, ColumnName) {
3262        // If the window frame includes the current (single) row, return its value, null otherwise
3263        let result_expr = if window_frame.includes_current_row() {
3264            arg
3265        } else {
3266            MirScalarExpr::literal_null(return_type)
3267        };
3268        (result_expr, ColumnName::from("?first_value?"))
3269    }
3270
3271    /// `on_unique` for window aggregations
3272    fn on_unique_window_agg(
3273        window_frame: &WindowFrame,
3274        arg_expr: MirScalarExpr,
3275        input_type: &[SqlColumnType],
3276        return_type: SqlScalarType,
3277        wrapped_aggr: &AggregateFunc,
3278    ) -> (MirScalarExpr, ColumnName) {
3279        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3280        // that row. Otherwise, return the default value for the aggregate.
3281        let result_expr = if window_frame.includes_current_row() {
3282            AggregateExpr {
3283                func: wrapped_aggr.clone(),
3284                expr: arg_expr,
3285                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3286            }
3287            .on_unique(input_type)
3288        } else {
3289            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3290        };
3291        (result_expr, ColumnName::from("?window_agg?"))
3292    }
3293
3294    /// Returns whether the expression is COUNT(*) or not.  Note that
3295    /// when we define the count builtin in sql::func, we convert
3296    /// COUNT(*) to COUNT(true), making it indistinguishable from
3297    /// literal COUNT(true), but we prefer to consider this as the
3298    /// former.
3299    ///
3300    /// (HIR has the same `is_count_asterisk`.)
3301    pub fn is_count_asterisk(&self) -> bool {
3302        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3303    }
3304}
3305
3306/// Describe a join implementation in dataflow.
3307#[derive(
3308    Clone,
3309    Debug,
3310    Eq,
3311    PartialEq,
3312    Ord,
3313    PartialOrd,
3314    Serialize,
3315    Deserialize,
3316    Hash,
3317    MzReflect
3318)]
3319pub enum JoinImplementation {
3320    /// Perform a sequence of binary differential dataflow joins.
3321    ///
3322    /// The first argument indicates
3323    /// 1) the index of the starting collection,
3324    /// 2) if it should be arranged, the keys to arrange it by, and
3325    /// 3) the characteristics of the starting collection (for EXPLAINing).
3326    /// The sequence that follows lists other relation indexes, and the key for
3327    /// the arrangement we should use when joining it in.
3328    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3329    /// were used for join ordering.
3330    ///
3331    /// Each collection index should occur exactly once, either as the starting collection
3332    /// or somewhere in the list.
3333    Differential(
3334        (
3335            usize,
3336            Option<Vec<MirScalarExpr>>,
3337            Option<JoinInputCharacteristics>,
3338        ),
3339        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3340    ),
3341    /// Perform independent delta query dataflows for each input.
3342    ///
3343    /// The argument is a sequence of plans, for the input collections in order.
3344    /// Each plan starts from the corresponding index, and then in sequence joins
3345    /// against collections identified by index and with the specified arrangement key.
3346    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3347    /// used for join ordering.
3348    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3349    /// Join a user-created index with a constant collection to speed up the evaluation of a
3350    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3351    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3352    /// to represent it in MIR, because the fast path detection wants to match on this.
3353    ///
3354    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3355    IndexedFilter(
3356        GlobalId,
3357        GlobalId,
3358        Vec<MirScalarExpr>,
3359        #[mzreflect(ignore)] Vec<Row>,
3360    ),
3361    /// No implementation yet selected.
3362    Unimplemented,
3363}
3364
3365impl Default for JoinImplementation {
3366    fn default() -> Self {
3367        JoinImplementation::Unimplemented
3368    }
3369}
3370
3371impl JoinImplementation {
3372    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3373    pub fn is_implemented(&self) -> bool {
3374        match self {
3375            Self::Unimplemented => false,
3376            _ => true,
3377        }
3378    }
3379
3380    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3381    pub fn name(&self) -> Option<&'static str> {
3382        match self {
3383            Self::Differential(..) => Some("differential"),
3384            Self::DeltaQuery(..) => Some("delta"),
3385            Self::IndexedFilter(..) => Some("indexed_filter"),
3386            Self::Unimplemented => None,
3387        }
3388    }
3389}
3390
3391/// Characteristics of a join order candidate collection.
3392///
3393/// A candidate is described by a collection and a key, and may have various liabilities.
3394/// Primarily, the candidate may risk substantial inflation of records, which is something
3395/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3396/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3397/// collections in the interest of consistent tie-breaking. For more characteristics, see
3398/// comments on individual fields.
3399///
3400/// This has more than one version. `new` instantiates the appropriate version based on a
3401/// feature flag.
3402#[derive(
3403    Eq,
3404    PartialEq,
3405    Ord,
3406    PartialOrd,
3407    Debug,
3408    Clone,
3409    Serialize,
3410    Deserialize,
3411    Hash,
3412    MzReflect
3413)]
3414pub enum JoinInputCharacteristics {
3415    /// Old version, with `enable_join_prioritize_arranged` turned off.
3416    V1(JoinInputCharacteristicsV1),
3417    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3418    V2(JoinInputCharacteristicsV2),
3419}
3420
3421impl JoinInputCharacteristics {
3422    /// Creates a new instance with the given characteristics.
3423    pub fn new(
3424        unique_key: bool,
3425        key_length: usize,
3426        arranged: bool,
3427        cardinality: Option<usize>,
3428        filters: FilterCharacteristics,
3429        input: usize,
3430        enable_join_prioritize_arranged: bool,
3431    ) -> Self {
3432        if enable_join_prioritize_arranged {
3433            Self::V2(JoinInputCharacteristicsV2::new(
3434                unique_key,
3435                key_length,
3436                arranged,
3437                cardinality,
3438                filters,
3439                input,
3440            ))
3441        } else {
3442            Self::V1(JoinInputCharacteristicsV1::new(
3443                unique_key,
3444                key_length,
3445                arranged,
3446                cardinality,
3447                filters,
3448                input,
3449            ))
3450        }
3451    }
3452
3453    /// Turns the instance into a String to be printed in EXPLAIN.
3454    pub fn explain(&self) -> String {
3455        match self {
3456            Self::V1(jic) => jic.explain(),
3457            Self::V2(jic) => jic.explain(),
3458        }
3459    }
3460
3461    /// Whether the join input described by `self` is arranged.
3462    pub fn arranged(&self) -> bool {
3463        match self {
3464            Self::V1(jic) => jic.arranged,
3465            Self::V2(jic) => jic.arranged,
3466        }
3467    }
3468
3469    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3470    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3471        match self {
3472            Self::V1(jic) => &mut jic.filters,
3473            Self::V2(jic) => &mut jic.filters,
3474        }
3475    }
3476}
3477
3478/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3479#[derive(
3480    Eq,
3481    PartialEq,
3482    Ord,
3483    PartialOrd,
3484    Debug,
3485    Clone,
3486    Serialize,
3487    Deserialize,
3488    Hash,
3489    MzReflect
3490)]
3491pub struct JoinInputCharacteristicsV2 {
3492    /// An excellent indication that record count will not increase.
3493    pub unique_key: bool,
3494    /// Cross joins are bad.
3495    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3496    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3497    /// but otherwise `key_length` is less important than `arranged`.)
3498    pub not_cross: bool,
3499    /// Indicates that there will be no additional in-memory footprint.
3500    pub arranged: bool,
3501    /// A weaker signal that record count will not increase.
3502    pub key_length: usize,
3503    /// Estimated cardinality (lower is better)
3504    pub cardinality: Option<std::cmp::Reverse<usize>>,
3505    /// Characteristics of the filter that is applied at this input.
3506    pub filters: FilterCharacteristics,
3507    /// We want to prefer input earlier in the input list, for stability of ordering.
3508    pub input: std::cmp::Reverse<usize>,
3509}
3510
3511impl JoinInputCharacteristicsV2 {
3512    /// Creates a new instance with the given characteristics.
3513    pub fn new(
3514        unique_key: bool,
3515        key_length: usize,
3516        arranged: bool,
3517        cardinality: Option<usize>,
3518        filters: FilterCharacteristics,
3519        input: usize,
3520    ) -> Self {
3521        Self {
3522            unique_key,
3523            not_cross: key_length > 0,
3524            arranged,
3525            key_length,
3526            cardinality: cardinality.map(std::cmp::Reverse),
3527            filters,
3528            input: std::cmp::Reverse(input),
3529        }
3530    }
3531
3532    /// Turns the instance into a String to be printed in EXPLAIN.
3533    pub fn explain(&self) -> String {
3534        let mut e = "".to_owned();
3535        if self.unique_key {
3536            e.push_str("U");
3537        }
3538        // Don't need to print `not_cross`, because that is visible in the printed key.
3539        // if !self.not_cross {
3540        //     e.push_str("C");
3541        // }
3542        for _ in 0..self.key_length {
3543            e.push_str("K");
3544        }
3545        if self.arranged {
3546            e.push_str("A");
3547        }
3548        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3549            e.push_str(&format!("|{cardinality}|"));
3550        }
3551        e.push_str(&self.filters.explain());
3552        e
3553    }
3554}
3555
3556/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3557#[derive(
3558    Eq,
3559    PartialEq,
3560    Ord,
3561    PartialOrd,
3562    Debug,
3563    Clone,
3564    Serialize,
3565    Deserialize,
3566    Hash,
3567    MzReflect
3568)]
3569pub struct JoinInputCharacteristicsV1 {
3570    /// An excellent indication that record count will not increase.
3571    pub unique_key: bool,
3572    /// A weaker signal that record count will not increase.
3573    pub key_length: usize,
3574    /// Indicates that there will be no additional in-memory footprint.
3575    pub arranged: bool,
3576    /// Estimated cardinality (lower is better)
3577    pub cardinality: Option<std::cmp::Reverse<usize>>,
3578    /// Characteristics of the filter that is applied at this input.
3579    pub filters: FilterCharacteristics,
3580    /// We want to prefer input earlier in the input list, for stability of ordering.
3581    pub input: std::cmp::Reverse<usize>,
3582}
3583
3584impl JoinInputCharacteristicsV1 {
3585    /// Creates a new instance with the given characteristics.
3586    pub fn new(
3587        unique_key: bool,
3588        key_length: usize,
3589        arranged: bool,
3590        cardinality: Option<usize>,
3591        filters: FilterCharacteristics,
3592        input: usize,
3593    ) -> Self {
3594        Self {
3595            unique_key,
3596            key_length,
3597            arranged,
3598            cardinality: cardinality.map(std::cmp::Reverse),
3599            filters,
3600            input: std::cmp::Reverse(input),
3601        }
3602    }
3603
3604    /// Turns the instance into a String to be printed in EXPLAIN.
3605    pub fn explain(&self) -> String {
3606        let mut e = "".to_owned();
3607        if self.unique_key {
3608            e.push_str("U");
3609        }
3610        for _ in 0..self.key_length {
3611            e.push_str("K");
3612        }
3613        if self.arranged {
3614            e.push_str("A");
3615        }
3616        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3617            e.push_str(&format!("|{cardinality}|"));
3618        }
3619        e.push_str(&self.filters.explain());
3620        e
3621    }
3622}
3623
3624/// Instructions for finishing the result of a query.
3625///
3626/// The primary reason for the existence of this structure and attendant code
3627/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3628/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3629/// multisets. But as it turns out, the same idea can be used to optimize
3630/// trivial peeks.
3631///
3632/// The generic parameters are for accommodating prepared statement parameters in
3633/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3634/// `bind_parameters` on them.
3635#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3636pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3637    /// Order rows by the given columns.
3638    pub order_by: Vec<ColumnOrder>,
3639    /// Include only as many rows (after offset).
3640    pub limit: Option<L>,
3641    /// Omit as many rows.
3642    pub offset: O,
3643    /// Include only given columns.
3644    pub project: Vec<usize>,
3645}
3646
3647impl<L> RowSetFinishing<L> {
3648    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3649    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3650        RowSetFinishing {
3651            order_by: Vec::new(),
3652            limit: None,
3653            offset: 0,
3654            project: (0..arity).collect(),
3655        }
3656    }
3657    /// True if the finishing does nothing to any result set.
3658    pub fn is_trivial(&self, arity: usize) -> bool {
3659        self.limit.is_none()
3660            && self.order_by.is_empty()
3661            && self.offset == 0
3662            && self.project.iter().copied().eq(0..arity)
3663    }
3664    /// True if the finishing does not require an ORDER BY.
3665    ///
3666    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3667    /// explicit ordering we will skip an arbitrary bag of elements and return
3668    /// the first arbitrary elements in the remaining bag. The result semantics
3669    /// are still correct but maybe surprising for some users.
3670    pub fn is_streamable(&self, arity: usize) -> bool {
3671        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3672    }
3673}
3674
3675impl RowSetFinishing<NonNeg<i64>, usize> {
3676    /// The number of rows needed from before the finishing to evaluate the finishing:
3677    /// offset + limit.
3678    ///
3679    /// If it returns None, then we need all the rows.
3680    pub fn num_rows_needed(&self) -> Option<usize> {
3681        self.limit
3682            .as_ref()
3683            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3684    }
3685}
3686
3687impl RowSetFinishing {
3688    /// Applies finishing actions to a [`RowCollection`], and reports the total
3689    /// time it took to run.
3690    ///
3691    /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3692    /// well as the size of the response in bytes.
3693    pub fn finish(
3694        &self,
3695        rows: RowCollection,
3696        max_result_size: u64,
3697        max_returned_query_size: Option<u64>,
3698        duration_histogram: &Histogram,
3699    ) -> Result<(SortedRowCollectionIter, usize), String> {
3700        let now = Instant::now();
3701        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3702        let duration = now.elapsed();
3703        duration_histogram.observe(duration.as_secs_f64());
3704
3705        result
3706    }
3707
3708    /// Implementation for [`RowSetFinishing::finish`].
3709    fn finish_inner(
3710        &self,
3711        rows: RowCollection,
3712        max_result_size: u64,
3713        max_returned_query_size: Option<u64>,
3714    ) -> Result<(SortedRowCollectionIter, usize), String> {
3715        // How much additional memory is required to make a sorted view.
3716        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3717        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3718
3719        // Bail if creating the sorted view would require us to use too much memory.
3720        if required_memory > usize::cast_from(max_result_size) {
3721            let max_bytes = ByteSize::b(max_result_size);
3722            return Err(format!("result exceeds max size of {max_bytes}",));
3723        }
3724
3725        let sorted_view = rows.sorted_view(&self.order_by);
3726        let mut iter = sorted_view
3727            .into_row_iter()
3728            .apply_offset(self.offset)
3729            .with_projection(self.project.clone());
3730
3731        if let Some(limit) = self.limit {
3732            let limit = u64::from(limit);
3733            let limit = usize::cast_from(limit);
3734            iter = iter.with_limit(limit);
3735        };
3736
3737        // TODO(parkmycar): Re-think how we can calculate the total response size without
3738        // having to iterate through the entire collection of Rows, while still
3739        // respecting the LIMIT, OFFSET, and projections.
3740        //
3741        // Note: It feels a bit bad always calculating the response size, but we almost
3742        // always need it to either check the `max_returned_query_size`, or for reporting
3743        // in the query history.
3744        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3745
3746        // Bail if we would end up returning more data to the client than they can support.
3747        if let Some(max) = max_returned_query_size {
3748            if response_size > usize::cast_from(max) {
3749                let max_bytes = ByteSize::b(max);
3750                return Err(format!("result exceeds max size of {max_bytes}"));
3751            }
3752        }
3753
3754        Ok((iter, response_size))
3755    }
3756}
3757
3758/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3759/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3760/// on query result size.
3761#[derive(Debug)]
3762pub struct RowSetFinishingIncremental {
3763    /// Include only as many rows (after offset).
3764    pub remaining_limit: Option<usize>,
3765    /// Omit as many rows.
3766    pub remaining_offset: usize,
3767    /// The maximum allowed result size, as requested by the client.
3768    pub max_returned_query_size: Option<u64>,
3769    /// Tracks our remaining allowed budget for result size.
3770    pub remaining_max_returned_query_size: Option<u64>,
3771    /// Include only given columns.
3772    pub project: Vec<usize>,
3773}
3774
3775impl RowSetFinishingIncremental {
3776    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3777    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3778    /// `true`.
3779    ///
3780    /// # Panics
3781    ///
3782    /// Panics if the result is not streamable, that is it has an ORDER BY.
3783    pub fn new(
3784        offset: usize,
3785        limit: Option<NonNeg<i64>>,
3786        project: Vec<usize>,
3787        max_returned_query_size: Option<u64>,
3788    ) -> Self {
3789        let limit = limit.map(|l| {
3790            let l = u64::from(l);
3791            let l = usize::cast_from(l);
3792            l
3793        });
3794
3795        RowSetFinishingIncremental {
3796            remaining_limit: limit,
3797            remaining_offset: offset,
3798            max_returned_query_size,
3799            remaining_max_returned_query_size: max_returned_query_size,
3800            project,
3801        }
3802    }
3803
3804    /// Applies finishing actions to the given [`RowCollection`], and reports
3805    /// the total time it took to run.
3806    ///
3807    /// Returns a [`SortedRowCollectionIter`] that contains all of the response
3808    /// data.
3809    pub fn finish_incremental(
3810        &mut self,
3811        rows: RowCollection,
3812        max_result_size: u64,
3813        duration_histogram: &Histogram,
3814    ) -> Result<SortedRowCollectionIter, String> {
3815        let now = Instant::now();
3816        let result = self.finish_incremental_inner(rows, max_result_size);
3817        let duration = now.elapsed();
3818        duration_histogram.observe(duration.as_secs_f64());
3819
3820        result
3821    }
3822
3823    fn finish_incremental_inner(
3824        &mut self,
3825        rows: RowCollection,
3826        max_result_size: u64,
3827    ) -> Result<SortedRowCollectionIter, String> {
3828        // How much additional memory is required to make a sorted view.
3829        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3830        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3831
3832        // Bail if creating the sorted view would require us to use too much memory.
3833        if required_memory > usize::cast_from(max_result_size) {
3834            let max_bytes = ByteSize::b(max_result_size);
3835            return Err(format!("total result exceeds max size of {max_bytes}",));
3836        }
3837
3838        let batch_num_rows = rows.count(0, None);
3839
3840        let sorted_view = rows.sorted_view(&[]);
3841        let mut iter = sorted_view
3842            .into_row_iter()
3843            .apply_offset(self.remaining_offset)
3844            .with_projection(self.project.clone());
3845
3846        if let Some(limit) = self.remaining_limit {
3847            iter = iter.with_limit(limit);
3848        };
3849
3850        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3851        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3852            *remaining_limit -= iter.count();
3853        }
3854
3855        // TODO(parkmycar): Re-think how we can calculate the total response size without
3856        // having to iterate through the entire collection of Rows, while still
3857        // respecting the LIMIT, OFFSET, and projections.
3858        //
3859        // Note: It feels a bit bad always calculating the response size, but we almost
3860        // always need it to either check the `max_returned_query_size`, or for reporting
3861        // in the query history.
3862        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3863
3864        // Bail if we would end up returning more data to the client than they can support.
3865        if let Some(max) = self.remaining_max_returned_query_size {
3866            if response_size > usize::cast_from(max) {
3867                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3868                return Err(format!("total result exceeds max size of {max_bytes}"));
3869            }
3870        }
3871
3872        Ok(iter)
3873    }
3874}
3875
3876/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3877/// ordering, call `tiebreaker`.
3878pub fn compare_columns<F>(
3879    order: &[ColumnOrder],
3880    left: &[Datum],
3881    right: &[Datum],
3882    tiebreaker: F,
3883) -> Ordering
3884where
3885    F: Fn() -> Ordering,
3886{
3887    for order in order {
3888        let cmp = match (&left[order.column], &right[order.column]) {
3889            (Datum::Null, Datum::Null) => Ordering::Equal,
3890            (Datum::Null, _) => {
3891                if order.nulls_last {
3892                    Ordering::Greater
3893                } else {
3894                    Ordering::Less
3895                }
3896            }
3897            (_, Datum::Null) => {
3898                if order.nulls_last {
3899                    Ordering::Less
3900                } else {
3901                    Ordering::Greater
3902                }
3903            }
3904            (lval, rval) => {
3905                if order.desc {
3906                    rval.cmp(lval)
3907                } else {
3908                    lval.cmp(rval)
3909                }
3910            }
3911        };
3912        if cmp != Ordering::Equal {
3913            return cmp;
3914        }
3915    }
3916    tiebreaker()
3917}
3918
3919/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3920/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3921///
3922/// Window frames define a subset of the partition , and only a subset of
3923/// window functions make use of the window frame.
3924#[derive(
3925    Debug,
3926    Clone,
3927    Eq,
3928    PartialEq,
3929    Ord,
3930    PartialOrd,
3931    Serialize,
3932    Deserialize,
3933    Hash,
3934    MzReflect
3935)]
3936pub struct WindowFrame {
3937    /// ROWS, RANGE or GROUPS
3938    pub units: WindowFrameUnits,
3939    /// Where the frame starts
3940    pub start_bound: WindowFrameBound,
3941    /// Where the frame ends
3942    pub end_bound: WindowFrameBound,
3943}
3944
3945impl Display for WindowFrame {
3946    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3947        write!(
3948            f,
3949            "{} between {} and {}",
3950            self.units, self.start_bound, self.end_bound
3951        )
3952    }
3953}
3954
3955impl WindowFrame {
3956    /// Return the default window frame used when one is not explicitly defined
3957    pub fn default() -> Self {
3958        WindowFrame {
3959            units: WindowFrameUnits::Range,
3960            start_bound: WindowFrameBound::UnboundedPreceding,
3961            end_bound: WindowFrameBound::CurrentRow,
3962        }
3963    }
3964
3965    fn includes_current_row(&self) -> bool {
3966        use WindowFrameBound::*;
3967        match self.start_bound {
3968            UnboundedPreceding => match self.end_bound {
3969                UnboundedPreceding => false,
3970                OffsetPreceding(0) => true,
3971                OffsetPreceding(_) => false,
3972                CurrentRow => true,
3973                OffsetFollowing(_) => true,
3974                UnboundedFollowing => true,
3975            },
3976            OffsetPreceding(0) => match self.end_bound {
3977                UnboundedPreceding => unreachable!(),
3978                OffsetPreceding(0) => true,
3979                // Any nonzero offsets here will create an empty window
3980                OffsetPreceding(_) => false,
3981                CurrentRow => true,
3982                OffsetFollowing(_) => true,
3983                UnboundedFollowing => true,
3984            },
3985            OffsetPreceding(_) => match self.end_bound {
3986                UnboundedPreceding => unreachable!(),
3987                // Window ends at the current row
3988                OffsetPreceding(0) => true,
3989                OffsetPreceding(_) => false,
3990                CurrentRow => true,
3991                OffsetFollowing(_) => true,
3992                UnboundedFollowing => true,
3993            },
3994            CurrentRow => true,
3995            OffsetFollowing(0) => match self.end_bound {
3996                UnboundedPreceding => unreachable!(),
3997                OffsetPreceding(_) => unreachable!(),
3998                CurrentRow => unreachable!(),
3999                OffsetFollowing(_) => true,
4000                UnboundedFollowing => true,
4001            },
4002            OffsetFollowing(_) => match self.end_bound {
4003                UnboundedPreceding => unreachable!(),
4004                OffsetPreceding(_) => unreachable!(),
4005                CurrentRow => unreachable!(),
4006                OffsetFollowing(_) => false,
4007                UnboundedFollowing => false,
4008            },
4009            UnboundedFollowing => false,
4010        }
4011    }
4012}
4013
4014/// Describe how frame bounds are interpreted
4015#[derive(
4016    Debug,
4017    Clone,
4018    Eq,
4019    PartialEq,
4020    Ord,
4021    PartialOrd,
4022    Serialize,
4023    Deserialize,
4024    Hash,
4025    MzReflect
4026)]
4027pub enum WindowFrameUnits {
4028    /// Each row is treated as the unit of work for bounds
4029    Rows,
4030    /// Each peer group is treated as the unit of work for bounds,
4031    /// and offset-based bounds use the value of the ORDER BY expression
4032    Range,
4033    /// Each peer group is treated as the unit of work for bounds.
4034    /// Groups is currently not supported, and it is rejected during planning.
4035    Groups,
4036}
4037
4038impl Display for WindowFrameUnits {
4039    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4040        match self {
4041            WindowFrameUnits::Rows => write!(f, "rows"),
4042            WindowFrameUnits::Range => write!(f, "range"),
4043            WindowFrameUnits::Groups => write!(f, "groups"),
4044        }
4045    }
4046}
4047
4048/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
4049///
4050/// The order between frame bounds is significant, as Postgres enforces
4051/// some restrictions there.
4052#[derive(
4053    Debug,
4054    Clone,
4055    Serialize,
4056    Deserialize,
4057    PartialEq,
4058    Eq,
4059    Hash,
4060    MzReflect,
4061    PartialOrd,
4062    Ord
4063)]
4064pub enum WindowFrameBound {
4065    /// `UNBOUNDED PRECEDING`
4066    UnboundedPreceding,
4067    /// `<N> PRECEDING`
4068    OffsetPreceding(u64),
4069    /// `CURRENT ROW`
4070    CurrentRow,
4071    /// `<N> FOLLOWING`
4072    OffsetFollowing(u64),
4073    /// `UNBOUNDED FOLLOWING`.
4074    UnboundedFollowing,
4075}
4076
4077impl Display for WindowFrameBound {
4078    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4079        match self {
4080            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4081            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4082            WindowFrameBound::CurrentRow => write!(f, "current row"),
4083            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4084            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4085        }
4086    }
4087}
4088
4089/// Maximum iterations for a LetRec.
4090#[derive(
4091    Debug,
4092    Clone,
4093    Copy,
4094    PartialEq,
4095    Eq,
4096    PartialOrd,
4097    Ord,
4098    Hash,
4099    Serialize,
4100    Deserialize
4101)]
4102pub struct LetRecLimit {
4103    /// Maximum number of iterations to evaluate.
4104    pub max_iters: NonZeroU64,
4105    /// Whether to throw an error when reaching the above limit.
4106    /// If true, we simply use the current contents of each Id as the final result.
4107    pub return_at_limit: bool,
4108}
4109
4110impl LetRecLimit {
4111    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4112    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4113        limits
4114            .iter()
4115            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4116            .min()
4117    }
4118
4119    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4120    /// WMR without ERROR AT or RETURN AT.
4121    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4122}
4123
4124impl Display for LetRecLimit {
4125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4126        write!(f, "[recursion_limit={}", self.max_iters)?;
4127        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4128            write!(f, ", return_at_limit")?;
4129        }
4130        write!(f, "]")
4131    }
4132}
4133
4134/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4135/// (See comment in MirRelationExpr::Get.)
4136#[derive(
4137    Clone,
4138    Debug,
4139    Eq,
4140    PartialEq,
4141    Ord,
4142    PartialOrd,
4143    Serialize,
4144    Deserialize,
4145    Hash
4146)]
4147pub enum AccessStrategy {
4148    /// It's either a local Get (a CTE), or unknown at the time.
4149    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4150    /// one of the other variants.
4151    UnknownOrLocal,
4152    /// The Get will read from Persist.
4153    Persist,
4154    /// The Get will read from an index or indexes: (index id, how the index will be used).
4155    Index(Vec<(GlobalId, IndexUsageType)>),
4156    /// The Get will read a collection that is computed by the same dataflow, but in a different
4157    /// `BuildDesc` in `objects_to_build`.
4158    SameDataflow,
4159}
4160
4161#[cfg(test)]
4162mod tests {
4163    use mz_repr::explain::text::text_string_at;
4164
4165    use crate::explain::HumanizedExplain;
4166
4167    use super::*;
4168
4169    #[mz_ore::test]
4170    fn test_row_set_finishing_as_text() {
4171        let finishing = RowSetFinishing {
4172            order_by: vec![ColumnOrder {
4173                column: 4,
4174                desc: true,
4175                nulls_last: true,
4176            }],
4177            limit: Some(NonNeg::try_from(7).unwrap()),
4178            offset: Default::default(),
4179            project: vec![1, 3, 4, 5],
4180        };
4181
4182        let mode = HumanizedExplain::new(false);
4183        let expr = mode.expr(&finishing, None);
4184
4185        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4186
4187        let exp = {
4188            use mz_ore::fmt::FormatBuffer;
4189            let mut s = String::new();
4190            write!(&mut s, "Finish");
4191            write!(&mut s, " order_by=[#4 desc nulls_last]");
4192            write!(&mut s, " limit=7");
4193            write!(&mut s, " output=[#1, #3..=#5]");
4194            writeln!(&mut s, "");
4195            s
4196        };
4197
4198        assert_eq!(act, exp);
4199    }
4200}
4201
4202/// An iterator over AST structures, which calls out nodes in difference.
4203///
4204/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4205/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4206/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4207/// to compare two ASTs.
4208mod structured_diff {
4209
4210    use super::MirRelationExpr;
4211    use itertools::Itertools;
4212
4213    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4214    pub struct MreDiff<'a> {
4215        /// Pairs of expressions that must still be compared.
4216        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4217    }
4218
4219    impl<'a> MreDiff<'a> {
4220        /// Create a new `MirRelationExpr` structured difference.
4221        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4222            MreDiff {
4223                todo: vec![(expr1, expr2)],
4224            }
4225        }
4226    }
4227
4228    impl<'a> Iterator for MreDiff<'a> {
4229        // Pairs of expressions that do not match.
4230        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4231
4232        fn next(&mut self) -> Option<Self::Item> {
4233            while let Some((expr1, expr2)) = self.todo.pop() {
4234                match (expr1, expr2) {
4235                    (
4236                        MirRelationExpr::Constant {
4237                            rows: rows1,
4238                            typ: typ1,
4239                        },
4240                        MirRelationExpr::Constant {
4241                            rows: rows2,
4242                            typ: typ2,
4243                        },
4244                    ) => {
4245                        if rows1 != rows2 || typ1 != typ2 {
4246                            return Some((expr1, expr2));
4247                        }
4248                    }
4249                    (
4250                        MirRelationExpr::Get {
4251                            id: id1,
4252                            typ: typ1,
4253                            access_strategy: as1,
4254                        },
4255                        MirRelationExpr::Get {
4256                            id: id2,
4257                            typ: typ2,
4258                            access_strategy: as2,
4259                        },
4260                    ) => {
4261                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4262                            return Some((expr1, expr2));
4263                        }
4264                    }
4265                    (
4266                        MirRelationExpr::Let {
4267                            id: id1,
4268                            body: body1,
4269                            value: value1,
4270                        },
4271                        MirRelationExpr::Let {
4272                            id: id2,
4273                            body: body2,
4274                            value: value2,
4275                        },
4276                    ) => {
4277                        if id1 != id2 {
4278                            return Some((expr1, expr2));
4279                        } else {
4280                            self.todo.push((body1, body2));
4281                            self.todo.push((value1, value2));
4282                        }
4283                    }
4284                    (
4285                        MirRelationExpr::LetRec {
4286                            ids: ids1,
4287                            body: body1,
4288                            values: values1,
4289                            limits: limits1,
4290                        },
4291                        MirRelationExpr::LetRec {
4292                            ids: ids2,
4293                            body: body2,
4294                            values: values2,
4295                            limits: limits2,
4296                        },
4297                    ) => {
4298                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4299                            return Some((expr1, expr2));
4300                        } else {
4301                            self.todo.push((body1, body2));
4302                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4303                        }
4304                    }
4305                    (
4306                        MirRelationExpr::Project {
4307                            outputs: outputs1,
4308                            input: input1,
4309                        },
4310                        MirRelationExpr::Project {
4311                            outputs: outputs2,
4312                            input: input2,
4313                        },
4314                    ) => {
4315                        if outputs1 != outputs2 {
4316                            return Some((expr1, expr2));
4317                        } else {
4318                            self.todo.push((input1, input2));
4319                        }
4320                    }
4321                    (
4322                        MirRelationExpr::Map {
4323                            scalars: scalars1,
4324                            input: input1,
4325                        },
4326                        MirRelationExpr::Map {
4327                            scalars: scalars2,
4328                            input: input2,
4329                        },
4330                    ) => {
4331                        if scalars1 != scalars2 {
4332                            return Some((expr1, expr2));
4333                        } else {
4334                            self.todo.push((input1, input2));
4335                        }
4336                    }
4337                    (
4338                        MirRelationExpr::Filter {
4339                            predicates: predicates1,
4340                            input: input1,
4341                        },
4342                        MirRelationExpr::Filter {
4343                            predicates: predicates2,
4344                            input: input2,
4345                        },
4346                    ) => {
4347                        if predicates1 != predicates2 {
4348                            return Some((expr1, expr2));
4349                        } else {
4350                            self.todo.push((input1, input2));
4351                        }
4352                    }
4353                    (
4354                        MirRelationExpr::FlatMap {
4355                            input: input1,
4356                            func: func1,
4357                            exprs: exprs1,
4358                        },
4359                        MirRelationExpr::FlatMap {
4360                            input: input2,
4361                            func: func2,
4362                            exprs: exprs2,
4363                        },
4364                    ) => {
4365                        if func1 != func2 || exprs1 != exprs2 {
4366                            return Some((expr1, expr2));
4367                        } else {
4368                            self.todo.push((input1, input2));
4369                        }
4370                    }
4371                    (
4372                        MirRelationExpr::Join {
4373                            inputs: inputs1,
4374                            equivalences: eq1,
4375                            implementation: impl1,
4376                        },
4377                        MirRelationExpr::Join {
4378                            inputs: inputs2,
4379                            equivalences: eq2,
4380                            implementation: impl2,
4381                        },
4382                    ) => {
4383                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4384                            return Some((expr1, expr2));
4385                        } else {
4386                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4387                        }
4388                    }
4389                    (
4390                        MirRelationExpr::Reduce {
4391                            aggregates: aggregates1,
4392                            input: inputs1,
4393                            group_key: gk1,
4394                            monotonic: m1,
4395                            expected_group_size: egs1,
4396                        },
4397                        MirRelationExpr::Reduce {
4398                            aggregates: aggregates2,
4399                            input: inputs2,
4400                            group_key: gk2,
4401                            monotonic: m2,
4402                            expected_group_size: egs2,
4403                        },
4404                    ) => {
4405                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4406                            return Some((expr1, expr2));
4407                        } else {
4408                            self.todo.push((inputs1, inputs2));
4409                        }
4410                    }
4411                    (
4412                        MirRelationExpr::TopK {
4413                            group_key: gk1,
4414                            order_key: order1,
4415                            input: input1,
4416                            limit: l1,
4417                            offset: o1,
4418                            monotonic: m1,
4419                            expected_group_size: egs1,
4420                        },
4421                        MirRelationExpr::TopK {
4422                            group_key: gk2,
4423                            order_key: order2,
4424                            input: input2,
4425                            limit: l2,
4426                            offset: o2,
4427                            monotonic: m2,
4428                            expected_group_size: egs2,
4429                        },
4430                    ) => {
4431                        if order1 != order2
4432                            || gk1 != gk2
4433                            || l1 != l2
4434                            || o1 != o2
4435                            || m1 != m2
4436                            || egs1 != egs2
4437                        {
4438                            return Some((expr1, expr2));
4439                        } else {
4440                            self.todo.push((input1, input2));
4441                        }
4442                    }
4443                    (
4444                        MirRelationExpr::Negate { input: input1 },
4445                        MirRelationExpr::Negate { input: input2 },
4446                    ) => {
4447                        self.todo.push((input1, input2));
4448                    }
4449                    (
4450                        MirRelationExpr::Threshold { input: input1 },
4451                        MirRelationExpr::Threshold { input: input2 },
4452                    ) => {
4453                        self.todo.push((input1, input2));
4454                    }
4455                    (
4456                        MirRelationExpr::Union {
4457                            base: base1,
4458                            inputs: inputs1,
4459                        },
4460                        MirRelationExpr::Union {
4461                            base: base2,
4462                            inputs: inputs2,
4463                        },
4464                    ) => {
4465                        if inputs1.len() != inputs2.len() {
4466                            return Some((expr1, expr2));
4467                        } else {
4468                            self.todo.push((base1, base2));
4469                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4470                        }
4471                    }
4472                    (
4473                        MirRelationExpr::ArrangeBy {
4474                            keys: keys1,
4475                            input: input1,
4476                        },
4477                        MirRelationExpr::ArrangeBy {
4478                            keys: keys2,
4479                            input: input2,
4480                        },
4481                    ) => {
4482                        if keys1 != keys2 {
4483                            return Some((expr1, expr2));
4484                        } else {
4485                            self.todo.push((input1, input2));
4486                        }
4487                    }
4488                    _ => {
4489                        return Some((expr1, expr2));
4490                    }
4491                }
4492            }
4493            None
4494        }
4495    }
4496}