Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_repr::adt::numeric::NumericMaxScale;
33use mz_repr::explain::text::text_string_at;
34use mz_repr::explain::{
35    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
36};
37use mz_repr::{
38    ColumnName, Datum, Diff, GlobalId, IntoRowIterator, ReprColumnType, Row, RowIterator,
39    SqlColumnType, SqlRelationType, SqlScalarType,
40};
41use serde::{Deserialize, Serialize};
42
43use crate::Id::Local;
44use crate::explain::{HumanizedExpr, HumanizerMode};
45use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
46use crate::row::{RowCollection, SortedRowCollectionIter};
47use crate::visit::{Visit, VisitChildren};
48use crate::{
49    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
50    func as scalar_func,
51};
52
53pub mod canonicalize;
54pub mod func;
55pub mod join_input_mapper;
56
57/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
58///
59/// The recursion limit must be large enough to accommodate for the linear representation
60/// of some pathological but frequently occurring query fragments.
61///
62/// For example, in MIR we could have long chains of
63/// - (1) `Let` bindings,
64/// - (2) `CallBinary` calls with associative functions such as `+`
65///
66/// Until we fix those, we need to stick with the larger recursion limit.
67pub const RECURSION_LIMIT: usize = 2048;
68
69/// A trait for types that describe how to build a collection.
70pub trait CollectionPlan {
71    /// Collects the set of global identifiers from dataflows referenced in Get.
72    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
73
74    /// Returns the set of global identifiers from dataflows referenced in Get.
75    ///
76    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
77    fn depends_on(&self) -> BTreeSet<GlobalId> {
78        let mut out = BTreeSet::new();
79        self.depends_on_into(&mut out);
80        out
81    }
82}
83
84/// An abstract syntax tree which defines a collection.
85///
86/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
87/// written generically enough to avoid run-time compilation work.
88///
89/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
90/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
91/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
92/// one. This is because the reason for the manual implementation is not to change the semantics
93/// from the derived one, but to avoid stack overflows.
94#[allow(clippy::derived_hash_with_manual_eq)]
95#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
96pub enum MirRelationExpr {
97    /// A constant relation containing specified rows.
98    ///
99    /// The runtime memory footprint of this operator is zero.
100    ///
101    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
102    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
103    /// constant folding doesn't remove `ArrangeBy`s.
104    Constant {
105        /// Rows of the constant collection and their multiplicities.
106        rows: Result<Vec<(Row, Diff)>, EvalError>,
107        /// Schema of the collection.
108        typ: SqlRelationType,
109    },
110    /// Get an existing dataflow.
111    ///
112    /// The runtime memory footprint of this operator is zero.
113    Get {
114        /// The identifier for the collection to load.
115        #[mzreflect(ignore)]
116        id: Id,
117        /// Schema of the collection.
118        typ: SqlRelationType,
119        /// If this is a global Get, this will indicate whether we are going to read from Persist or
120        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
121        /// how downstream dataflow operations will use this index is also recorded. This is filled
122        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
123        /// lowering to LIR, but is used only by EXPLAIN.
124        #[mzreflect(ignore)]
125        access_strategy: AccessStrategy,
126    },
127    /// Introduce a temporary dataflow.
128    ///
129    /// The runtime memory footprint of this operator is zero.
130    Let {
131        /// The identifier to be used in `Get` variants to retrieve `value`.
132        #[mzreflect(ignore)]
133        id: LocalId,
134        /// The collection to be bound to `id`.
135        value: Box<MirRelationExpr>,
136        /// The result of the `Let`, evaluated with `id` bound to `value`.
137        body: Box<MirRelationExpr>,
138    },
139    /// Introduce mutually recursive bindings.
140    ///
141    /// Each `LocalId` is immediately bound to an initially empty  collection
142    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
143    /// binding is evaluated using the current contents of each other binding,
144    /// and is refreshed to contain the new evaluation. This process continues
145    /// through all bindings, and repeats as long as changes continue to occur.
146    ///
147    /// The resulting value of the expression is `body` evaluated once in the
148    /// context of the final iterates.
149    ///
150    /// A zero-binding instance can be replaced by `body`.
151    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
152    ///
153    /// The runtime memory footprint of this operator is zero.
154    LetRec {
155        /// The identifiers to be used in `Get` variants to retrieve each `value`.
156        #[mzreflect(ignore)]
157        ids: Vec<LocalId>,
158        /// The collections to be bound to each `id`.
159        values: Vec<MirRelationExpr>,
160        /// Maximum number of iterations, after which we should artificially force a fixpoint.
161        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
162        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
163        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
164        #[mzreflect(ignore)]
165        limits: Vec<Option<LetRecLimit>>,
166        /// The result of the `Let`, evaluated with `id` bound to `value`.
167        body: Box<MirRelationExpr>,
168    },
169    /// Project out some columns from a dataflow
170    ///
171    /// The runtime memory footprint of this operator is zero.
172    Project {
173        /// The source collection.
174        input: Box<MirRelationExpr>,
175        /// Indices of columns to retain.
176        outputs: Vec<usize>,
177    },
178    /// Append new columns to a dataflow
179    ///
180    /// The runtime memory footprint of this operator is zero.
181    Map {
182        /// The source collection.
183        input: Box<MirRelationExpr>,
184        /// Expressions which determine values to append to each row.
185        /// An expression may refer to columns in `input` or
186        /// expressions defined earlier in the vector
187        scalars: Vec<MirScalarExpr>,
188    },
189    /// Like Map, but yields zero-or-more output rows per input row
190    ///
191    /// The runtime memory footprint of this operator is zero.
192    FlatMap {
193        /// The source collection
194        input: Box<MirRelationExpr>,
195        /// The table func to apply
196        func: TableFunc,
197        /// The argument to the table func
198        exprs: Vec<MirScalarExpr>,
199    },
200    /// Keep rows from a dataflow where all the predicates are true
201    ///
202    /// The runtime memory footprint of this operator is zero.
203    Filter {
204        /// The source collection.
205        input: Box<MirRelationExpr>,
206        /// Predicates, each of which must be true.
207        predicates: Vec<MirScalarExpr>,
208    },
209    /// Join several collections, where some columns must be equal.
210    ///
211    /// For further details consult the documentation for [`MirRelationExpr::join`].
212    ///
213    /// The runtime memory footprint of this operator can be proportional to
214    /// the sizes of all inputs and the size of all joins of prefixes.
215    /// This may be reduced due to arrangements available at rendering time.
216    Join {
217        /// A sequence of input relations.
218        inputs: Vec<MirRelationExpr>,
219        /// A sequence of equivalence classes of expressions on the cross product of inputs.
220        ///
221        /// Each equivalence class is a list of scalar expressions, where for each class the
222        /// intended interpretation is that all evaluated expressions should be equal.
223        ///
224        /// Each scalar expression is to be evaluated over the cross-product of all records
225        /// from all inputs. In many cases this may just be column selection from specific
226        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
227        /// from multiple inputs, or just constant literals).
228        equivalences: Vec<Vec<MirScalarExpr>>,
229        /// Join implementation information.
230        #[serde(default)]
231        implementation: JoinImplementation,
232    },
233    /// Group a dataflow by some columns and aggregate over each group
234    ///
235    /// The runtime memory footprint of this operator is at most proportional to the
236    /// number of distinct records in the input and output. The actual requirements
237    /// can be less: the number of distinct inputs to each aggregate, summed across
238    /// each aggregate, plus the output size. For more details consult the code that
239    /// builds the associated dataflow.
240    Reduce {
241        /// The source collection.
242        input: Box<MirRelationExpr>,
243        /// Column indices used to form groups.
244        group_key: Vec<MirScalarExpr>,
245        /// Expressions which determine values to append to each row, after the group keys.
246        aggregates: Vec<AggregateExpr>,
247        /// True iff the input is known to monotonically increase (only addition of records).
248        #[serde(default)]
249        monotonic: bool,
250        /// User hint: expected number of values per group key. Used to optimize physical rendering.
251        #[serde(default)]
252        expected_group_size: Option<u64>,
253    },
254    /// Groups and orders within each group, limiting output.
255    ///
256    /// The runtime memory footprint of this operator is proportional to its input and output.
257    TopK {
258        /// The source collection.
259        input: Box<MirRelationExpr>,
260        /// Column indices used to form groups.
261        group_key: Vec<usize>,
262        /// Column indices used to order rows within groups.
263        order_key: Vec<ColumnOrder>,
264        /// Number of records to retain
265        #[serde(default)]
266        limit: Option<MirScalarExpr>,
267        /// Number of records to skip
268        #[serde(default)]
269        offset: usize,
270        /// True iff the input is known to monotonically increase (only addition of records).
271        #[serde(default)]
272        monotonic: bool,
273        /// User-supplied hint: how many rows will have the same group key.
274        #[serde(default)]
275        expected_group_size: Option<u64>,
276    },
277    /// Return a dataflow where the row counts are negated
278    ///
279    /// The runtime memory footprint of this operator is zero.
280    Negate {
281        /// The source collection.
282        input: Box<MirRelationExpr>,
283    },
284    /// Keep rows from a dataflow where the row counts are positive
285    ///
286    /// The runtime memory footprint of this operator is proportional to its input and output.
287    Threshold {
288        /// The source collection.
289        input: Box<MirRelationExpr>,
290    },
291    /// Adds the frequencies of elements in contained sets.
292    ///
293    /// The runtime memory footprint of this operator is zero.
294    Union {
295        /// A source collection.
296        base: Box<MirRelationExpr>,
297        /// Source collections to union.
298        inputs: Vec<MirRelationExpr>,
299    },
300    /// Technically a no-op. Used to render an index. Will be used to optimize queries
301    /// on finer grain. Each `keys` item represents a different index that should be
302    /// produced from the `keys`.
303    ///
304    /// The runtime memory footprint of this operator is proportional to its input.
305    ArrangeBy {
306        /// The source collection
307        input: Box<MirRelationExpr>,
308        /// Columns to arrange `input` by, in order of decreasing primacy
309        keys: Vec<Vec<MirScalarExpr>>,
310    },
311}
312
313impl PartialEq for MirRelationExpr {
314    fn eq(&self, other: &Self) -> bool {
315        // Capture the result and test it wrt `Ord` implementation in test environments.
316        let result = structured_diff::MreDiff::new(self, other).next().is_none();
317        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
318        result
319    }
320}
321impl Eq for MirRelationExpr {}
322
323impl MirRelationExpr {
324    /// Reports the schema of the relation.
325    ///
326    /// This method determines the type through recursive traversal of the
327    /// relation expression, drawing from the types of base collections.
328    /// As such, this is not an especially cheap method, and should be used
329    /// judiciously.
330    ///
331    /// The relation type is computed incrementally with a recursive post-order
332    /// traversal, that accumulates the input types for the relations yet to be
333    /// visited in `type_stack`.
334    pub fn typ(&self) -> SqlRelationType {
335        let mut type_stack = Vec::new();
336        #[allow(deprecated)]
337        self.visit_pre_post_nolimit(
338            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
339                match &e {
340                    MirRelationExpr::Let { body, .. } => {
341                        // Do not traverse the value sub-graph, since it's not relevant for
342                        // determining the relation type of Let operators.
343                        Some(vec![&*body])
344                    }
345                    MirRelationExpr::LetRec { body, .. } => {
346                        // Do not traverse the value sub-graph, since it's not relevant for
347                        // determining the relation type of Let operators.
348                        Some(vec![&*body])
349                    }
350                    _ => None,
351                }
352            },
353            &mut |e: &MirRelationExpr| {
354                match e {
355                    MirRelationExpr::Let { .. } => {
356                        let body_typ = type_stack.pop().unwrap();
357                        // Insert a dummy relation type for the value, since `typ_with_input_types`
358                        // won't look at it, but expects the relation type of the body to be second.
359                        type_stack.push(SqlRelationType::empty());
360                        type_stack.push(body_typ);
361                    }
362                    MirRelationExpr::LetRec { values, .. } => {
363                        let body_typ = type_stack.pop().unwrap();
364                        // Insert dummy relation types for the values, since `typ_with_input_types`
365                        // won't look at them, but expects the relation type of the body to be last.
366                        type_stack
367                            .extend(std::iter::repeat(SqlRelationType::empty()).take(values.len()));
368                        type_stack.push(body_typ);
369                    }
370                    _ => {}
371                }
372                let num_inputs = e.num_inputs();
373                let relation_type =
374                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
375                type_stack.truncate(type_stack.len() - num_inputs);
376                type_stack.push(relation_type);
377            },
378        );
379        assert_eq!(type_stack.len(), 1);
380        type_stack.pop().unwrap()
381    }
382
383    /// Reports the schema of the relation given the schema of the input relations.
384    ///
385    /// `input_types` is required to contain the schemas for the input relations of
386    /// the current relation in the same order as they are visited by `try_visit_children`
387    /// method, even though not all may be used for computing the schema of the
388    /// current relation. For example, `Let` expects two input types, one for the
389    /// value relation and one for the body, in that order, but only the one for the
390    /// body is used to determine the type of the `Let` relation.
391    ///
392    /// It is meant to be used during post-order traversals to compute relation
393    /// schemas incrementally.
394    pub fn typ_with_input_types(&self, input_types: &[SqlRelationType]) -> SqlRelationType {
395        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
396        let unique_keys = self.keys_with_input_keys(
397            input_types.iter().map(|i| i.arity()),
398            input_types.iter().map(|i| &i.keys),
399        );
400        SqlRelationType::new(column_types).with_keys(unique_keys)
401    }
402
403    /// Reports the column types of the relation given the column types of the
404    /// input relations.
405    ///
406    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
407    /// variant is returned.
408    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<SqlColumnType>
409    where
410        I: Iterator<Item = &'a Vec<SqlColumnType>>,
411    {
412        match self.try_col_with_input_cols(input_types) {
413            Ok(col_types) => col_types,
414            Err(err) => panic!("{err}"),
415        }
416    }
417
418    /// Reports the column types of the relation given the column types of the input relations.
419    ///
420    /// `input_types` is required to contain the column types for the input relations of
421    /// the current relation in the same order as they are visited by `try_visit_children`
422    /// method, even though not all may be used for computing the schema of the
423    /// current relation. For example, `Let` expects two input types, one for the
424    /// value relation and one for the body, in that order, but only the one for the
425    /// body is used to determine the type of the `Let` relation.
426    ///
427    /// It is meant to be used during post-order traversals to compute column types
428    /// incrementally.
429    pub fn try_col_with_input_cols<'a, I>(
430        &self,
431        mut input_types: I,
432    ) -> Result<Vec<SqlColumnType>, String>
433    where
434        I: Iterator<Item = &'a Vec<SqlColumnType>>,
435    {
436        use MirRelationExpr::*;
437
438        let col_types = match self {
439            Constant { rows, typ } => {
440                let mut col_types = typ.column_types.clone();
441                let mut seen_null = vec![false; typ.arity()];
442                if let Ok(rows) = rows {
443                    for (row, _diff) in rows {
444                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
445                            if datum.is_null() {
446                                seen_null[i] = true;
447                            }
448                        }
449                    }
450                }
451                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
452                    if !seen_null {
453                        col_types[i].nullable = false;
454                    } else {
455                        assert!(col_types[i].nullable);
456                    }
457                }
458                col_types
459            }
460            Get { typ, .. } => typ.column_types.clone(),
461            Project { outputs, .. } => {
462                let input = input_types.next().unwrap();
463                outputs.iter().map(|&i| input[i].clone()).collect()
464            }
465            Map { scalars, .. } => {
466                let mut result = input_types.next().unwrap().clone();
467                for scalar in scalars.iter() {
468                    result.push(scalar.typ(&result))
469                }
470                result
471            }
472            FlatMap { func, .. } => {
473                let mut result = input_types.next().unwrap().clone();
474                result.extend(func.output_type().column_types);
475                result
476            }
477            Filter { predicates, .. } => {
478                let mut result = input_types.next().unwrap().clone();
479
480                // Set as nonnull any columns where null values would cause
481                // any predicate to evaluate to null.
482                for column in non_nullable_columns(predicates) {
483                    result[column].nullable = false;
484                }
485                result
486            }
487            Join { equivalences, .. } => {
488                // Concatenate input column types
489                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
490                // In an equivalence class, if any column is non-null, then make all non-null
491                for equivalence in equivalences {
492                    let col_inds = equivalence
493                        .iter()
494                        .filter_map(|expr| match expr {
495                            MirScalarExpr::Column(col, _name) => Some(*col),
496                            _ => None,
497                        })
498                        .collect_vec();
499                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
500                        for i in col_inds {
501                            types.get_mut(i).unwrap().nullable = false;
502                        }
503                    }
504                }
505                types
506            }
507            Reduce {
508                group_key,
509                aggregates,
510                ..
511            } => {
512                let input = input_types.next().unwrap();
513                group_key
514                    .iter()
515                    .map(|e| e.typ(input))
516                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
517                    .collect()
518            }
519            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
520                input_types.next().unwrap().clone()
521            }
522            Let { .. } => {
523                // skip over the input types for `value`.
524                input_types.nth(1).unwrap().clone()
525            }
526            LetRec { values, .. } => {
527                // skip over the input types for `values`.
528                input_types.nth(values.len()).unwrap().clone()
529            }
530            Union { .. } => {
531                let mut result = input_types.next().unwrap().clone();
532                for input_col_types in input_types {
533                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
534                        *base_col = base_col
535                            .union(col)
536                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
537                    }
538                }
539                result
540            }
541        };
542
543        Ok(col_types)
544    }
545
546    /// Reports the column types of the relation given the column types of the
547    /// input relations.
548    ///
549    /// This method delegates to `try_repr_col_with_input_repr_cols`, panicking if an `Err`
550    /// variant is returned.
551    pub fn repr_col_with_input_repr_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
552    where
553        I: Iterator<Item = &'a Vec<ReprColumnType>>,
554    {
555        match self.try_repr_col_with_input_repr_cols(input_types) {
556            Ok(col_types) => col_types,
557            Err(err) => panic!("{err}"),
558        }
559    }
560
561    /// Reports the column types of the relation given the column types of the input relations.
562    ///
563    /// `input_types` is required to contain the column types for the input relations of
564    /// the current relation in the same order as they are visited by `try_visit_children`
565    /// method, even though not all may be used for computing the schema of the
566    /// current relation. For example, `Let` expects two input types, one for the
567    /// value relation and one for the body, in that order, but only the one for the
568    /// body is used to determine the type of the `Let` relation.
569    ///
570    /// It is meant to be used during post-order traversals to compute column types
571    /// incrementally.
572    pub fn try_repr_col_with_input_repr_cols<'a, I>(
573        &self,
574        mut input_types: I,
575    ) -> Result<Vec<ReprColumnType>, String>
576    where
577        I: Iterator<Item = &'a Vec<ReprColumnType>>,
578    {
579        use MirRelationExpr::*;
580
581        let col_types = match self {
582            Constant { rows, typ } => {
583                let mut col_types = typ
584                    .column_types
585                    .iter()
586                    .map(ReprColumnType::from)
587                    .collect_vec();
588                let mut seen_null = vec![false; typ.arity()];
589                if let Ok(rows) = rows {
590                    for (row, _diff) in rows {
591                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
592                            if datum.is_null() {
593                                seen_null[i] = true;
594                            }
595                        }
596                    }
597                }
598                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
599                    if !seen_null {
600                        col_types[i].nullable = false;
601                    } else {
602                        assert!(col_types[i].nullable);
603                    }
604                }
605                col_types
606            }
607            Get { typ, .. } => typ
608                .column_types
609                .iter()
610                .map(ReprColumnType::from)
611                .collect_vec(),
612            Project { outputs, .. } => {
613                let input = input_types.next().unwrap();
614                outputs.iter().map(|&i| input[i].clone()).collect()
615            }
616            Map { scalars, .. } => {
617                let mut result = input_types.next().unwrap().clone();
618                for scalar in scalars.iter() {
619                    result.push(scalar.repr_typ(&result))
620                }
621                result
622            }
623            FlatMap { func, .. } => {
624                let mut result = input_types.next().unwrap().clone();
625                result.extend(
626                    func.output_type()
627                        .column_types
628                        .iter()
629                        .map(ReprColumnType::from),
630                );
631                result
632            }
633            Filter { predicates, .. } => {
634                let mut result = input_types.next().unwrap().clone();
635
636                // Set as nonnull any columns where null values would cause
637                // any predicate to evaluate to null.
638                for column in non_nullable_columns(predicates) {
639                    result[column].nullable = false;
640                }
641                result
642            }
643            Join { equivalences, .. } => {
644                // Concatenate input column types
645                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
646                // In an equivalence class, if any column is non-null, then make all non-null
647                for equivalence in equivalences {
648                    let col_inds = equivalence
649                        .iter()
650                        .filter_map(|expr| match expr {
651                            MirScalarExpr::Column(col, _name) => Some(*col),
652                            _ => None,
653                        })
654                        .collect_vec();
655                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
656                        for i in col_inds {
657                            types.get_mut(i).unwrap().nullable = false;
658                        }
659                    }
660                }
661                types
662            }
663            Reduce {
664                group_key,
665                aggregates,
666                ..
667            } => {
668                let input = input_types.next().unwrap();
669                group_key
670                    .iter()
671                    .map(|e| e.repr_typ(input))
672                    .chain(aggregates.iter().map(|agg| agg.repr_typ(input)))
673                    .collect()
674            }
675            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
676                input_types.next().unwrap().clone()
677            }
678            Let { .. } => {
679                // skip over the input types for `value`.
680                input_types.nth(1).unwrap().clone()
681            }
682            LetRec { values, .. } => {
683                // skip over the input types for `values`.
684                input_types.nth(values.len()).unwrap().clone()
685            }
686            Union { .. } => {
687                let mut result = input_types.next().unwrap().clone();
688                for input_col_types in input_types {
689                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
690                        *base_col = base_col
691                            .union(col)
692                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
693                    }
694                }
695                result
696            }
697        };
698
699        Ok(col_types)
700    }
701
702    /// Reports the unique keys of the relation given the arities and the unique
703    /// keys of the input relations.
704    ///
705    /// `input_arities` and `input_keys` are required to contain the
706    /// corresponding info for the input relations of
707    /// the current relation in the same order as they are visited by `try_visit_children`
708    /// method, even though not all may be used for computing the schema of the
709    /// current relation. For example, `Let` expects two input types, one for the
710    /// value relation and one for the body, in that order, but only the one for the
711    /// body is used to determine the type of the `Let` relation.
712    ///
713    /// It is meant to be used during post-order traversals to compute unique keys
714    /// incrementally.
715    pub fn keys_with_input_keys<'a, I, J>(
716        &self,
717        mut input_arities: I,
718        mut input_keys: J,
719    ) -> Vec<Vec<usize>>
720    where
721        I: Iterator<Item = usize>,
722        J: Iterator<Item = &'a Vec<Vec<usize>>>,
723    {
724        use MirRelationExpr::*;
725
726        let mut keys = match self {
727            Constant {
728                rows: Ok(rows),
729                typ,
730            } => {
731                let n_cols = typ.arity();
732                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
733                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
734                for (row, diff) in rows {
735                    for (i, datum) in row.iter().enumerate() {
736                        if datum != Datum::Dummy {
737                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
738                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
739                                if is_dupe {
740                                    unique_values_per_col[i] = None;
741                                }
742                            }
743                        }
744                    }
745                }
746                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
747                    vec![vec![]]
748                } else {
749                    // XXX - Multi-column keys are not detected.
750                    typ.keys
751                        .iter()
752                        .cloned()
753                        .chain(
754                            unique_values_per_col
755                                .into_iter()
756                                .enumerate()
757                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
758                                .map(|(idx, _)| vec![idx]),
759                        )
760                        .collect()
761                }
762            }
763            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
764            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
765            Let { .. } => {
766                // skip over the unique keys for value
767                input_keys.nth(1).unwrap().clone()
768            }
769            LetRec { values, .. } => {
770                // skip over the unique keys for value
771                input_keys.nth(values.len()).unwrap().clone()
772            }
773            Project { outputs, .. } => {
774                let input = input_keys.next().unwrap();
775                input
776                    .iter()
777                    .filter_map(|key_set| {
778                        if key_set.iter().all(|k| outputs.contains(k)) {
779                            Some(
780                                key_set
781                                    .iter()
782                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
783                                    .collect(),
784                            )
785                        } else {
786                            None
787                        }
788                    })
789                    .collect()
790            }
791            Map { scalars, .. } => {
792                let mut remappings = Vec::new();
793                let arity = input_arities.next().unwrap();
794                for (column, scalar) in scalars.iter().enumerate() {
795                    // assess whether the scalar preserves uniqueness,
796                    // and could participate in a key!
797
798                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
799                        match expr {
800                            MirScalarExpr::CallUnary { func, expr } => {
801                                if func.preserves_uniqueness() {
802                                    uniqueness(expr)
803                                } else {
804                                    None
805                                }
806                            }
807                            MirScalarExpr::Column(c, _name) => Some(*c),
808                            _ => None,
809                        }
810                    }
811
812                    if let Some(c) = uniqueness(scalar) {
813                        remappings.push((c, column + arity));
814                    }
815                }
816
817                let mut result = input_keys.next().unwrap().clone();
818                let mut new_keys = Vec::new();
819                // Any column in `remappings` could be replaced in a key
820                // by the corresponding c. This could lead to combinatorial
821                // explosion using our current representation, so we wont
822                // do that. Instead, we'll handle the case of one remapping.
823                if remappings.len() == 1 {
824                    let (old, new) = remappings.pop().unwrap();
825                    for key in &result {
826                        if key.contains(&old) {
827                            let mut new_key: Vec<usize> =
828                                key.iter().cloned().filter(|k| k != &old).collect();
829                            new_key.push(new);
830                            new_key.sort_unstable();
831                            new_keys.push(new_key);
832                        }
833                    }
834                    result.append(&mut new_keys);
835                }
836                result
837            }
838            FlatMap { .. } => {
839                // FlatMap can add duplicate rows, so input keys are no longer
840                // valid
841                vec![]
842            }
843            Negate { .. } => {
844                // Although negate may have distinct records for each key,
845                // the multiplicity is -1 rather than 1. This breaks many
846                // of the optimization uses of "keys".
847                vec![]
848            }
849            Filter { predicates, .. } => {
850                // A filter inherits the keys of its input unless the filters
851                // have reduced the input to a single row, in which case the
852                // keys of the input are `()`.
853                let mut input = input_keys.next().unwrap().clone();
854
855                if !input.is_empty() {
856                    // Track columns equated to literals, which we can prune.
857                    let mut cols_equal_to_literal = BTreeSet::new();
858
859                    // Perform union find on `col1 = col2` to establish
860                    // connected components of equated columns. Absent any
861                    // equalities, this will be `0 .. #c` (where #c is the
862                    // greatest column referenced by a predicate), but each
863                    // equality will orient the root of the greater to the root
864                    // of the lesser.
865                    let mut union_find = Vec::new();
866
867                    for expr in predicates.iter() {
868                        if let MirScalarExpr::CallBinary {
869                            func: crate::BinaryFunc::Eq(_),
870                            expr1,
871                            expr2,
872                        } = expr
873                        {
874                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
875                                if expr2.is_literal_ok() {
876                                    cols_equal_to_literal.insert(c);
877                                }
878                            }
879                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
880                                if expr1.is_literal_ok() {
881                                    cols_equal_to_literal.insert(c);
882                                }
883                            }
884                            // Perform union-find to equate columns.
885                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
886                                if c1 != c2 {
887                                    // Ensure union_find has entries up to
888                                    // max(c1, c2) by filling up missing
889                                    // positions with identity mappings.
890                                    while union_find.len() <= std::cmp::max(c1, c2) {
891                                        union_find.push(union_find.len());
892                                    }
893                                    let mut r1 = c1; // Find the representative column of [c1].
894                                    while r1 != union_find[r1] {
895                                        assert!(union_find[r1] < r1);
896                                        r1 = union_find[r1];
897                                    }
898                                    let mut r2 = c2; // Find the representative column of [c2].
899                                    while r2 != union_find[r2] {
900                                        assert!(union_find[r2] < r2);
901                                        r2 = union_find[r2];
902                                    }
903                                    // Union [c1] and [c2] by pointing the
904                                    // larger to the smaller representative (we
905                                    // update the remaining equivalence class
906                                    // members only once after this for-loop).
907                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
908                                }
909                            }
910                        }
911                    }
912
913                    // Complete union-find by pointing each element at its representative column.
914                    for i in 0..union_find.len() {
915                        // Iteration not required, as each prior already references the right column.
916                        union_find[i] = union_find[union_find[i]];
917                    }
918
919                    // Remove columns bound to literals, and remap columns equated to earlier columns.
920                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
921                    for key_set in &mut input {
922                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
923                        for col in key_set.iter_mut() {
924                            if let Some(equiv) = union_find.get(*col) {
925                                *col = *equiv;
926                            }
927                        }
928                        key_set.sort();
929                        key_set.dedup();
930                    }
931                    input.sort();
932                    input.dedup();
933
934                    // Expand out each key to each of its equivalent forms.
935                    // Each instance of `col` can be replaced by any equivalent column.
936                    // This has the potential to result in exponentially sized number of unique keys,
937                    // and in the future we should probably maintain unique keys modulo equivalence.
938
939                    // First, compute an inverse map from each representative
940                    // column `sub` to all other equivalent columns `col`.
941                    let mut subs = Vec::new();
942                    for (col, sub) in union_find.iter().enumerate() {
943                        if *sub != col {
944                            assert!(*sub < col);
945                            while subs.len() <= *sub {
946                                subs.push(Vec::new());
947                            }
948                            subs[*sub].push(col);
949                        }
950                    }
951                    // For each column, substitute for it in each occurrence.
952                    let mut to_add = Vec::new();
953                    for (col, subs) in subs.iter().enumerate() {
954                        if !subs.is_empty() {
955                            for key_set in input.iter() {
956                                if key_set.contains(&col) {
957                                    let mut to_extend = key_set.clone();
958                                    to_extend.retain(|c| c != &col);
959                                    for sub in subs {
960                                        to_extend.push(*sub);
961                                        to_add.push(to_extend.clone());
962                                        to_extend.pop();
963                                    }
964                                }
965                            }
966                        }
967                        // No deduplication, as we cannot introduce duplicates.
968                        input.append(&mut to_add);
969                    }
970                    for key_set in input.iter_mut() {
971                        key_set.sort();
972                        key_set.dedup();
973                    }
974                }
975                input
976            }
977            Join { equivalences, .. } => {
978                // It is important the `new_from_input_arities` constructor is
979                // used. Otherwise, Materialize may potentially end up in an
980                // infinite loop.
981                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
982
983                input_mapper.global_keys(input_keys, equivalences)
984            }
985            Reduce { group_key, .. } => {
986                // The group key should form a key, but we might already have
987                // keys that are subsets of the group key, and should retain
988                // those instead, if so.
989                let mut result = Vec::new();
990                for key_set in input_keys.next().unwrap() {
991                    if key_set
992                        .iter()
993                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
994                    {
995                        result.push(
996                            key_set
997                                .iter()
998                                .map(|i| {
999                                    group_key
1000                                        .iter()
1001                                        .position(|k| k == &MirScalarExpr::column(*i))
1002                                        .unwrap()
1003                                })
1004                                .collect::<Vec<_>>(),
1005                        );
1006                    }
1007                }
1008                if result.is_empty() {
1009                    result.push((0..group_key.len()).collect());
1010                }
1011                result
1012            }
1013            TopK {
1014                group_key, limit, ..
1015            } => {
1016                // If `limit` is `Some(1)` then the group key will become
1017                // a unique key, as there will be only one record with that key.
1018                let mut result = input_keys.next().unwrap().clone();
1019                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1020                    result.push(group_key.clone())
1021                }
1022                result
1023            }
1024            Union { base, inputs } => {
1025                // Generally, unions do not have any unique keys, because
1026                // each input might duplicate some. However, there is at
1027                // least one idiomatic structure that does preserve keys,
1028                // which results from SQL aggregations that must populate
1029                // absent records with default values. In that pattern,
1030                // the union of one GET with its negation, which has first
1031                // been subjected to a projection and map, we can remove
1032                // their influence on the key structure.
1033                //
1034                // If there are A, B, each with a unique `key` such that
1035                // we are looking at
1036                //
1037                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1038                //
1039                // Then we can report `key` as a unique key.
1040                //
1041                // TODO: make unique key structure an optimization analysis
1042                // rather than part of the type information.
1043                // TODO: perhaps ensure that (above) A.proj(key) is a
1044                // subset of B, as otherwise there are negative records
1045                // and who knows what is true (not expected, but again
1046                // who knows what the query plan might look like).
1047
1048                let arity = input_arities.next().unwrap();
1049                let (base_projection, base_with_project_stripped) =
1050                    if let MirRelationExpr::Project { input, outputs } = &**base {
1051                        (outputs.clone(), &**input)
1052                    } else {
1053                        // A input without a project is equivalent to an input
1054                        // with the project being all columns in the input in order.
1055                        ((0..arity).collect::<Vec<_>>(), &**base)
1056                    };
1057                let mut result = Vec::new();
1058                if let MirRelationExpr::Get {
1059                    id: first_id,
1060                    typ: _,
1061                    ..
1062                } = base_with_project_stripped
1063                {
1064                    if inputs.len() == 1 {
1065                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1066                            if let MirRelationExpr::Union { base, inputs } = &**input {
1067                                if inputs.len() == 1 {
1068                                    if let Some((input, outputs)) = base.is_negated_project() {
1069                                        if let MirRelationExpr::Get {
1070                                            id: second_id,
1071                                            typ: _,
1072                                            ..
1073                                        } = input
1074                                        {
1075                                            if first_id == second_id {
1076                                                result.extend(
1077                                                    input_keys
1078                                                        .next()
1079                                                        .unwrap()
1080                                                        .into_iter()
1081                                                        .filter(|key| {
1082                                                            key.iter().all(|c| {
1083                                                                outputs.get(*c) == Some(c)
1084                                                                    && base_projection.get(*c)
1085                                                                        == Some(c)
1086                                                            })
1087                                                        })
1088                                                        .cloned(),
1089                                                );
1090                                            }
1091                                        }
1092                                    }
1093                                }
1094                            }
1095                        }
1096                    }
1097                }
1098                // Important: do not inherit keys of either input, as not unique.
1099                result
1100            }
1101        };
1102        keys.sort();
1103        keys.dedup();
1104        keys
1105    }
1106
1107    /// The number of columns in the relation.
1108    ///
1109    /// This number is determined from the type, which is determined recursively
1110    /// at non-trivial cost.
1111    ///
1112    /// The arity is computed incrementally with a recursive post-order
1113    /// traversal, that accumulates the arities for the relations yet to be
1114    /// visited in `arity_stack`.
1115    pub fn arity(&self) -> usize {
1116        let mut arity_stack = Vec::new();
1117        #[allow(deprecated)]
1118        self.visit_pre_post_nolimit(
1119            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1120                match &e {
1121                    MirRelationExpr::Let { body, .. } => {
1122                        // Do not traverse the value sub-graph, since it's not relevant for
1123                        // determining the arity of Let operators.
1124                        Some(vec![&*body])
1125                    }
1126                    MirRelationExpr::LetRec { body, .. } => {
1127                        // Do not traverse the value sub-graph, since it's not relevant for
1128                        // determining the arity of Let operators.
1129                        Some(vec![&*body])
1130                    }
1131                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1132                        // No further traversal is required; these operators know their arity.
1133                        Some(Vec::new())
1134                    }
1135                    _ => None,
1136                }
1137            },
1138            &mut |e: &MirRelationExpr| {
1139                match &e {
1140                    MirRelationExpr::Let { .. } => {
1141                        let body_arity = arity_stack.pop().unwrap();
1142                        arity_stack.push(0);
1143                        arity_stack.push(body_arity);
1144                    }
1145                    MirRelationExpr::LetRec { values, .. } => {
1146                        let body_arity = arity_stack.pop().unwrap();
1147                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
1148                        arity_stack.push(body_arity);
1149                    }
1150                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1151                        arity_stack.push(0);
1152                    }
1153                    _ => {}
1154                }
1155                let num_inputs = e.num_inputs();
1156                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1157                let arity = e.arity_with_input_arities(input_arities);
1158                arity_stack.push(arity);
1159            },
1160        );
1161        assert_eq!(arity_stack.len(), 1);
1162        arity_stack.pop().unwrap()
1163    }
1164
1165    /// Reports the arity of the relation given the schema of the input relations.
1166    ///
1167    /// `input_arities` is required to contain the arities for the input relations of
1168    /// the current relation in the same order as they are visited by `try_visit_children`
1169    /// method, even though not all may be used for computing the schema of the
1170    /// current relation. For example, `Let` expects two input types, one for the
1171    /// value relation and one for the body, in that order, but only the one for the
1172    /// body is used to determine the type of the `Let` relation.
1173    ///
1174    /// It is meant to be used during post-order traversals to compute arities
1175    /// incrementally.
1176    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1177    where
1178        I: Iterator<Item = usize>,
1179    {
1180        use MirRelationExpr::*;
1181
1182        match self {
1183            Constant { rows: _, typ } => typ.arity(),
1184            Get { typ, .. } => typ.arity(),
1185            Let { .. } => {
1186                input_arities.next();
1187                input_arities.next().unwrap()
1188            }
1189            LetRec { values, .. } => {
1190                for _ in 0..values.len() {
1191                    input_arities.next();
1192                }
1193                input_arities.next().unwrap()
1194            }
1195            Project { outputs, .. } => outputs.len(),
1196            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1197            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1198            Join { .. } => input_arities.sum(),
1199            Reduce {
1200                input: _,
1201                group_key,
1202                aggregates,
1203                ..
1204            } => group_key.len() + aggregates.len(),
1205            Filter { .. }
1206            | TopK { .. }
1207            | Negate { .. }
1208            | Threshold { .. }
1209            | Union { .. }
1210            | ArrangeBy { .. } => input_arities.next().unwrap(),
1211        }
1212    }
1213
1214    /// The number of child relations this relation has.
1215    pub fn num_inputs(&self) -> usize {
1216        let mut count = 0;
1217
1218        self.visit_children(|_| count += 1);
1219
1220        count
1221    }
1222
1223    /// Constructs a constant collection from specific rows and schema, where
1224    /// each row will have a multiplicity of one.
1225    pub fn constant(rows: Vec<Vec<Datum>>, typ: SqlRelationType) -> Self {
1226        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1227        MirRelationExpr::constant_diff(rows, typ)
1228    }
1229
1230    /// Constructs a constant collection from specific rows and schema, where
1231    /// each row can have an arbitrary multiplicity.
1232    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: SqlRelationType) -> Self {
1233        for (row, _diff) in &rows {
1234            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1235                assert!(
1236                    datum.is_instance_of_sql(column_typ),
1237                    "Expected datum of type {:?}, got value {:?}",
1238                    column_typ,
1239                    datum
1240                );
1241            }
1242        }
1243        let rows = Ok(rows
1244            .into_iter()
1245            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1246            .collect());
1247        MirRelationExpr::Constant { rows, typ }
1248    }
1249
1250    /// If self is a constant, return the value and the type, otherwise `None`.
1251    /// Looks behind `ArrangeBy`s.
1252    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &SqlRelationType)> {
1253        match self {
1254            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1255            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1256            _ => None,
1257        }
1258    }
1259
1260    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1261    /// Looks behind `ArrangeBy`s.
1262    pub fn as_const_mut(
1263        &mut self,
1264    ) -> Option<(
1265        &mut Result<Vec<(Row, Diff)>, EvalError>,
1266        &mut SqlRelationType,
1267    )> {
1268        match self {
1269            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1270            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1271            _ => None,
1272        }
1273    }
1274
1275    /// If self is a constant error, return the error, otherwise `None`.
1276    /// Looks behind `ArrangeBy`s.
1277    pub fn as_const_err(&self) -> Option<&EvalError> {
1278        match self {
1279            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1280            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1281            _ => None,
1282        }
1283    }
1284
1285    /// Checks if `self` is the single element collection with no columns.
1286    pub fn is_constant_singleton(&self) -> bool {
1287        if let Some((Ok(rows), typ)) = self.as_const() {
1288            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1289        } else {
1290            false
1291        }
1292    }
1293
1294    /// Constructs the expression for getting a local collection.
1295    pub fn local_get(id: LocalId, typ: SqlRelationType) -> Self {
1296        MirRelationExpr::Get {
1297            id: Id::Local(id),
1298            typ,
1299            access_strategy: AccessStrategy::UnknownOrLocal,
1300        }
1301    }
1302
1303    /// Constructs the expression for getting a global collection
1304    pub fn global_get(id: GlobalId, typ: SqlRelationType) -> Self {
1305        MirRelationExpr::Get {
1306            id: Id::Global(id),
1307            typ,
1308            access_strategy: AccessStrategy::UnknownOrLocal,
1309        }
1310    }
1311
1312    /// Retains only the columns specified by `output`.
1313    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1314        if let MirRelationExpr::Project {
1315            outputs: columns, ..
1316        } = &mut self
1317        {
1318            // Update `outputs` to reference base columns of `input`.
1319            for column in outputs.iter_mut() {
1320                *column = columns[*column];
1321            }
1322            *columns = outputs;
1323            self
1324        } else {
1325            MirRelationExpr::Project {
1326                input: Box::new(self),
1327                outputs,
1328            }
1329        }
1330    }
1331
1332    /// Append to each row the results of applying elements of `scalar`.
1333    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1334        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1335            s.extend(scalars);
1336            self
1337        } else if !scalars.is_empty() {
1338            MirRelationExpr::Map {
1339                input: Box::new(self),
1340                scalars,
1341            }
1342        } else {
1343            self
1344        }
1345    }
1346
1347    /// Append to each row a single `scalar`.
1348    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1349        self.map(vec![scalar])
1350    }
1351
1352    /// Like `map`, but yields zero-or-more output rows per input row
1353    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1354        MirRelationExpr::FlatMap {
1355            input: Box::new(self),
1356            func,
1357            exprs,
1358        }
1359    }
1360
1361    /// Retain only the rows satisfying each of several predicates.
1362    pub fn filter<I>(mut self, predicates: I) -> Self
1363    where
1364        I: IntoIterator<Item = MirScalarExpr>,
1365    {
1366        // Extract existing predicates
1367        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1368            self = *input;
1369            predicates
1370        } else {
1371            Vec::new()
1372        };
1373        // Normalize collection of predicates.
1374        new_predicates.extend(predicates);
1375        new_predicates.retain(|p| !p.is_literal_true());
1376        new_predicates.sort();
1377        new_predicates.dedup();
1378        // Introduce a `Filter` only if we have predicates.
1379        if !new_predicates.is_empty() {
1380            self = MirRelationExpr::Filter {
1381                input: Box::new(self),
1382                predicates: new_predicates,
1383            };
1384        }
1385
1386        self
1387    }
1388
1389    /// Form the Cartesian outer-product of rows in both inputs.
1390    pub fn product(mut self, right: Self) -> Self {
1391        if right.is_constant_singleton() {
1392            self
1393        } else if self.is_constant_singleton() {
1394            right
1395        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1396            inputs.push(right);
1397            self
1398        } else {
1399            MirRelationExpr::join(vec![self, right], vec![])
1400        }
1401    }
1402
1403    /// Performs a relational equijoin among the input collections.
1404    ///
1405    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1406    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1407    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1408    ///
1409    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1410    /// input collection and for each row examining its `column`th column.
1411    ///
1412    /// # Example
1413    ///
1414    /// ```rust
1415    /// use mz_repr::{Datum, SqlColumnType, SqlRelationType, SqlScalarType};
1416    /// use mz_expr::MirRelationExpr;
1417    ///
1418    /// // A common schema for each input.
1419    /// let schema = SqlRelationType::new(vec![
1420    ///     SqlScalarType::Int32.nullable(false),
1421    ///     SqlScalarType::Int32.nullable(false),
1422    /// ]);
1423    ///
1424    /// // the specific data are not important here.
1425    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1426    ///
1427    /// // Three collections that could have been different.
1428    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1429    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1430    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1431    ///
1432    /// // Join the three relations looking for triangles, like so.
1433    /// //
1434    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1435    /// let joined = MirRelationExpr::join(
1436    ///     vec![input0, input1, input2],
1437    ///     vec![
1438    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1439    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1440    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1441    ///     ],
1442    /// );
1443    ///
1444    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1445    /// // A projection resolves this and produces the correct output.
1446    /// let result = joined.project(vec![0, 1, 3]);
1447    /// ```
1448    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1449        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1450
1451        let equivalences = variables
1452            .into_iter()
1453            .map(|vs| {
1454                vs.into_iter()
1455                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1456                    .collect::<Vec<_>>()
1457            })
1458            .collect::<Vec<_>>();
1459
1460        Self::join_scalars(inputs, equivalences)
1461    }
1462
1463    /// Constructs a join operator from inputs and required-equal scalar expressions.
1464    pub fn join_scalars(
1465        mut inputs: Vec<MirRelationExpr>,
1466        equivalences: Vec<Vec<MirScalarExpr>>,
1467    ) -> Self {
1468        // Remove all constant inputs that are the identity for join.
1469        // They neither introduce nor modify any column references.
1470        inputs.retain(|i| !i.is_constant_singleton());
1471        MirRelationExpr::Join {
1472            inputs,
1473            equivalences,
1474            implementation: JoinImplementation::Unimplemented,
1475        }
1476    }
1477
1478    /// Perform a key-wise reduction / aggregation.
1479    ///
1480    /// The `group_key` argument indicates columns in the input collection that should
1481    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1482    /// one output column in addition to the keys.
1483    pub fn reduce(
1484        self,
1485        group_key: Vec<usize>,
1486        aggregates: Vec<AggregateExpr>,
1487        expected_group_size: Option<u64>,
1488    ) -> Self {
1489        MirRelationExpr::Reduce {
1490            input: Box::new(self),
1491            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1492            aggregates,
1493            monotonic: false,
1494            expected_group_size,
1495        }
1496    }
1497
1498    /// Perform a key-wise reduction order by and limit.
1499    ///
1500    /// The `group_key` argument indicates columns in the input collection that should
1501    /// be grouped, the `order_key` argument indicates columns that should be further
1502    /// used to order records within groups, and the `limit` argument constrains the
1503    /// total number of records that should be produced in each group.
1504    pub fn top_k(
1505        self,
1506        group_key: Vec<usize>,
1507        order_key: Vec<ColumnOrder>,
1508        limit: Option<MirScalarExpr>,
1509        offset: usize,
1510        expected_group_size: Option<u64>,
1511    ) -> Self {
1512        MirRelationExpr::TopK {
1513            input: Box::new(self),
1514            group_key,
1515            order_key,
1516            limit,
1517            offset,
1518            expected_group_size,
1519            monotonic: false,
1520        }
1521    }
1522
1523    /// Negates the occurrences of each row.
1524    pub fn negate(self) -> Self {
1525        if let MirRelationExpr::Negate { input } = self {
1526            *input
1527        } else {
1528            MirRelationExpr::Negate {
1529                input: Box::new(self),
1530            }
1531        }
1532    }
1533
1534    /// Removes all but the first occurrence of each row.
1535    pub fn distinct(self) -> Self {
1536        let arity = self.arity();
1537        self.distinct_by((0..arity).collect())
1538    }
1539
1540    /// Removes all but the first occurrence of each key. Columns not included
1541    /// in the `group_key` are discarded.
1542    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1543        self.reduce(group_key, vec![], None)
1544    }
1545
1546    /// Discards rows with a negative frequency.
1547    pub fn threshold(self) -> Self {
1548        if let MirRelationExpr::Threshold { .. } = &self {
1549            self
1550        } else {
1551            MirRelationExpr::Threshold {
1552                input: Box::new(self),
1553            }
1554        }
1555    }
1556
1557    /// Unions together any number inputs.
1558    ///
1559    /// If `inputs` is empty, then an empty relation of type `typ` is
1560    /// constructed.
1561    pub fn union_many(mut inputs: Vec<Self>, typ: SqlRelationType) -> Self {
1562        // Deconstruct `inputs` as `Union`s and reconstitute.
1563        let mut flat_inputs = Vec::with_capacity(inputs.len());
1564        for input in inputs {
1565            if let MirRelationExpr::Union { base, inputs } = input {
1566                flat_inputs.push(*base);
1567                flat_inputs.extend(inputs);
1568            } else {
1569                flat_inputs.push(input);
1570            }
1571        }
1572        inputs = flat_inputs;
1573        if inputs.len() == 0 {
1574            MirRelationExpr::Constant {
1575                rows: Ok(vec![]),
1576                typ,
1577            }
1578        } else if inputs.len() == 1 {
1579            inputs.into_element()
1580        } else {
1581            MirRelationExpr::Union {
1582                base: Box::new(inputs.remove(0)),
1583                inputs,
1584            }
1585        }
1586    }
1587
1588    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1589    pub fn union(self, other: Self) -> Self {
1590        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1591        let mut flat_inputs = Vec::with_capacity(2);
1592        if let MirRelationExpr::Union { base, inputs } = self {
1593            flat_inputs.push(*base);
1594            flat_inputs.extend(inputs);
1595        } else {
1596            flat_inputs.push(self);
1597        }
1598        if let MirRelationExpr::Union { base, inputs } = other {
1599            flat_inputs.push(*base);
1600            flat_inputs.extend(inputs);
1601        } else {
1602            flat_inputs.push(other);
1603        }
1604
1605        MirRelationExpr::Union {
1606            base: Box::new(flat_inputs.remove(0)),
1607            inputs: flat_inputs,
1608        }
1609    }
1610
1611    /// Arranges the collection by the specified columns
1612    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1613        MirRelationExpr::ArrangeBy {
1614            input: Box::new(self),
1615            keys: keys.to_owned(),
1616        }
1617    }
1618
1619    /// Indicates if this is a constant empty collection.
1620    ///
1621    /// A false value does not mean the collection is known to be non-empty,
1622    /// only that we cannot currently determine that it is statically empty.
1623    pub fn is_empty(&self) -> bool {
1624        if let Some((Ok(rows), ..)) = self.as_const() {
1625            rows.is_empty()
1626        } else {
1627            false
1628        }
1629    }
1630
1631    /// If the expression is a negated project, return the input and the projection.
1632    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1633        if let MirRelationExpr::Negate { input } = self {
1634            if let MirRelationExpr::Project { input, outputs } = &**input {
1635                return Some((&**input, outputs));
1636            }
1637        }
1638        if let MirRelationExpr::Project { input, outputs } = self {
1639            if let MirRelationExpr::Negate { input } = &**input {
1640                return Some((&**input, outputs));
1641            }
1642        }
1643        None
1644    }
1645
1646    /// Pretty-print this [MirRelationExpr] to a string.
1647    pub fn pretty(&self) -> String {
1648        let config = ExplainConfig::default();
1649        self.explain(&config, None)
1650    }
1651
1652    /// Pretty-print this [MirRelationExpr] to a string using a custom
1653    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1654    pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1655        text_string_at(self, || PlanRenderingContext {
1656            indent: Indent::default(),
1657            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1658            annotations: BTreeMap::default(),
1659            config,
1660            ambiguous_ids: BTreeSet::default(),
1661        })
1662    }
1663
1664    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1665    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1666    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1667    /// the best possible key and nullability, since we are making an empty collection.
1668    ///
1669    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1670    /// the correct type.
1671    pub fn take_safely(&mut self, typ: Option<SqlRelationType>) -> MirRelationExpr {
1672        if let Some(typ) = &typ {
1673            soft_assert_no_log!(
1674                self.typ()
1675                    .column_types
1676                    .iter()
1677                    .zip_eq(typ.column_types.iter())
1678                    .all(|(t1, t2)| t1.scalar_type.base_eq(&t2.scalar_type))
1679            );
1680        }
1681        let mut typ = typ.unwrap_or_else(|| self.typ());
1682        typ.keys = vec![vec![]];
1683        for ct in typ.column_types.iter_mut() {
1684            ct.nullable = false;
1685        }
1686        std::mem::replace(
1687            self,
1688            MirRelationExpr::Constant {
1689                rows: Ok(vec![]),
1690                typ,
1691            },
1692        )
1693    }
1694
1695    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1696    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1697    /// possible nullability, since we are making an empty collection.
1698    pub fn take_safely_with_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1699        self.take_safely(Some(SqlRelationType::new(typ)))
1700    }
1701
1702    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1703    ///
1704    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1705    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1706        let empty = MirRelationExpr::Constant {
1707            rows: Ok(vec![]),
1708            typ: SqlRelationType::new(Vec::new()),
1709        };
1710        std::mem::replace(self, empty)
1711    }
1712
1713    /// Replaces `self` with some logic applied to `self`.
1714    pub fn replace_using<F>(&mut self, logic: F)
1715    where
1716        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1717    {
1718        let empty = MirRelationExpr::Constant {
1719            rows: Ok(vec![]),
1720            typ: SqlRelationType::new(Vec::new()),
1721        };
1722        let expr = std::mem::replace(self, empty);
1723        *self = logic(expr);
1724    }
1725
1726    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1727    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1728    where
1729        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1730    {
1731        if let MirRelationExpr::Get { .. } = self {
1732            // already done
1733            body(id_gen, self)
1734        } else {
1735            let id = LocalId::new(id_gen.allocate_id());
1736            let get = MirRelationExpr::Get {
1737                id: Id::Local(id),
1738                typ: self.typ(),
1739                access_strategy: AccessStrategy::UnknownOrLocal,
1740            };
1741            let body = (body)(id_gen, get)?;
1742            Ok(MirRelationExpr::Let {
1743                id,
1744                value: Box::new(self),
1745                body: Box::new(body),
1746            })
1747        }
1748    }
1749
1750    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1751    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1752    pub fn anti_lookup<E>(
1753        self,
1754        id_gen: &mut IdGen,
1755        keys_and_values: MirRelationExpr,
1756        default: Vec<(Datum, SqlScalarType)>,
1757    ) -> Result<MirRelationExpr, E> {
1758        let (data, column_types): (Vec<_>, Vec<_>) = default
1759            .into_iter()
1760            .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1761            .unzip();
1762        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1763        self.let_in(id_gen, |_id_gen, get_keys| {
1764            let get_keys_arity = get_keys.arity();
1765            Ok(MirRelationExpr::join(
1766                vec![
1767                    // all the missing keys (with count 1)
1768                    keys_and_values
1769                        .distinct_by((0..get_keys_arity).collect())
1770                        .negate()
1771                        .union(get_keys.clone().distinct()),
1772                    // join with keys to get the correct counts
1773                    get_keys.clone(),
1774                ],
1775                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1776            )
1777            // get rid of the extra copies of columns from keys
1778            .project((0..get_keys_arity).collect())
1779            // This join is logically equivalent to
1780            // `.map(<default_expr>)`, but using a join allows for
1781            // potential predicate pushdown and elision in the
1782            // optimizer.
1783            .product(MirRelationExpr::constant(
1784                vec![data],
1785                SqlRelationType::new(column_types),
1786            )))
1787        })
1788    }
1789
1790    /// Return:
1791    /// * every row in keys_and_values
1792    /// * every row in `self` that does not have a matching row in the first columns of
1793    ///   `keys_and_values`, using `default` to fill in the remaining columns
1794    /// (This is LEFT OUTER JOIN if:
1795    /// 1) `default` is a row of null
1796    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1797    pub fn lookup<E>(
1798        self,
1799        id_gen: &mut IdGen,
1800        keys_and_values: MirRelationExpr,
1801        default: Vec<(Datum<'static>, SqlScalarType)>,
1802    ) -> Result<MirRelationExpr, E> {
1803        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1804            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1805                id_gen,
1806                get_keys_and_values,
1807                default,
1808            )?))
1809        })
1810    }
1811
1812    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1813    pub fn contains_temporal(&self) -> bool {
1814        let mut contains = false;
1815        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1816        contains
1817    }
1818
1819    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1820    ///
1821    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1822    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1823    where
1824        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1825    {
1826        use MirRelationExpr::*;
1827        match self {
1828            Map { scalars, .. } => {
1829                for s in scalars {
1830                    f(s)?;
1831                }
1832            }
1833            Filter { predicates, .. } => {
1834                for p in predicates {
1835                    f(p)?;
1836                }
1837            }
1838            FlatMap { exprs, .. } => {
1839                for expr in exprs {
1840                    f(expr)?;
1841                }
1842            }
1843            Join {
1844                inputs: _,
1845                equivalences,
1846                implementation,
1847            } => {
1848                for equivalence in equivalences {
1849                    for expr in equivalence {
1850                        f(expr)?;
1851                    }
1852                }
1853                match implementation {
1854                    JoinImplementation::Differential((_, start_key, _), order) => {
1855                        if let Some(start_key) = start_key {
1856                            for k in start_key {
1857                                f(k)?;
1858                            }
1859                        }
1860                        for (_, lookup_key, _) in order {
1861                            for k in lookup_key {
1862                                f(k)?;
1863                            }
1864                        }
1865                    }
1866                    JoinImplementation::DeltaQuery(paths) => {
1867                        for path in paths {
1868                            for (_, lookup_key, _) in path {
1869                                for k in lookup_key {
1870                                    f(k)?;
1871                                }
1872                            }
1873                        }
1874                    }
1875                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1876                        for k in index_key {
1877                            f(k)?;
1878                        }
1879                    }
1880                    JoinImplementation::Unimplemented => {} // No scalar exprs
1881                }
1882            }
1883            ArrangeBy { keys, .. } => {
1884                for key in keys {
1885                    for s in key {
1886                        f(s)?;
1887                    }
1888                }
1889            }
1890            Reduce {
1891                group_key,
1892                aggregates,
1893                ..
1894            } => {
1895                for s in group_key {
1896                    f(s)?;
1897                }
1898                for agg in aggregates {
1899                    f(&mut agg.expr)?;
1900                }
1901            }
1902            TopK { limit, .. } => {
1903                if let Some(s) = limit {
1904                    f(s)?;
1905                }
1906            }
1907            Constant { .. }
1908            | Get { .. }
1909            | Let { .. }
1910            | LetRec { .. }
1911            | Project { .. }
1912            | Negate { .. }
1913            | Threshold { .. }
1914            | Union { .. } => (),
1915        }
1916        Ok(())
1917    }
1918
1919    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1920    /// rooted at `self`.
1921    ///
1922    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1923    /// nodes.
1924    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1925    where
1926        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1927        E: From<RecursionLimitError>,
1928    {
1929        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1930    }
1931
1932    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1933    /// rooted at `self`.
1934    ///
1935    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1936    /// nodes.
1937    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1938    where
1939        F: FnMut(&mut MirScalarExpr),
1940    {
1941        self.try_visit_scalars_mut(&mut |s| {
1942            f(s);
1943            Ok::<_, RecursionLimitError>(())
1944        })
1945        .expect("Unexpected error in `visit_scalars_mut` call");
1946    }
1947
1948    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1949    ///
1950    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1951    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1952    where
1953        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1954    {
1955        use MirRelationExpr::*;
1956        match self {
1957            Map { scalars, .. } => {
1958                for s in scalars {
1959                    f(s)?;
1960                }
1961            }
1962            Filter { predicates, .. } => {
1963                for p in predicates {
1964                    f(p)?;
1965                }
1966            }
1967            FlatMap { exprs, .. } => {
1968                for expr in exprs {
1969                    f(expr)?;
1970                }
1971            }
1972            Join {
1973                inputs: _,
1974                equivalences,
1975                implementation,
1976            } => {
1977                for equivalence in equivalences {
1978                    for expr in equivalence {
1979                        f(expr)?;
1980                    }
1981                }
1982                match implementation {
1983                    JoinImplementation::Differential((_, start_key, _), order) => {
1984                        if let Some(start_key) = start_key {
1985                            for k in start_key {
1986                                f(k)?;
1987                            }
1988                        }
1989                        for (_, lookup_key, _) in order {
1990                            for k in lookup_key {
1991                                f(k)?;
1992                            }
1993                        }
1994                    }
1995                    JoinImplementation::DeltaQuery(paths) => {
1996                        for path in paths {
1997                            for (_, lookup_key, _) in path {
1998                                for k in lookup_key {
1999                                    f(k)?;
2000                                }
2001                            }
2002                        }
2003                    }
2004                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2005                        for k in index_key {
2006                            f(k)?;
2007                        }
2008                    }
2009                    JoinImplementation::Unimplemented => {} // No scalar exprs
2010                }
2011            }
2012            ArrangeBy { keys, .. } => {
2013                for key in keys {
2014                    for s in key {
2015                        f(s)?;
2016                    }
2017                }
2018            }
2019            Reduce {
2020                group_key,
2021                aggregates,
2022                ..
2023            } => {
2024                for s in group_key {
2025                    f(s)?;
2026                }
2027                for agg in aggregates {
2028                    f(&agg.expr)?;
2029                }
2030            }
2031            TopK { limit, .. } => {
2032                if let Some(s) = limit {
2033                    f(s)?;
2034                }
2035            }
2036            Constant { .. }
2037            | Get { .. }
2038            | Let { .. }
2039            | LetRec { .. }
2040            | Project { .. }
2041            | Negate { .. }
2042            | Threshold { .. }
2043            | Union { .. } => (),
2044        }
2045        Ok(())
2046    }
2047
2048    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2049    /// rooted at `self`.
2050    ///
2051    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2052    /// nodes.
2053    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2054    where
2055        F: FnMut(&MirScalarExpr) -> Result<(), E>,
2056        E: From<RecursionLimitError>,
2057    {
2058        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2059    }
2060
2061    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2062    /// rooted at `self`.
2063    ///
2064    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2065    /// nodes.
2066    pub fn visit_scalars<F>(&self, f: &mut F)
2067    where
2068        F: FnMut(&MirScalarExpr),
2069    {
2070        self.try_visit_scalars(&mut |s| {
2071            f(s);
2072            Ok::<_, RecursionLimitError>(())
2073        })
2074        .expect("Unexpected error in `visit_scalars` call");
2075    }
2076
2077    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2078    /// stack overflow in `drop_in_place`.
2079    ///
2080    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2081    /// dropped or otherwise overwritten.
2082    pub fn destroy_carefully(&mut self) {
2083        let mut todo = vec![self.take_dangerous()];
2084        while let Some(mut expr) = todo.pop() {
2085            for child in expr.children_mut() {
2086                todo.push(child.take_dangerous());
2087            }
2088        }
2089    }
2090
2091    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2092    /// debug printing purposes.
2093    pub fn debug_size_and_depth(&self) -> (usize, usize) {
2094        let mut size = 0;
2095        let mut max_depth = 0;
2096        let mut todo = vec![(self, 1)];
2097        while let Some((expr, depth)) = todo.pop() {
2098            size += 1;
2099            max_depth = max(max_depth, depth);
2100            todo.extend(expr.children().map(|c| (c, depth + 1)));
2101        }
2102        (size, max_depth)
2103    }
2104
2105    /// The MirRelationExpr is considered potentially expensive if and only if
2106    /// at least one of the following conditions is true:
2107    ///
2108    ///  - It contains at least one FlatMap or a Reduce operator.
2109    ///  - It contains at least one MirScalarExpr with a function call.
2110    ///
2111    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2112    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2113    pub fn could_run_expensive_function(&self) -> bool {
2114        let mut result = false;
2115        self.visit_pre(|e: &MirRelationExpr| {
2116            use MirRelationExpr::*;
2117            use MirScalarExpr::*;
2118            if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2119                result |= match scalar {
2120                    Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2121                    // Function calls are considered expensive
2122                    CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2123                };
2124                Ok(())
2125            }) {
2126                // Conservatively set `true` if on RecursionLimitError.
2127                result = true;
2128            }
2129            // FlatMap has a table function; Reduce has an aggregate function.
2130            // Other constructs use MirScalarExpr to run a function
2131            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2132        });
2133        result
2134    }
2135
2136    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2137    /// than what `Hashable::hashed` would give us.)
2138    pub fn hash_to_u64(&self) -> u64 {
2139        let mut h = DefaultHasher::new();
2140        self.hash(&mut h);
2141        h.finish()
2142    }
2143}
2144
2145// `LetRec` helpers
2146impl MirRelationExpr {
2147    /// True when `expr` contains a `LetRec` AST node.
2148    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2149        let mut worklist = vec![self];
2150        while let Some(expr) = worklist.pop() {
2151            if let MirRelationExpr::LetRec { .. } = expr {
2152                return true;
2153            }
2154            worklist.extend(expr.children());
2155        }
2156        false
2157    }
2158
2159    /// Return the number of sub-expressions in the tree (including self).
2160    pub fn size(&self) -> usize {
2161        let mut size = 0;
2162        self.visit_pre(|_| size += 1);
2163        size
2164    }
2165
2166    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2167    /// iterations. These are those ids that have a reference before they are defined, when reading
2168    /// all the bindings in order.
2169    ///
2170    /// For example:
2171    /// ```SQL
2172    /// WITH MUTUALLY RECURSIVE
2173    ///     x(...) AS f(z),
2174    ///     y(...) AS g(x),
2175    ///     z(...) AS h(y)
2176    /// ...;
2177    /// ```
2178    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2179    /// iteration.
2180    ///
2181    /// Note that if a binding references itself, that is also returned.
2182    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2183        let mut used_across_iterations = BTreeSet::new();
2184        let mut defined = BTreeSet::new();
2185        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2186            value.visit_pre(|expr| {
2187                if let MirRelationExpr::Get {
2188                    id: Local(get_id), ..
2189                } = expr
2190                {
2191                    // If we haven't seen a definition for it yet, then this will refer
2192                    // to the previous iteration.
2193                    // The `ids.contains` part of the condition is needed to exclude
2194                    // those ids that are not really in this LetRec, but either an inner
2195                    // or outer one.
2196                    if !defined.contains(get_id) && ids.contains(get_id) {
2197                        used_across_iterations.insert(*get_id);
2198                    }
2199                }
2200            });
2201            defined.insert(*binding_id);
2202        }
2203        used_across_iterations
2204    }
2205
2206    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2207    ///
2208    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2209    /// identifiers are rewritten to be the constant collection.
2210    /// This makes the computation perform exactly "one" iteration.
2211    ///
2212    /// This was used only temporarily while developing `LetRec`.
2213    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2214        let mut deadlist = BTreeSet::new();
2215        let mut worklist = vec![self];
2216        while let Some(expr) = worklist.pop() {
2217            if let MirRelationExpr::LetRec {
2218                ids,
2219                values,
2220                limits: _,
2221                body,
2222            } = expr
2223            {
2224                let ids_values = values
2225                    .drain(..)
2226                    .zip_eq(ids)
2227                    .map(|(value, id)| (*id, value))
2228                    .collect::<Vec<_>>();
2229                *expr = body.take_dangerous();
2230                for (id, mut value) in ids_values.into_iter().rev() {
2231                    // Remove references to potentially recursive identifiers.
2232                    deadlist.insert(id);
2233                    value.visit_pre_mut(|e| {
2234                        if let MirRelationExpr::Get {
2235                            id: crate::Id::Local(id),
2236                            typ,
2237                            ..
2238                        } = e
2239                        {
2240                            let typ = typ.clone();
2241                            if deadlist.contains(id) {
2242                                e.take_safely(Some(typ));
2243                            }
2244                        }
2245                    });
2246                    *expr = MirRelationExpr::Let {
2247                        id,
2248                        value: Box::new(value),
2249                        body: Box::new(expr.take_dangerous()),
2250                    };
2251                }
2252                worklist.push(expr);
2253            } else {
2254                worklist.extend(expr.children_mut().rev());
2255            }
2256        }
2257    }
2258
2259    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2260    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2261    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2262    /// redefinition.
2263    ///
2264    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2265    pub fn collect_expirations(
2266        id: LocalId,
2267        expr: &MirRelationExpr,
2268        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2269    ) {
2270        expr.visit_pre(|e| {
2271            if let MirRelationExpr::Get {
2272                id: Id::Local(referenced_id),
2273                ..
2274            } = e
2275            {
2276                // The following check needs `renumber_bindings` to have run recently
2277                if referenced_id >= &id {
2278                    expire_whens
2279                        .entry(*referenced_id)
2280                        .or_insert_with(Vec::new)
2281                        .push(id);
2282                }
2283            }
2284        });
2285    }
2286
2287    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2288    /// about such Ids whose information depended on the earlier definition of `id`, according to
2289    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2290    pub fn do_expirations<I>(
2291        redefined_id: LocalId,
2292        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2293        id_infos: &mut BTreeMap<LocalId, I>,
2294    ) -> Vec<(LocalId, I)> {
2295        let mut expired_infos = Vec::new();
2296        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2297            for expired_id in expirations.into_iter() {
2298                if let Some(offer) = id_infos.remove(&expired_id) {
2299                    expired_infos.push((expired_id, offer));
2300                }
2301            }
2302        }
2303        expired_infos
2304    }
2305}
2306/// Augment non-nullability of columns, by observing either
2307/// 1. Predicates that explicitly test for null values, and
2308/// 2. Columns that if null would make a predicate be null.
2309pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2310    let mut nonnull_required_columns = BTreeSet::new();
2311    for predicate in predicates {
2312        // Add any columns that being null would force the predicate to be null.
2313        // Should that happen, the row would be discarded.
2314        predicate.non_null_requirements(&mut nonnull_required_columns);
2315
2316        /*
2317        Test for explicit checks that a column is non-null.
2318
2319        This analysis is ad hoc, and will miss things:
2320
2321        materialize=> create table a(x int, y int);
2322        CREATE TABLE
2323        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2324        Optimized Plan
2325        --------------------------------------------------------------------------------------------------------
2326        Explained Query:                                                                                      +
2327        Project (#0) // { types: "(integer?)" }                                                             +
2328        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2329        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2330                                                                                  +
2331        Source materialize.public.a                                                                           +
2332        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2333
2334        (1 row)
2335        */
2336
2337        if let MirScalarExpr::CallUnary {
2338            func: UnaryFunc::Not(scalar_func::Not),
2339            expr,
2340        } = predicate
2341        {
2342            if let MirScalarExpr::CallUnary {
2343                func: UnaryFunc::IsNull(scalar_func::IsNull),
2344                expr,
2345            } = &**expr
2346            {
2347                if let MirScalarExpr::Column(c, _name) = &**expr {
2348                    nonnull_required_columns.insert(*c);
2349                }
2350            }
2351        }
2352    }
2353
2354    nonnull_required_columns
2355}
2356
2357impl CollectionPlan for MirRelationExpr {
2358    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2359    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2360    ///
2361    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2362    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2363    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2364        if let MirRelationExpr::Get {
2365            id: Id::Global(id), ..
2366        } = self
2367        {
2368            out.insert(*id);
2369        }
2370        self.visit_children(|expr| expr.depends_on_into(out))
2371    }
2372}
2373
2374impl MirRelationExpr {
2375    /// Iterates through references to child expressions.
2376    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2377        let mut first = None;
2378        let mut second = None;
2379        let mut rest = None;
2380        let mut last = None;
2381
2382        use MirRelationExpr::*;
2383        match self {
2384            Constant { .. } | Get { .. } => (),
2385            Let { value, body, .. } => {
2386                first = Some(&**value);
2387                second = Some(&**body);
2388            }
2389            LetRec { values, body, .. } => {
2390                rest = Some(values);
2391                last = Some(&**body);
2392            }
2393            Project { input, .. }
2394            | Map { input, .. }
2395            | FlatMap { input, .. }
2396            | Filter { input, .. }
2397            | Reduce { input, .. }
2398            | TopK { input, .. }
2399            | Negate { input }
2400            | Threshold { input }
2401            | ArrangeBy { input, .. } => {
2402                first = Some(&**input);
2403            }
2404            Join { inputs, .. } => {
2405                rest = Some(inputs);
2406            }
2407            Union { base, inputs } => {
2408                first = Some(&**base);
2409                rest = Some(inputs);
2410            }
2411        }
2412
2413        first
2414            .into_iter()
2415            .chain(second)
2416            .chain(rest.into_iter().flatten())
2417            .chain(last)
2418    }
2419
2420    /// Iterates through mutable references to child expressions.
2421    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2422        let mut first = None;
2423        let mut second = None;
2424        let mut rest = None;
2425        let mut last = None;
2426
2427        use MirRelationExpr::*;
2428        match self {
2429            Constant { .. } | Get { .. } => (),
2430            Let { value, body, .. } => {
2431                first = Some(&mut **value);
2432                second = Some(&mut **body);
2433            }
2434            LetRec { values, body, .. } => {
2435                rest = Some(values);
2436                last = Some(&mut **body);
2437            }
2438            Project { input, .. }
2439            | Map { input, .. }
2440            | FlatMap { input, .. }
2441            | Filter { input, .. }
2442            | Reduce { input, .. }
2443            | TopK { input, .. }
2444            | Negate { input }
2445            | Threshold { input }
2446            | ArrangeBy { input, .. } => {
2447                first = Some(&mut **input);
2448            }
2449            Join { inputs, .. } => {
2450                rest = Some(inputs);
2451            }
2452            Union { base, inputs } => {
2453                first = Some(&mut **base);
2454                rest = Some(inputs);
2455            }
2456        }
2457
2458        first
2459            .into_iter()
2460            .chain(second)
2461            .chain(rest.into_iter().flatten())
2462            .chain(last)
2463    }
2464
2465    /// Iterative pre-order visitor.
2466    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2467        let mut worklist = vec![self];
2468        while let Some(expr) = worklist.pop() {
2469            f(expr);
2470            worklist.extend(expr.children().rev());
2471        }
2472    }
2473
2474    /// Iterative pre-order visitor.
2475    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2476        let mut worklist = vec![self];
2477        while let Some(expr) = worklist.pop() {
2478            f(expr);
2479            worklist.extend(expr.children_mut().rev());
2480        }
2481    }
2482
2483    /// Return a vector of references to the subtrees of this expression
2484    /// in post-visit order (the last element is `&self`).
2485    pub fn post_order_vec(&self) -> Vec<&Self> {
2486        let mut stack = vec![self];
2487        let mut result = vec![];
2488        while let Some(expr) = stack.pop() {
2489            result.push(expr);
2490            stack.extend(expr.children());
2491        }
2492        result.reverse();
2493        result
2494    }
2495}
2496
2497impl VisitChildren<Self> for MirRelationExpr {
2498    fn visit_children<F>(&self, mut f: F)
2499    where
2500        F: FnMut(&Self),
2501    {
2502        for child in self.children() {
2503            f(child)
2504        }
2505    }
2506
2507    fn visit_mut_children<F>(&mut self, mut f: F)
2508    where
2509        F: FnMut(&mut Self),
2510    {
2511        for child in self.children_mut() {
2512            f(child)
2513        }
2514    }
2515
2516    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2517    where
2518        F: FnMut(&Self) -> Result<(), E>,
2519        E: From<RecursionLimitError>,
2520    {
2521        for child in self.children() {
2522            f(child)?
2523        }
2524        Ok(())
2525    }
2526
2527    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2528    where
2529        F: FnMut(&mut Self) -> Result<(), E>,
2530        E: From<RecursionLimitError>,
2531    {
2532        for child in self.children_mut() {
2533            f(child)?
2534        }
2535        Ok(())
2536    }
2537}
2538
2539/// Specification for an ordering by a column.
2540#[derive(
2541    Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
2542)]
2543pub struct ColumnOrder {
2544    /// The column index.
2545    pub column: usize,
2546    /// Whether to sort in descending order.
2547    #[serde(default)]
2548    pub desc: bool,
2549    /// Whether to sort nulls last.
2550    #[serde(default)]
2551    pub nulls_last: bool,
2552}
2553
2554impl Columnation for ColumnOrder {
2555    type InnerRegion = CopyRegion<Self>;
2556}
2557
2558impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2559where
2560    M: HumanizerMode,
2561{
2562    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2563        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2564        write!(
2565            f,
2566            "{} {} {}",
2567            self.child(&self.expr.column),
2568            if self.expr.desc { "desc" } else { "asc" },
2569            if self.expr.nulls_last {
2570                "nulls_last"
2571            } else {
2572                "nulls_first"
2573            },
2574        )
2575    }
2576}
2577
2578/// Describes an aggregation expression.
2579#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
2580pub struct AggregateExpr {
2581    /// Names the aggregation function.
2582    pub func: AggregateFunc,
2583    /// An expression which extracts from each row the input to `func`.
2584    pub expr: MirScalarExpr,
2585    /// Should the aggregation be applied only to distinct results in each group.
2586    #[serde(default)]
2587    pub distinct: bool,
2588}
2589
2590impl AggregateExpr {
2591    /// Computes the type of this `AggregateExpr`.
2592    pub fn typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2593        self.func.output_type(self.expr.typ(column_types))
2594    }
2595
2596    /// Computes the type of this `AggregateExpr`.
2597    pub fn repr_typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2598        ReprColumnType::from(
2599            &self
2600                .func
2601                .output_type(SqlColumnType::from_repr(&self.expr.repr_typ(column_types))),
2602        )
2603    }
2604
2605    /// Returns whether the expression has a constant result.
2606    pub fn is_constant(&self) -> bool {
2607        match self.func {
2608            AggregateFunc::MaxInt16
2609            | AggregateFunc::MaxInt32
2610            | AggregateFunc::MaxInt64
2611            | AggregateFunc::MaxUInt16
2612            | AggregateFunc::MaxUInt32
2613            | AggregateFunc::MaxUInt64
2614            | AggregateFunc::MaxMzTimestamp
2615            | AggregateFunc::MaxFloat32
2616            | AggregateFunc::MaxFloat64
2617            | AggregateFunc::MaxBool
2618            | AggregateFunc::MaxString
2619            | AggregateFunc::MaxDate
2620            | AggregateFunc::MaxTimestamp
2621            | AggregateFunc::MaxTimestampTz
2622            | AggregateFunc::MinInt16
2623            | AggregateFunc::MinInt32
2624            | AggregateFunc::MinInt64
2625            | AggregateFunc::MinUInt16
2626            | AggregateFunc::MinUInt32
2627            | AggregateFunc::MinUInt64
2628            | AggregateFunc::MinMzTimestamp
2629            | AggregateFunc::MinFloat32
2630            | AggregateFunc::MinFloat64
2631            | AggregateFunc::MinBool
2632            | AggregateFunc::MinString
2633            | AggregateFunc::MinDate
2634            | AggregateFunc::MinTimestamp
2635            | AggregateFunc::MinTimestampTz
2636            | AggregateFunc::Any
2637            | AggregateFunc::All
2638            | AggregateFunc::Dummy => self.expr.is_literal(),
2639            AggregateFunc::Count => self.expr.is_literal_null(),
2640            _ => self.expr.is_literal_err(),
2641        }
2642    }
2643
2644    /// Returns an expression that computes `self` on a group that has exactly one row.
2645    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2646    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2647    pub fn on_unique(&self, input_type: &[SqlColumnType]) -> MirScalarExpr {
2648        match &self.func {
2649            // Count is one if non-null, and zero if null.
2650            AggregateFunc::Count => self
2651                .expr
2652                .clone()
2653                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2654                .if_then_else(
2655                    MirScalarExpr::literal_ok(Datum::Int64(0), SqlScalarType::Int64),
2656                    MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
2657                ),
2658
2659            // SumInt16 takes Int16s as input, but outputs Int64s.
2660            AggregateFunc::SumInt16 => self
2661                .expr
2662                .clone()
2663                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2664
2665            // SumInt32 takes Int32s as input, but outputs Int64s.
2666            AggregateFunc::SumInt32 => self
2667                .expr
2668                .clone()
2669                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2670
2671            // SumInt64 takes Int64s as input, but outputs numerics.
2672            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2673                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2674            )),
2675
2676            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2677            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2678                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2679            ),
2680
2681            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2682            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2683                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2684            ),
2685
2686            // SumUInt64 takes UInt64s as input, but outputs numerics.
2687            AggregateFunc::SumUInt64 => {
2688                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2689                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2690                ))
2691            }
2692
2693            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2694            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2695                func: VariadicFunc::JsonbBuildArray,
2696                exprs: vec![
2697                    self.expr
2698                        .clone()
2699                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2700                ],
2701            },
2702
2703            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2704            AggregateFunc::JsonbObjectAgg { .. } => {
2705                let record = self
2706                    .expr
2707                    .clone()
2708                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2709                MirScalarExpr::CallVariadic {
2710                    func: VariadicFunc::JsonbBuildObject,
2711                    exprs: (0..2)
2712                        .map(|i| {
2713                            record
2714                                .clone()
2715                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2716                        })
2717                        .collect(),
2718                }
2719            }
2720
2721            AggregateFunc::MapAgg { value_type, .. } => {
2722                let record = self
2723                    .expr
2724                    .clone()
2725                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2726                MirScalarExpr::CallVariadic {
2727                    func: VariadicFunc::MapBuild {
2728                        value_type: value_type.clone(),
2729                    },
2730                    exprs: (0..2)
2731                        .map(|i| {
2732                            record
2733                                .clone()
2734                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2735                        })
2736                        .collect(),
2737                }
2738            }
2739
2740            // StringAgg takes nested records of strings and outputs a string
2741            AggregateFunc::StringAgg { .. } => self
2742                .expr
2743                .clone()
2744                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2745                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2746
2747            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2748            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2749                .expr
2750                .clone()
2751                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2752
2753            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2754            AggregateFunc::RowNumber { .. } => {
2755                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2756            }
2757            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2758            AggregateFunc::DenseRank { .. } => {
2759                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2760            }
2761
2762            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2763            AggregateFunc::LagLead { lag_lead, .. } => {
2764                let tuple = self
2765                    .expr
2766                    .clone()
2767                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2768
2769                // Get the overall return type
2770                let return_type_with_orig_row = self
2771                    .typ(input_type)
2772                    .scalar_type
2773                    .unwrap_list_element_type()
2774                    .clone();
2775                let lag_lead_return_type =
2776                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2777
2778                // Extract the original row
2779                let original_row = tuple
2780                    .clone()
2781                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2782
2783                // Extract the encoded args
2784                let encoded_args =
2785                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2786
2787                let (result_expr, column_name) =
2788                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2789
2790                MirScalarExpr::CallVariadic {
2791                    func: VariadicFunc::ListCreate {
2792                        elem_type: return_type_with_orig_row,
2793                    },
2794                    exprs: vec![MirScalarExpr::CallVariadic {
2795                        func: VariadicFunc::RecordCreate {
2796                            field_names: vec![column_name, ColumnName::from("?record?")],
2797                        },
2798                        exprs: vec![result_expr, original_row],
2799                    }],
2800                }
2801            }
2802
2803            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2804            AggregateFunc::FirstValue { window_frame, .. } => {
2805                let tuple = self
2806                    .expr
2807                    .clone()
2808                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2809
2810                // Get the overall return type
2811                let return_type_with_orig_row = self
2812                    .typ(input_type)
2813                    .scalar_type
2814                    .unwrap_list_element_type()
2815                    .clone();
2816                let first_value_return_type =
2817                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2818
2819                // Extract the original row
2820                let original_row = tuple
2821                    .clone()
2822                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2823
2824                // Extract the input value
2825                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2826
2827                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2828                    window_frame,
2829                    arg,
2830                    first_value_return_type,
2831                );
2832
2833                MirScalarExpr::CallVariadic {
2834                    func: VariadicFunc::ListCreate {
2835                        elem_type: return_type_with_orig_row,
2836                    },
2837                    exprs: vec![MirScalarExpr::CallVariadic {
2838                        func: VariadicFunc::RecordCreate {
2839                            field_names: vec![column_name, ColumnName::from("?record?")],
2840                        },
2841                        exprs: vec![result_expr, original_row],
2842                    }],
2843                }
2844            }
2845
2846            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2847            AggregateFunc::LastValue { window_frame, .. } => {
2848                let tuple = self
2849                    .expr
2850                    .clone()
2851                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2852
2853                // Get the overall return type
2854                let return_type_with_orig_row = self
2855                    .typ(input_type)
2856                    .scalar_type
2857                    .unwrap_list_element_type()
2858                    .clone();
2859                let last_value_return_type =
2860                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2861
2862                // Extract the original row
2863                let original_row = tuple
2864                    .clone()
2865                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2866
2867                // Extract the input value
2868                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2869
2870                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2871                    window_frame,
2872                    arg,
2873                    last_value_return_type,
2874                );
2875
2876                MirScalarExpr::CallVariadic {
2877                    func: VariadicFunc::ListCreate {
2878                        elem_type: return_type_with_orig_row,
2879                    },
2880                    exprs: vec![MirScalarExpr::CallVariadic {
2881                        func: VariadicFunc::RecordCreate {
2882                            field_names: vec![column_name, ColumnName::from("?record?")],
2883                        },
2884                        exprs: vec![result_expr, original_row],
2885                    }],
2886                }
2887            }
2888
2889            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2890            // See an example MIR in `window_func_applied_to`.
2891            AggregateFunc::WindowAggregate {
2892                wrapped_aggregate,
2893                window_frame,
2894                order_by: _,
2895            } => {
2896                // TODO: deduplicate code between the various window function cases.
2897
2898                let tuple = self
2899                    .expr
2900                    .clone()
2901                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2902
2903                // Get the overall return type
2904                let return_type = self
2905                    .typ(input_type)
2906                    .scalar_type
2907                    .unwrap_list_element_type()
2908                    .clone();
2909                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2910
2911                // Extract the original row
2912                let original_row = tuple
2913                    .clone()
2914                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2915
2916                // Extract the input value
2917                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2918
2919                let (result, column_name) = Self::on_unique_window_agg(
2920                    window_frame,
2921                    arg_expr,
2922                    input_type,
2923                    window_agg_return_type,
2924                    wrapped_aggregate,
2925                );
2926
2927                MirScalarExpr::CallVariadic {
2928                    func: VariadicFunc::ListCreate {
2929                        elem_type: return_type,
2930                    },
2931                    exprs: vec![MirScalarExpr::CallVariadic {
2932                        func: VariadicFunc::RecordCreate {
2933                            field_names: vec![column_name, ColumnName::from("?record?")],
2934                        },
2935                        exprs: vec![result, original_row],
2936                    }],
2937                }
2938            }
2939
2940            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2941            AggregateFunc::FusedWindowAggregate {
2942                wrapped_aggregates,
2943                order_by: _,
2944                window_frame,
2945            } => {
2946                // Throw away OrderByExprs
2947                let tuple = self
2948                    .expr
2949                    .clone()
2950                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2951
2952                // Extract the original row
2953                let original_row = tuple
2954                    .clone()
2955                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2956
2957                // Extract the args of the fused call
2958                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2959
2960                let return_type_with_orig_row = self
2961                    .typ(input_type)
2962                    .scalar_type
2963                    .unwrap_list_element_type()
2964                    .clone();
2965
2966                let all_func_return_types =
2967                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2968                let mut func_result_exprs = Vec::new();
2969                let mut col_names = Vec::new();
2970                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2971                    let arg = all_args
2972                        .clone()
2973                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2974                    let return_type =
2975                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2976                    let (result, column_name) = Self::on_unique_window_agg(
2977                        window_frame,
2978                        arg,
2979                        input_type,
2980                        return_type,
2981                        wrapped_aggr,
2982                    );
2983                    func_result_exprs.push(result);
2984                    col_names.push(column_name);
2985                }
2986
2987                MirScalarExpr::CallVariadic {
2988                    func: VariadicFunc::ListCreate {
2989                        elem_type: return_type_with_orig_row,
2990                    },
2991                    exprs: vec![MirScalarExpr::CallVariadic {
2992                        func: VariadicFunc::RecordCreate {
2993                            field_names: vec![
2994                                ColumnName::from("?fused_window_aggr?"),
2995                                ColumnName::from("?record?"),
2996                            ],
2997                        },
2998                        exprs: vec![
2999                            MirScalarExpr::CallVariadic {
3000                                func: VariadicFunc::RecordCreate {
3001                                    field_names: col_names,
3002                                },
3003                                exprs: func_result_exprs,
3004                            },
3005                            original_row,
3006                        ],
3007                    }],
3008                }
3009            }
3010
3011            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3012            AggregateFunc::FusedValueWindowFunc {
3013                funcs,
3014                order_by: outer_order_by,
3015            } => {
3016                // Throw away OrderByExprs
3017                let tuple = self
3018                    .expr
3019                    .clone()
3020                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3021
3022                // Extract the original row
3023                let original_row = tuple
3024                    .clone()
3025                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3026
3027                // Extract the encoded args of the fused call
3028                let all_encoded_args =
3029                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3030
3031                let return_type_with_orig_row = self
3032                    .typ(input_type)
3033                    .scalar_type
3034                    .unwrap_list_element_type()
3035                    .clone();
3036
3037                let all_func_return_types =
3038                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3039                let mut func_result_exprs = Vec::new();
3040                let mut col_names = Vec::new();
3041                for (idx, func) in funcs.iter().enumerate() {
3042                    let args_for_func = all_encoded_args
3043                        .clone()
3044                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3045                    let return_type_for_func =
3046                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3047                    let (result, column_name) = match func {
3048                        AggregateFunc::LagLead {
3049                            lag_lead,
3050                            order_by,
3051                            ignore_nulls: _,
3052                        } => {
3053                            assert_eq!(order_by, outer_order_by);
3054                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3055                        }
3056                        AggregateFunc::FirstValue {
3057                            window_frame,
3058                            order_by,
3059                        } => {
3060                            assert_eq!(order_by, outer_order_by);
3061                            Self::on_unique_first_value_last_value(
3062                                window_frame,
3063                                args_for_func,
3064                                return_type_for_func,
3065                            )
3066                        }
3067                        AggregateFunc::LastValue {
3068                            window_frame,
3069                            order_by,
3070                        } => {
3071                            assert_eq!(order_by, outer_order_by);
3072                            Self::on_unique_first_value_last_value(
3073                                window_frame,
3074                                args_for_func,
3075                                return_type_for_func,
3076                            )
3077                        }
3078                        _ => panic!("unknown function in FusedValueWindowFunc"),
3079                    };
3080                    func_result_exprs.push(result);
3081                    col_names.push(column_name);
3082                }
3083
3084                MirScalarExpr::CallVariadic {
3085                    func: VariadicFunc::ListCreate {
3086                        elem_type: return_type_with_orig_row,
3087                    },
3088                    exprs: vec![MirScalarExpr::CallVariadic {
3089                        func: VariadicFunc::RecordCreate {
3090                            field_names: vec![
3091                                ColumnName::from("?fused_value_window_func?"),
3092                                ColumnName::from("?record?"),
3093                            ],
3094                        },
3095                        exprs: vec![
3096                            MirScalarExpr::CallVariadic {
3097                                func: VariadicFunc::RecordCreate {
3098                                    field_names: col_names,
3099                                },
3100                                exprs: func_result_exprs,
3101                            },
3102                            original_row,
3103                        ],
3104                    }],
3105                }
3106            }
3107
3108            // All other variants should return the argument to the aggregation.
3109            AggregateFunc::MaxNumeric
3110            | AggregateFunc::MaxInt16
3111            | AggregateFunc::MaxInt32
3112            | AggregateFunc::MaxInt64
3113            | AggregateFunc::MaxUInt16
3114            | AggregateFunc::MaxUInt32
3115            | AggregateFunc::MaxUInt64
3116            | AggregateFunc::MaxMzTimestamp
3117            | AggregateFunc::MaxFloat32
3118            | AggregateFunc::MaxFloat64
3119            | AggregateFunc::MaxBool
3120            | AggregateFunc::MaxString
3121            | AggregateFunc::MaxDate
3122            | AggregateFunc::MaxTimestamp
3123            | AggregateFunc::MaxTimestampTz
3124            | AggregateFunc::MaxInterval
3125            | AggregateFunc::MaxTime
3126            | AggregateFunc::MinNumeric
3127            | AggregateFunc::MinInt16
3128            | AggregateFunc::MinInt32
3129            | AggregateFunc::MinInt64
3130            | AggregateFunc::MinUInt16
3131            | AggregateFunc::MinUInt32
3132            | AggregateFunc::MinUInt64
3133            | AggregateFunc::MinMzTimestamp
3134            | AggregateFunc::MinFloat32
3135            | AggregateFunc::MinFloat64
3136            | AggregateFunc::MinBool
3137            | AggregateFunc::MinString
3138            | AggregateFunc::MinDate
3139            | AggregateFunc::MinTimestamp
3140            | AggregateFunc::MinTimestampTz
3141            | AggregateFunc::MinInterval
3142            | AggregateFunc::MinTime
3143            | AggregateFunc::SumFloat32
3144            | AggregateFunc::SumFloat64
3145            | AggregateFunc::SumNumeric
3146            | AggregateFunc::Any
3147            | AggregateFunc::All
3148            | AggregateFunc::Dummy => self.expr.clone(),
3149        }
3150    }
3151
3152    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3153    fn on_unique_ranking_window_funcs(
3154        &self,
3155        input_type: &[SqlColumnType],
3156        col_name: &str,
3157    ) -> MirScalarExpr {
3158        let list = self
3159            .expr
3160            .clone()
3161            // extract the list within the record
3162            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3163
3164        // extract the expression within the list
3165        let record = MirScalarExpr::CallVariadic {
3166            func: VariadicFunc::ListIndex,
3167            exprs: vec![
3168                list,
3169                MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3170            ],
3171        };
3172
3173        MirScalarExpr::CallVariadic {
3174            func: VariadicFunc::ListCreate {
3175                elem_type: self
3176                    .typ(input_type)
3177                    .scalar_type
3178                    .unwrap_list_element_type()
3179                    .clone(),
3180            },
3181            exprs: vec![MirScalarExpr::CallVariadic {
3182                func: VariadicFunc::RecordCreate {
3183                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3184                },
3185                exprs: vec![
3186                    MirScalarExpr::literal_ok(Datum::Int64(1), SqlScalarType::Int64),
3187                    record,
3188                ],
3189            }],
3190        }
3191    }
3192
3193    /// `on_unique` for `lag` and `lead`
3194    fn on_unique_lag_lead(
3195        lag_lead: &LagLeadType,
3196        encoded_args: MirScalarExpr,
3197        return_type: SqlScalarType,
3198    ) -> (MirScalarExpr, ColumnName) {
3199        let expr = encoded_args
3200            .clone()
3201            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3202        let offset = encoded_args
3203            .clone()
3204            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3205        let default_value =
3206            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3207
3208        // In this case, the window always has only one element, so if the offset is not null and
3209        // not zero, the default value should be returned instead.
3210        let value = offset
3211            .clone()
3212            .call_binary(
3213                MirScalarExpr::literal_ok(Datum::Int32(0), SqlScalarType::Int32),
3214                crate::func::Eq,
3215            )
3216            .if_then_else(expr, default_value);
3217        let result_expr = offset
3218            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3219            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3220
3221        let column_name = ColumnName::from(match lag_lead {
3222            LagLeadType::Lag => "?lag?",
3223            LagLeadType::Lead => "?lead?",
3224        });
3225
3226        (result_expr, column_name)
3227    }
3228
3229    /// `on_unique` for `first_value` and `last_value`
3230    fn on_unique_first_value_last_value(
3231        window_frame: &WindowFrame,
3232        arg: MirScalarExpr,
3233        return_type: SqlScalarType,
3234    ) -> (MirScalarExpr, ColumnName) {
3235        // If the window frame includes the current (single) row, return its value, null otherwise
3236        let result_expr = if window_frame.includes_current_row() {
3237            arg
3238        } else {
3239            MirScalarExpr::literal_null(return_type)
3240        };
3241        (result_expr, ColumnName::from("?first_value?"))
3242    }
3243
3244    /// `on_unique` for window aggregations
3245    fn on_unique_window_agg(
3246        window_frame: &WindowFrame,
3247        arg_expr: MirScalarExpr,
3248        input_type: &[SqlColumnType],
3249        return_type: SqlScalarType,
3250        wrapped_aggr: &AggregateFunc,
3251    ) -> (MirScalarExpr, ColumnName) {
3252        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3253        // that row. Otherwise, return the default value for the aggregate.
3254        let result_expr = if window_frame.includes_current_row() {
3255            AggregateExpr {
3256                func: wrapped_aggr.clone(),
3257                expr: arg_expr,
3258                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3259            }
3260            .on_unique(input_type)
3261        } else {
3262            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3263        };
3264        (result_expr, ColumnName::from("?window_agg?"))
3265    }
3266
3267    /// Returns whether the expression is COUNT(*) or not.  Note that
3268    /// when we define the count builtin in sql::func, we convert
3269    /// COUNT(*) to COUNT(true), making it indistinguishable from
3270    /// literal COUNT(true), but we prefer to consider this as the
3271    /// former.
3272    ///
3273    /// (HIR has the same `is_count_asterisk`.)
3274    pub fn is_count_asterisk(&self) -> bool {
3275        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3276    }
3277}
3278
3279/// Describe a join implementation in dataflow.
3280#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3281pub enum JoinImplementation {
3282    /// Perform a sequence of binary differential dataflow joins.
3283    ///
3284    /// The first argument indicates
3285    /// 1) the index of the starting collection,
3286    /// 2) if it should be arranged, the keys to arrange it by, and
3287    /// 3) the characteristics of the starting collection (for EXPLAINing).
3288    /// The sequence that follows lists other relation indexes, and the key for
3289    /// the arrangement we should use when joining it in.
3290    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3291    /// were used for join ordering.
3292    ///
3293    /// Each collection index should occur exactly once, either as the starting collection
3294    /// or somewhere in the list.
3295    Differential(
3296        (
3297            usize,
3298            Option<Vec<MirScalarExpr>>,
3299            Option<JoinInputCharacteristics>,
3300        ),
3301        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3302    ),
3303    /// Perform independent delta query dataflows for each input.
3304    ///
3305    /// The argument is a sequence of plans, for the input collections in order.
3306    /// Each plan starts from the corresponding index, and then in sequence joins
3307    /// against collections identified by index and with the specified arrangement key.
3308    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3309    /// used for join ordering.
3310    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3311    /// Join a user-created index with a constant collection to speed up the evaluation of a
3312    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3313    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3314    /// to represent it in MIR, because the fast path detection wants to match on this.
3315    ///
3316    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3317    IndexedFilter(
3318        GlobalId,
3319        GlobalId,
3320        Vec<MirScalarExpr>,
3321        #[mzreflect(ignore)] Vec<Row>,
3322    ),
3323    /// No implementation yet selected.
3324    Unimplemented,
3325}
3326
3327impl Default for JoinImplementation {
3328    fn default() -> Self {
3329        JoinImplementation::Unimplemented
3330    }
3331}
3332
3333impl JoinImplementation {
3334    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3335    pub fn is_implemented(&self) -> bool {
3336        match self {
3337            Self::Unimplemented => false,
3338            _ => true,
3339        }
3340    }
3341
3342    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3343    pub fn name(&self) -> Option<&'static str> {
3344        match self {
3345            Self::Differential(..) => Some("differential"),
3346            Self::DeltaQuery(..) => Some("delta"),
3347            Self::IndexedFilter(..) => Some("indexed_filter"),
3348            Self::Unimplemented => None,
3349        }
3350    }
3351}
3352
3353/// Characteristics of a join order candidate collection.
3354///
3355/// A candidate is described by a collection and a key, and may have various liabilities.
3356/// Primarily, the candidate may risk substantial inflation of records, which is something
3357/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3358/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3359/// collections in the interest of consistent tie-breaking. For more characteristics, see
3360/// comments on individual fields.
3361///
3362/// This has more than one version. `new` instantiates the appropriate version based on a
3363/// feature flag.
3364#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3365pub enum JoinInputCharacteristics {
3366    /// Old version, with `enable_join_prioritize_arranged` turned off.
3367    V1(JoinInputCharacteristicsV1),
3368    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3369    V2(JoinInputCharacteristicsV2),
3370}
3371
3372impl JoinInputCharacteristics {
3373    /// Creates a new instance with the given characteristics.
3374    pub fn new(
3375        unique_key: bool,
3376        key_length: usize,
3377        arranged: bool,
3378        cardinality: Option<usize>,
3379        filters: FilterCharacteristics,
3380        input: usize,
3381        enable_join_prioritize_arranged: bool,
3382    ) -> Self {
3383        if enable_join_prioritize_arranged {
3384            Self::V2(JoinInputCharacteristicsV2::new(
3385                unique_key,
3386                key_length,
3387                arranged,
3388                cardinality,
3389                filters,
3390                input,
3391            ))
3392        } else {
3393            Self::V1(JoinInputCharacteristicsV1::new(
3394                unique_key,
3395                key_length,
3396                arranged,
3397                cardinality,
3398                filters,
3399                input,
3400            ))
3401        }
3402    }
3403
3404    /// Turns the instance into a String to be printed in EXPLAIN.
3405    pub fn explain(&self) -> String {
3406        match self {
3407            Self::V1(jic) => jic.explain(),
3408            Self::V2(jic) => jic.explain(),
3409        }
3410    }
3411
3412    /// Whether the join input described by `self` is arranged.
3413    pub fn arranged(&self) -> bool {
3414        match self {
3415            Self::V1(jic) => jic.arranged,
3416            Self::V2(jic) => jic.arranged,
3417        }
3418    }
3419
3420    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3421    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3422        match self {
3423            Self::V1(jic) => &mut jic.filters,
3424            Self::V2(jic) => &mut jic.filters,
3425        }
3426    }
3427}
3428
3429/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3430#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3431pub struct JoinInputCharacteristicsV2 {
3432    /// An excellent indication that record count will not increase.
3433    pub unique_key: bool,
3434    /// Cross joins are bad.
3435    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3436    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3437    /// but otherwise `key_length` is less important than `arranged`.)
3438    pub not_cross: bool,
3439    /// Indicates that there will be no additional in-memory footprint.
3440    pub arranged: bool,
3441    /// A weaker signal that record count will not increase.
3442    pub key_length: usize,
3443    /// Estimated cardinality (lower is better)
3444    pub cardinality: Option<std::cmp::Reverse<usize>>,
3445    /// Characteristics of the filter that is applied at this input.
3446    pub filters: FilterCharacteristics,
3447    /// We want to prefer input earlier in the input list, for stability of ordering.
3448    pub input: std::cmp::Reverse<usize>,
3449}
3450
3451impl JoinInputCharacteristicsV2 {
3452    /// Creates a new instance with the given characteristics.
3453    pub fn new(
3454        unique_key: bool,
3455        key_length: usize,
3456        arranged: bool,
3457        cardinality: Option<usize>,
3458        filters: FilterCharacteristics,
3459        input: usize,
3460    ) -> Self {
3461        Self {
3462            unique_key,
3463            not_cross: key_length > 0,
3464            arranged,
3465            key_length,
3466            cardinality: cardinality.map(std::cmp::Reverse),
3467            filters,
3468            input: std::cmp::Reverse(input),
3469        }
3470    }
3471
3472    /// Turns the instance into a String to be printed in EXPLAIN.
3473    pub fn explain(&self) -> String {
3474        let mut e = "".to_owned();
3475        if self.unique_key {
3476            e.push_str("U");
3477        }
3478        // Don't need to print `not_cross`, because that is visible in the printed key.
3479        // if !self.not_cross {
3480        //     e.push_str("C");
3481        // }
3482        for _ in 0..self.key_length {
3483            e.push_str("K");
3484        }
3485        if self.arranged {
3486            e.push_str("A");
3487        }
3488        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3489            e.push_str(&format!("|{cardinality}|"));
3490        }
3491        e.push_str(&self.filters.explain());
3492        e
3493    }
3494}
3495
3496/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3497#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect)]
3498pub struct JoinInputCharacteristicsV1 {
3499    /// An excellent indication that record count will not increase.
3500    pub unique_key: bool,
3501    /// A weaker signal that record count will not increase.
3502    pub key_length: usize,
3503    /// Indicates that there will be no additional in-memory footprint.
3504    pub arranged: bool,
3505    /// Estimated cardinality (lower is better)
3506    pub cardinality: Option<std::cmp::Reverse<usize>>,
3507    /// Characteristics of the filter that is applied at this input.
3508    pub filters: FilterCharacteristics,
3509    /// We want to prefer input earlier in the input list, for stability of ordering.
3510    pub input: std::cmp::Reverse<usize>,
3511}
3512
3513impl JoinInputCharacteristicsV1 {
3514    /// Creates a new instance with the given characteristics.
3515    pub fn new(
3516        unique_key: bool,
3517        key_length: usize,
3518        arranged: bool,
3519        cardinality: Option<usize>,
3520        filters: FilterCharacteristics,
3521        input: usize,
3522    ) -> Self {
3523        Self {
3524            unique_key,
3525            key_length,
3526            arranged,
3527            cardinality: cardinality.map(std::cmp::Reverse),
3528            filters,
3529            input: std::cmp::Reverse(input),
3530        }
3531    }
3532
3533    /// Turns the instance into a String to be printed in EXPLAIN.
3534    pub fn explain(&self) -> String {
3535        let mut e = "".to_owned();
3536        if self.unique_key {
3537            e.push_str("U");
3538        }
3539        for _ in 0..self.key_length {
3540            e.push_str("K");
3541        }
3542        if self.arranged {
3543            e.push_str("A");
3544        }
3545        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3546            e.push_str(&format!("|{cardinality}|"));
3547        }
3548        e.push_str(&self.filters.explain());
3549        e
3550    }
3551}
3552
3553/// Instructions for finishing the result of a query.
3554///
3555/// The primary reason for the existence of this structure and attendant code
3556/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3557/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3558/// multisets. But as it turns out, the same idea can be used to optimize
3559/// trivial peeks.
3560///
3561/// The generic parameters are for accommodating prepared statement parameters in
3562/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3563/// `bind_parameters` on them.
3564#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3565pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3566    /// Order rows by the given columns.
3567    pub order_by: Vec<ColumnOrder>,
3568    /// Include only as many rows (after offset).
3569    pub limit: Option<L>,
3570    /// Omit as many rows.
3571    pub offset: O,
3572    /// Include only given columns.
3573    pub project: Vec<usize>,
3574}
3575
3576impl<L> RowSetFinishing<L> {
3577    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3578    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3579        RowSetFinishing {
3580            order_by: Vec::new(),
3581            limit: None,
3582            offset: 0,
3583            project: (0..arity).collect(),
3584        }
3585    }
3586    /// True if the finishing does nothing to any result set.
3587    pub fn is_trivial(&self, arity: usize) -> bool {
3588        self.limit.is_none()
3589            && self.order_by.is_empty()
3590            && self.offset == 0
3591            && self.project.iter().copied().eq(0..arity)
3592    }
3593    /// True if the finishing does not require an ORDER BY.
3594    ///
3595    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3596    /// explicit ordering we will skip an arbitrary bag of elements and return
3597    /// the first arbitrary elements in the remaining bag. The result semantics
3598    /// are still correct but maybe surprising for some users.
3599    pub fn is_streamable(&self, arity: usize) -> bool {
3600        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3601    }
3602}
3603
3604impl RowSetFinishing<NonNeg<i64>, usize> {
3605    /// The number of rows needed from before the finishing to evaluate the finishing:
3606    /// offset + limit.
3607    ///
3608    /// If it returns None, then we need all the rows.
3609    pub fn num_rows_needed(&self) -> Option<usize> {
3610        self.limit
3611            .as_ref()
3612            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3613    }
3614}
3615
3616impl RowSetFinishing {
3617    /// Applies finishing actions to a [`RowCollection`], and reports the total
3618    /// time it took to run.
3619    ///
3620    /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3621    /// well as the size of the response in bytes.
3622    pub fn finish(
3623        &self,
3624        rows: RowCollection,
3625        max_result_size: u64,
3626        max_returned_query_size: Option<u64>,
3627        duration_histogram: &Histogram,
3628    ) -> Result<(SortedRowCollectionIter, usize), String> {
3629        let now = Instant::now();
3630        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3631        let duration = now.elapsed();
3632        duration_histogram.observe(duration.as_secs_f64());
3633
3634        result
3635    }
3636
3637    /// Implementation for [`RowSetFinishing::finish`].
3638    fn finish_inner(
3639        &self,
3640        rows: RowCollection,
3641        max_result_size: u64,
3642        max_returned_query_size: Option<u64>,
3643    ) -> Result<(SortedRowCollectionIter, usize), String> {
3644        // How much additional memory is required to make a sorted view.
3645        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3646        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3647
3648        // Bail if creating the sorted view would require us to use too much memory.
3649        if required_memory > usize::cast_from(max_result_size) {
3650            let max_bytes = ByteSize::b(max_result_size);
3651            return Err(format!("result exceeds max size of {max_bytes}",));
3652        }
3653
3654        let sorted_view = rows.sorted_view(&self.order_by);
3655        let mut iter = sorted_view
3656            .into_row_iter()
3657            .apply_offset(self.offset)
3658            .with_projection(self.project.clone());
3659
3660        if let Some(limit) = self.limit {
3661            let limit = u64::from(limit);
3662            let limit = usize::cast_from(limit);
3663            iter = iter.with_limit(limit);
3664        };
3665
3666        // TODO(parkmycar): Re-think how we can calculate the total response size without
3667        // having to iterate through the entire collection of Rows, while still
3668        // respecting the LIMIT, OFFSET, and projections.
3669        //
3670        // Note: It feels a bit bad always calculating the response size, but we almost
3671        // always need it to either check the `max_returned_query_size`, or for reporting
3672        // in the query history.
3673        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3674
3675        // Bail if we would end up returning more data to the client than they can support.
3676        if let Some(max) = max_returned_query_size {
3677            if response_size > usize::cast_from(max) {
3678                let max_bytes = ByteSize::b(max);
3679                return Err(format!("result exceeds max size of {max_bytes}"));
3680            }
3681        }
3682
3683        Ok((iter, response_size))
3684    }
3685}
3686
3687/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3688/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3689/// on query result size.
3690#[derive(Debug)]
3691pub struct RowSetFinishingIncremental {
3692    /// Include only as many rows (after offset).
3693    pub remaining_limit: Option<usize>,
3694    /// Omit as many rows.
3695    pub remaining_offset: usize,
3696    /// The maximum allowed result size, as requested by the client.
3697    pub max_returned_query_size: Option<u64>,
3698    /// Tracks our remaining allowed budget for result size.
3699    pub remaining_max_returned_query_size: Option<u64>,
3700    /// Include only given columns.
3701    pub project: Vec<usize>,
3702}
3703
3704impl RowSetFinishingIncremental {
3705    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3706    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3707    /// `true`.
3708    ///
3709    /// # Panics
3710    ///
3711    /// Panics if the result is not streamable, that is it has an ORDER BY.
3712    pub fn new(
3713        offset: usize,
3714        limit: Option<NonNeg<i64>>,
3715        project: Vec<usize>,
3716        max_returned_query_size: Option<u64>,
3717    ) -> Self {
3718        let limit = limit.map(|l| {
3719            let l = u64::from(l);
3720            let l = usize::cast_from(l);
3721            l
3722        });
3723
3724        RowSetFinishingIncremental {
3725            remaining_limit: limit,
3726            remaining_offset: offset,
3727            max_returned_query_size,
3728            remaining_max_returned_query_size: max_returned_query_size,
3729            project,
3730        }
3731    }
3732
3733    /// Applies finishing actions to the given [`RowCollection`], and reports
3734    /// the total time it took to run.
3735    ///
3736    /// Returns a [`SortedRowCollectionIter`] that contains all of the response
3737    /// data.
3738    pub fn finish_incremental(
3739        &mut self,
3740        rows: RowCollection,
3741        max_result_size: u64,
3742        duration_histogram: &Histogram,
3743    ) -> Result<SortedRowCollectionIter, String> {
3744        let now = Instant::now();
3745        let result = self.finish_incremental_inner(rows, max_result_size);
3746        let duration = now.elapsed();
3747        duration_histogram.observe(duration.as_secs_f64());
3748
3749        result
3750    }
3751
3752    fn finish_incremental_inner(
3753        &mut self,
3754        rows: RowCollection,
3755        max_result_size: u64,
3756    ) -> Result<SortedRowCollectionIter, String> {
3757        // How much additional memory is required to make a sorted view.
3758        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3759        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3760
3761        // Bail if creating the sorted view would require us to use too much memory.
3762        if required_memory > usize::cast_from(max_result_size) {
3763            let max_bytes = ByteSize::b(max_result_size);
3764            return Err(format!("total result exceeds max size of {max_bytes}",));
3765        }
3766
3767        let batch_num_rows = rows.count(0, None);
3768
3769        let sorted_view = rows.sorted_view(&[]);
3770        let mut iter = sorted_view
3771            .into_row_iter()
3772            .apply_offset(self.remaining_offset)
3773            .with_projection(self.project.clone());
3774
3775        if let Some(limit) = self.remaining_limit {
3776            iter = iter.with_limit(limit);
3777        };
3778
3779        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3780        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3781            *remaining_limit -= iter.count();
3782        }
3783
3784        // TODO(parkmycar): Re-think how we can calculate the total response size without
3785        // having to iterate through the entire collection of Rows, while still
3786        // respecting the LIMIT, OFFSET, and projections.
3787        //
3788        // Note: It feels a bit bad always calculating the response size, but we almost
3789        // always need it to either check the `max_returned_query_size`, or for reporting
3790        // in the query history.
3791        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3792
3793        // Bail if we would end up returning more data to the client than they can support.
3794        if let Some(max) = self.remaining_max_returned_query_size {
3795            if response_size > usize::cast_from(max) {
3796                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3797                return Err(format!("total result exceeds max size of {max_bytes}"));
3798            }
3799        }
3800
3801        Ok(iter)
3802    }
3803}
3804
3805/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3806/// ordering, call `tiebreaker`.
3807pub fn compare_columns<F>(
3808    order: &[ColumnOrder],
3809    left: &[Datum],
3810    right: &[Datum],
3811    tiebreaker: F,
3812) -> Ordering
3813where
3814    F: Fn() -> Ordering,
3815{
3816    for order in order {
3817        let cmp = match (&left[order.column], &right[order.column]) {
3818            (Datum::Null, Datum::Null) => Ordering::Equal,
3819            (Datum::Null, _) => {
3820                if order.nulls_last {
3821                    Ordering::Greater
3822                } else {
3823                    Ordering::Less
3824                }
3825            }
3826            (_, Datum::Null) => {
3827                if order.nulls_last {
3828                    Ordering::Less
3829                } else {
3830                    Ordering::Greater
3831                }
3832            }
3833            (lval, rval) => {
3834                if order.desc {
3835                    rval.cmp(lval)
3836                } else {
3837                    lval.cmp(rval)
3838                }
3839            }
3840        };
3841        if cmp != Ordering::Equal {
3842            return cmp;
3843        }
3844    }
3845    tiebreaker()
3846}
3847
3848/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3849/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3850///
3851/// Window frames define a subset of the partition , and only a subset of
3852/// window functions make use of the window frame.
3853#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3854pub struct WindowFrame {
3855    /// ROWS, RANGE or GROUPS
3856    pub units: WindowFrameUnits,
3857    /// Where the frame starts
3858    pub start_bound: WindowFrameBound,
3859    /// Where the frame ends
3860    pub end_bound: WindowFrameBound,
3861}
3862
3863impl Display for WindowFrame {
3864    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3865        write!(
3866            f,
3867            "{} between {} and {}",
3868            self.units, self.start_bound, self.end_bound
3869        )
3870    }
3871}
3872
3873impl WindowFrame {
3874    /// Return the default window frame used when one is not explicitly defined
3875    pub fn default() -> Self {
3876        WindowFrame {
3877            units: WindowFrameUnits::Range,
3878            start_bound: WindowFrameBound::UnboundedPreceding,
3879            end_bound: WindowFrameBound::CurrentRow,
3880        }
3881    }
3882
3883    fn includes_current_row(&self) -> bool {
3884        use WindowFrameBound::*;
3885        match self.start_bound {
3886            UnboundedPreceding => match self.end_bound {
3887                UnboundedPreceding => false,
3888                OffsetPreceding(0) => true,
3889                OffsetPreceding(_) => false,
3890                CurrentRow => true,
3891                OffsetFollowing(_) => true,
3892                UnboundedFollowing => true,
3893            },
3894            OffsetPreceding(0) => match self.end_bound {
3895                UnboundedPreceding => unreachable!(),
3896                OffsetPreceding(0) => true,
3897                // Any nonzero offsets here will create an empty window
3898                OffsetPreceding(_) => false,
3899                CurrentRow => true,
3900                OffsetFollowing(_) => true,
3901                UnboundedFollowing => true,
3902            },
3903            OffsetPreceding(_) => match self.end_bound {
3904                UnboundedPreceding => unreachable!(),
3905                // Window ends at the current row
3906                OffsetPreceding(0) => true,
3907                OffsetPreceding(_) => false,
3908                CurrentRow => true,
3909                OffsetFollowing(_) => true,
3910                UnboundedFollowing => true,
3911            },
3912            CurrentRow => true,
3913            OffsetFollowing(0) => match self.end_bound {
3914                UnboundedPreceding => unreachable!(),
3915                OffsetPreceding(_) => unreachable!(),
3916                CurrentRow => unreachable!(),
3917                OffsetFollowing(_) => true,
3918                UnboundedFollowing => true,
3919            },
3920            OffsetFollowing(_) => match self.end_bound {
3921                UnboundedPreceding => unreachable!(),
3922                OffsetPreceding(_) => unreachable!(),
3923                CurrentRow => unreachable!(),
3924                OffsetFollowing(_) => false,
3925                UnboundedFollowing => false,
3926            },
3927            UnboundedFollowing => false,
3928        }
3929    }
3930}
3931
3932/// Describe how frame bounds are interpreted
3933#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect)]
3934pub enum WindowFrameUnits {
3935    /// Each row is treated as the unit of work for bounds
3936    Rows,
3937    /// Each peer group is treated as the unit of work for bounds,
3938    /// and offset-based bounds use the value of the ORDER BY expression
3939    Range,
3940    /// Each peer group is treated as the unit of work for bounds.
3941    /// Groups is currently not supported, and it is rejected during planning.
3942    Groups,
3943}
3944
3945impl Display for WindowFrameUnits {
3946    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3947        match self {
3948            WindowFrameUnits::Rows => write!(f, "rows"),
3949            WindowFrameUnits::Range => write!(f, "range"),
3950            WindowFrameUnits::Groups => write!(f, "groups"),
3951        }
3952    }
3953}
3954
3955/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3956///
3957/// The order between frame bounds is significant, as Postgres enforces
3958/// some restrictions there.
3959#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, MzReflect, PartialOrd, Ord)]
3960pub enum WindowFrameBound {
3961    /// `UNBOUNDED PRECEDING`
3962    UnboundedPreceding,
3963    /// `<N> PRECEDING`
3964    OffsetPreceding(u64),
3965    /// `CURRENT ROW`
3966    CurrentRow,
3967    /// `<N> FOLLOWING`
3968    OffsetFollowing(u64),
3969    /// `UNBOUNDED FOLLOWING`.
3970    UnboundedFollowing,
3971}
3972
3973impl Display for WindowFrameBound {
3974    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3975        match self {
3976            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
3977            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
3978            WindowFrameBound::CurrentRow => write!(f, "current row"),
3979            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
3980            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
3981        }
3982    }
3983}
3984
3985/// Maximum iterations for a LetRec.
3986#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
3987pub struct LetRecLimit {
3988    /// Maximum number of iterations to evaluate.
3989    pub max_iters: NonZeroU64,
3990    /// Whether to throw an error when reaching the above limit.
3991    /// If true, we simply use the current contents of each Id as the final result.
3992    pub return_at_limit: bool,
3993}
3994
3995impl LetRecLimit {
3996    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
3997    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
3998        limits
3999            .iter()
4000            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4001            .min()
4002    }
4003
4004    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4005    /// WMR without ERROR AT or RETURN AT.
4006    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4007}
4008
4009impl Display for LetRecLimit {
4010    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4011        write!(f, "[recursion_limit={}", self.max_iters)?;
4012        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4013            write!(f, ", return_at_limit")?;
4014        }
4015        write!(f, "]")
4016    }
4017}
4018
4019/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4020/// (See comment in MirRelationExpr::Get.)
4021#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)]
4022pub enum AccessStrategy {
4023    /// It's either a local Get (a CTE), or unknown at the time.
4024    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4025    /// one of the other variants.
4026    UnknownOrLocal,
4027    /// The Get will read from Persist.
4028    Persist,
4029    /// The Get will read from an index or indexes: (index id, how the index will be used).
4030    Index(Vec<(GlobalId, IndexUsageType)>),
4031    /// The Get will read a collection that is computed by the same dataflow, but in a different
4032    /// `BuildDesc` in `objects_to_build`.
4033    SameDataflow,
4034}
4035
4036#[cfg(test)]
4037mod tests {
4038    use mz_repr::explain::text::text_string_at;
4039
4040    use crate::explain::HumanizedExplain;
4041
4042    use super::*;
4043
4044    #[mz_ore::test]
4045    fn test_row_set_finishing_as_text() {
4046        let finishing = RowSetFinishing {
4047            order_by: vec![ColumnOrder {
4048                column: 4,
4049                desc: true,
4050                nulls_last: true,
4051            }],
4052            limit: Some(NonNeg::try_from(7).unwrap()),
4053            offset: Default::default(),
4054            project: vec![1, 3, 4, 5],
4055        };
4056
4057        let mode = HumanizedExplain::new(false);
4058        let expr = mode.expr(&finishing, None);
4059
4060        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4061
4062        let exp = {
4063            use mz_ore::fmt::FormatBuffer;
4064            let mut s = String::new();
4065            write!(&mut s, "Finish");
4066            write!(&mut s, " order_by=[#4 desc nulls_last]");
4067            write!(&mut s, " limit=7");
4068            write!(&mut s, " output=[#1, #3..=#5]");
4069            writeln!(&mut s, "");
4070            s
4071        };
4072
4073        assert_eq!(act, exp);
4074    }
4075}
4076
4077/// An iterator over AST structures, which calls out nodes in difference.
4078///
4079/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4080/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4081/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4082/// to compare two ASTs.
4083mod structured_diff {
4084
4085    use super::MirRelationExpr;
4086    use itertools::Itertools;
4087
4088    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4089    pub struct MreDiff<'a> {
4090        /// Pairs of expressions that must still be compared.
4091        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4092    }
4093
4094    impl<'a> MreDiff<'a> {
4095        /// Create a new `MirRelationExpr` structured difference.
4096        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4097            MreDiff {
4098                todo: vec![(expr1, expr2)],
4099            }
4100        }
4101    }
4102
4103    impl<'a> Iterator for MreDiff<'a> {
4104        // Pairs of expressions that do not match.
4105        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4106
4107        fn next(&mut self) -> Option<Self::Item> {
4108            while let Some((expr1, expr2)) = self.todo.pop() {
4109                match (expr1, expr2) {
4110                    (
4111                        MirRelationExpr::Constant {
4112                            rows: rows1,
4113                            typ: typ1,
4114                        },
4115                        MirRelationExpr::Constant {
4116                            rows: rows2,
4117                            typ: typ2,
4118                        },
4119                    ) => {
4120                        if rows1 != rows2 || typ1 != typ2 {
4121                            return Some((expr1, expr2));
4122                        }
4123                    }
4124                    (
4125                        MirRelationExpr::Get {
4126                            id: id1,
4127                            typ: typ1,
4128                            access_strategy: as1,
4129                        },
4130                        MirRelationExpr::Get {
4131                            id: id2,
4132                            typ: typ2,
4133                            access_strategy: as2,
4134                        },
4135                    ) => {
4136                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4137                            return Some((expr1, expr2));
4138                        }
4139                    }
4140                    (
4141                        MirRelationExpr::Let {
4142                            id: id1,
4143                            body: body1,
4144                            value: value1,
4145                        },
4146                        MirRelationExpr::Let {
4147                            id: id2,
4148                            body: body2,
4149                            value: value2,
4150                        },
4151                    ) => {
4152                        if id1 != id2 {
4153                            return Some((expr1, expr2));
4154                        } else {
4155                            self.todo.push((body1, body2));
4156                            self.todo.push((value1, value2));
4157                        }
4158                    }
4159                    (
4160                        MirRelationExpr::LetRec {
4161                            ids: ids1,
4162                            body: body1,
4163                            values: values1,
4164                            limits: limits1,
4165                        },
4166                        MirRelationExpr::LetRec {
4167                            ids: ids2,
4168                            body: body2,
4169                            values: values2,
4170                            limits: limits2,
4171                        },
4172                    ) => {
4173                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4174                            return Some((expr1, expr2));
4175                        } else {
4176                            self.todo.push((body1, body2));
4177                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4178                        }
4179                    }
4180                    (
4181                        MirRelationExpr::Project {
4182                            outputs: outputs1,
4183                            input: input1,
4184                        },
4185                        MirRelationExpr::Project {
4186                            outputs: outputs2,
4187                            input: input2,
4188                        },
4189                    ) => {
4190                        if outputs1 != outputs2 {
4191                            return Some((expr1, expr2));
4192                        } else {
4193                            self.todo.push((input1, input2));
4194                        }
4195                    }
4196                    (
4197                        MirRelationExpr::Map {
4198                            scalars: scalars1,
4199                            input: input1,
4200                        },
4201                        MirRelationExpr::Map {
4202                            scalars: scalars2,
4203                            input: input2,
4204                        },
4205                    ) => {
4206                        if scalars1 != scalars2 {
4207                            return Some((expr1, expr2));
4208                        } else {
4209                            self.todo.push((input1, input2));
4210                        }
4211                    }
4212                    (
4213                        MirRelationExpr::Filter {
4214                            predicates: predicates1,
4215                            input: input1,
4216                        },
4217                        MirRelationExpr::Filter {
4218                            predicates: predicates2,
4219                            input: input2,
4220                        },
4221                    ) => {
4222                        if predicates1 != predicates2 {
4223                            return Some((expr1, expr2));
4224                        } else {
4225                            self.todo.push((input1, input2));
4226                        }
4227                    }
4228                    (
4229                        MirRelationExpr::FlatMap {
4230                            input: input1,
4231                            func: func1,
4232                            exprs: exprs1,
4233                        },
4234                        MirRelationExpr::FlatMap {
4235                            input: input2,
4236                            func: func2,
4237                            exprs: exprs2,
4238                        },
4239                    ) => {
4240                        if func1 != func2 || exprs1 != exprs2 {
4241                            return Some((expr1, expr2));
4242                        } else {
4243                            self.todo.push((input1, input2));
4244                        }
4245                    }
4246                    (
4247                        MirRelationExpr::Join {
4248                            inputs: inputs1,
4249                            equivalences: eq1,
4250                            implementation: impl1,
4251                        },
4252                        MirRelationExpr::Join {
4253                            inputs: inputs2,
4254                            equivalences: eq2,
4255                            implementation: impl2,
4256                        },
4257                    ) => {
4258                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4259                            return Some((expr1, expr2));
4260                        } else {
4261                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4262                        }
4263                    }
4264                    (
4265                        MirRelationExpr::Reduce {
4266                            aggregates: aggregates1,
4267                            input: inputs1,
4268                            group_key: gk1,
4269                            monotonic: m1,
4270                            expected_group_size: egs1,
4271                        },
4272                        MirRelationExpr::Reduce {
4273                            aggregates: aggregates2,
4274                            input: inputs2,
4275                            group_key: gk2,
4276                            monotonic: m2,
4277                            expected_group_size: egs2,
4278                        },
4279                    ) => {
4280                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4281                            return Some((expr1, expr2));
4282                        } else {
4283                            self.todo.push((inputs1, inputs2));
4284                        }
4285                    }
4286                    (
4287                        MirRelationExpr::TopK {
4288                            group_key: gk1,
4289                            order_key: order1,
4290                            input: input1,
4291                            limit: l1,
4292                            offset: o1,
4293                            monotonic: m1,
4294                            expected_group_size: egs1,
4295                        },
4296                        MirRelationExpr::TopK {
4297                            group_key: gk2,
4298                            order_key: order2,
4299                            input: input2,
4300                            limit: l2,
4301                            offset: o2,
4302                            monotonic: m2,
4303                            expected_group_size: egs2,
4304                        },
4305                    ) => {
4306                        if order1 != order2
4307                            || gk1 != gk2
4308                            || l1 != l2
4309                            || o1 != o2
4310                            || m1 != m2
4311                            || egs1 != egs2
4312                        {
4313                            return Some((expr1, expr2));
4314                        } else {
4315                            self.todo.push((input1, input2));
4316                        }
4317                    }
4318                    (
4319                        MirRelationExpr::Negate { input: input1 },
4320                        MirRelationExpr::Negate { input: input2 },
4321                    ) => {
4322                        self.todo.push((input1, input2));
4323                    }
4324                    (
4325                        MirRelationExpr::Threshold { input: input1 },
4326                        MirRelationExpr::Threshold { input: input2 },
4327                    ) => {
4328                        self.todo.push((input1, input2));
4329                    }
4330                    (
4331                        MirRelationExpr::Union {
4332                            base: base1,
4333                            inputs: inputs1,
4334                        },
4335                        MirRelationExpr::Union {
4336                            base: base2,
4337                            inputs: inputs2,
4338                        },
4339                    ) => {
4340                        if inputs1.len() != inputs2.len() {
4341                            return Some((expr1, expr2));
4342                        } else {
4343                            self.todo.push((base1, base2));
4344                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4345                        }
4346                    }
4347                    (
4348                        MirRelationExpr::ArrangeBy {
4349                            keys: keys1,
4350                            input: input1,
4351                        },
4352                        MirRelationExpr::ArrangeBy {
4353                            keys: keys2,
4354                            input: input2,
4355                        },
4356                    ) => {
4357                        if keys1 != keys2 {
4358                            return Some((expr1, expr2));
4359                        } else {
4360                            self.todo.push((input1, input2));
4361                        }
4362                    }
4363                    _ => {
4364                        return Some((expr1, expr2));
4365                    }
4366                }
4367            }
4368            None
4369        }
4370    }
4371}