Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cell::RefCell;
13use std::cmp::{Ordering, max};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17use std::hash::{DefaultHasher, Hash, Hasher};
18use std::num::NonZeroU64;
19use std::time::Instant;
20
21use bytesize::ByteSize;
22use columnation::{Columnation, CopyRegion};
23use itertools::Itertools;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, CastInto};
26use mz_ore::collections::CollectionExt;
27use mz_ore::id_gen::IdGen;
28use mz_ore::metrics::Histogram;
29use mz_ore::num::NonNeg;
30use mz_ore::soft_assert_no_log;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::Indent;
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39    ColumnName, Datum, DatumVec, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
40    ReprScalarType, Row, RowIterator, RowRef, SqlColumnType, SqlRelationType, SqlScalarType,
41};
42use serde::{Deserialize, Serialize};
43
44use crate::Id::Local;
45use crate::explain::{HumanizedExpr, HumanizerMode};
46use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
47use crate::row::{RowCollection, RowCollectionIter};
48use crate::scalar::func::variadic::{
49    JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
50};
51use crate::visit::{Visit, VisitChildren};
52use crate::{
53    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
54};
55
56pub mod canonicalize;
57pub mod func;
58pub mod join_input_mapper;
59
60/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
61///
62/// The recursion limit must be large enough to accommodate for the linear representation
63/// of some pathological but frequently occurring query fragments.
64///
65/// For example, in MIR we could have long chains of
66/// - (1) `Let` bindings,
67/// - (2) `CallBinary` calls with associative functions such as `+`
68///
69/// Until we fix those, we need to stick with the larger recursion limit.
70pub const RECURSION_LIMIT: usize = 2048;
71
72/// A trait for types that describe how to build a collection.
73pub trait CollectionPlan {
74    /// Collects the set of global identifiers from dataflows referenced in Get.
75    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
76
77    /// Returns the set of global identifiers from dataflows referenced in Get.
78    ///
79    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
80    fn depends_on(&self) -> BTreeSet<GlobalId> {
81        let mut out = BTreeSet::new();
82        self.depends_on_into(&mut out);
83        out
84    }
85}
86
87/// An abstract syntax tree which defines a collection.
88///
89/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
90/// written generically enough to avoid run-time compilation work.
91///
92/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
93/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
94/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
95/// one. This is because the reason for the manual implementation is not to change the semantics
96/// from the derived one, but to avoid stack overflows.
97#[allow(clippy::derived_hash_with_manual_eq)]
98#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
99pub enum MirRelationExpr {
100    /// A constant relation containing specified rows.
101    ///
102    /// The runtime memory footprint of this operator is zero.
103    ///
104    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
105    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
106    /// constant folding doesn't remove `ArrangeBy`s.
107    Constant {
108        /// Rows of the constant collection and their multiplicities.
109        rows: Result<Vec<(Row, Diff)>, EvalError>,
110        /// Schema of the collection.
111        typ: ReprRelationType,
112    },
113    /// Get an existing dataflow.
114    ///
115    /// The runtime memory footprint of this operator is zero.
116    Get {
117        /// The identifier for the collection to load.
118        #[mzreflect(ignore)]
119        id: Id,
120        /// Schema of the collection.
121        typ: ReprRelationType,
122        /// If this is a global Get, this will indicate whether we are going to read from Persist or
123        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
124        /// how downstream dataflow operations will use this index is also recorded. This is filled
125        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
126        /// lowering to LIR, but is used only by EXPLAIN.
127        #[mzreflect(ignore)]
128        access_strategy: AccessStrategy,
129    },
130    /// Introduce a temporary dataflow.
131    ///
132    /// The runtime memory footprint of this operator is zero.
133    Let {
134        /// The identifier to be used in `Get` variants to retrieve `value`.
135        #[mzreflect(ignore)]
136        id: LocalId,
137        /// The collection to be bound to `id`.
138        value: Box<MirRelationExpr>,
139        /// The result of the `Let`, evaluated with `id` bound to `value`.
140        body: Box<MirRelationExpr>,
141    },
142    /// Introduce mutually recursive bindings.
143    ///
144    /// Each `LocalId` is immediately bound to an initially empty  collection
145    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
146    /// binding is evaluated using the current contents of each other binding,
147    /// and is refreshed to contain the new evaluation. This process continues
148    /// through all bindings, and repeats as long as changes continue to occur.
149    ///
150    /// The resulting value of the expression is `body` evaluated once in the
151    /// context of the final iterates.
152    ///
153    /// A zero-binding instance can be replaced by `body`.
154    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
155    ///
156    /// The runtime memory footprint of this operator is zero.
157    LetRec {
158        /// The identifiers to be used in `Get` variants to retrieve each `value`.
159        #[mzreflect(ignore)]
160        ids: Vec<LocalId>,
161        /// The collections to be bound to each `id`.
162        values: Vec<MirRelationExpr>,
163        /// Maximum number of iterations, after which we should artificially force a fixpoint.
164        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
165        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
166        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
167        #[mzreflect(ignore)]
168        limits: Vec<Option<LetRecLimit>>,
169        /// The result of the `Let`, evaluated with `id` bound to `value`.
170        body: Box<MirRelationExpr>,
171    },
172    /// Project out some columns from a dataflow
173    ///
174    /// The runtime memory footprint of this operator is zero.
175    Project {
176        /// The source collection.
177        input: Box<MirRelationExpr>,
178        /// Indices of columns to retain.
179        outputs: Vec<usize>,
180    },
181    /// Append new columns to a dataflow
182    ///
183    /// The runtime memory footprint of this operator is zero.
184    Map {
185        /// The source collection.
186        input: Box<MirRelationExpr>,
187        /// Expressions which determine values to append to each row.
188        /// An expression may refer to columns in `input` or
189        /// expressions defined earlier in the vector
190        scalars: Vec<MirScalarExpr>,
191    },
192    /// Like Map, but yields zero-or-more output rows per input row
193    ///
194    /// The runtime memory footprint of this operator is zero.
195    FlatMap {
196        /// The source collection
197        input: Box<MirRelationExpr>,
198        /// The table func to apply
199        func: TableFunc,
200        /// The argument to the table func
201        exprs: Vec<MirScalarExpr>,
202    },
203    /// Keep rows from a dataflow where all the predicates are true
204    ///
205    /// The runtime memory footprint of this operator is zero.
206    Filter {
207        /// The source collection.
208        input: Box<MirRelationExpr>,
209        /// Predicates, each of which must be true.
210        predicates: Vec<MirScalarExpr>,
211    },
212    /// Join several collections, where some columns must be equal.
213    ///
214    /// For further details consult the documentation for [`MirRelationExpr::join`].
215    ///
216    /// The runtime memory footprint of this operator can be proportional to
217    /// the sizes of all inputs and the size of all joins of prefixes.
218    /// This may be reduced due to arrangements available at rendering time.
219    Join {
220        /// A sequence of input relations.
221        inputs: Vec<MirRelationExpr>,
222        /// A sequence of equivalence classes of expressions on the cross product of inputs.
223        ///
224        /// Each equivalence class is a list of scalar expressions, where for each class the
225        /// intended interpretation is that all evaluated expressions should be equal.
226        ///
227        /// Each scalar expression is to be evaluated over the cross-product of all records
228        /// from all inputs. In many cases this may just be column selection from specific
229        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
230        /// from multiple inputs, or just constant literals).
231        equivalences: Vec<Vec<MirScalarExpr>>,
232        /// Join implementation information.
233        #[serde(default)]
234        implementation: JoinImplementation,
235    },
236    /// Group a dataflow by some columns and aggregate over each group
237    ///
238    /// The runtime memory footprint of this operator is at most proportional to the
239    /// number of distinct records in the input and output. The actual requirements
240    /// can be less: the number of distinct inputs to each aggregate, summed across
241    /// each aggregate, plus the output size. For more details consult the code that
242    /// builds the associated dataflow.
243    Reduce {
244        /// The source collection.
245        input: Box<MirRelationExpr>,
246        /// Column indices used to form groups.
247        group_key: Vec<MirScalarExpr>,
248        /// Expressions which determine values to append to each row, after the group keys.
249        aggregates: Vec<AggregateExpr>,
250        /// True iff the input is known to monotonically increase (only addition of records).
251        #[serde(default)]
252        monotonic: bool,
253        /// User hint: expected number of values per group key. Used to optimize physical rendering.
254        #[serde(default)]
255        expected_group_size: Option<u64>,
256    },
257    /// Groups and orders within each group, limiting output.
258    ///
259    /// The runtime memory footprint of this operator is proportional to its input and output.
260    TopK {
261        /// The source collection.
262        input: Box<MirRelationExpr>,
263        /// Column indices used to form groups.
264        group_key: Vec<usize>,
265        /// Column indices used to order rows within groups.
266        order_key: Vec<ColumnOrder>,
267        /// Number of records to retain
268        #[serde(default)]
269        limit: Option<MirScalarExpr>,
270        /// Number of records to skip
271        #[serde(default)]
272        offset: usize,
273        /// True iff the input is known to monotonically increase (only addition of records).
274        #[serde(default)]
275        monotonic: bool,
276        /// User-supplied hint: how many rows will have the same group key.
277        #[serde(default)]
278        expected_group_size: Option<u64>,
279    },
280    /// Return a dataflow where the row counts are negated
281    ///
282    /// The runtime memory footprint of this operator is zero.
283    Negate {
284        /// The source collection.
285        input: Box<MirRelationExpr>,
286    },
287    /// Keep rows from a dataflow where the row counts are positive
288    ///
289    /// The runtime memory footprint of this operator is proportional to its input and output.
290    Threshold {
291        /// The source collection.
292        input: Box<MirRelationExpr>,
293    },
294    /// Adds the frequencies of elements in contained sets.
295    ///
296    /// The runtime memory footprint of this operator is zero.
297    Union {
298        /// A source collection.
299        base: Box<MirRelationExpr>,
300        /// Source collections to union.
301        inputs: Vec<MirRelationExpr>,
302    },
303    /// Technically a no-op. Used to render an index. Will be used to optimize queries
304    /// on finer grain. Each `keys` item represents a different index that should be
305    /// produced from the `keys`.
306    ///
307    /// The runtime memory footprint of this operator is proportional to its input.
308    ArrangeBy {
309        /// The source collection
310        input: Box<MirRelationExpr>,
311        /// Columns to arrange `input` by, in order of decreasing primacy
312        keys: Vec<Vec<MirScalarExpr>>,
313    },
314}
315
316impl PartialEq for MirRelationExpr {
317    fn eq(&self, other: &Self) -> bool {
318        // Capture the result and test it wrt `Ord` implementation in test environments.
319        let result = structured_diff::MreDiff::new(self, other).next().is_none();
320        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
321        result
322    }
323}
324impl Eq for MirRelationExpr {}
325
326impl MirRelationExpr {
327    /// Reports the schema of the relation.
328    ///
329    /// This is the SQL-type parallel of [`Self::typ`]; it is merely
330    /// a wrapper around it, returning a [`SqlRelationType`] instead of
331    /// a [`ReprRelationType`].
332    pub fn sql_typ(&self) -> SqlRelationType {
333        let repr_typ = self.typ();
334        SqlRelationType::from_repr(&repr_typ)
335    }
336
337    /// Reports the repr schema of the relation.
338    ///
339    /// This method determines the type through recursive traversal of the
340    /// relation expression, drawing from the types of base collections.
341    /// As such, this is not an especially cheap method, and should be used
342    /// judiciously.
343    ///
344    /// The relation type is computed incrementally with a recursive post-order
345    /// traversal, that accumulates the input types for the relations yet to be
346    /// visited in `type_stack`.
347    pub fn typ(&self) -> ReprRelationType {
348        let mut type_stack = Vec::new();
349        #[allow(deprecated)]
350        self.visit_pre_post_nolimit(
351            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
352                match &e {
353                    MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
354                    MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
355                    _ => None,
356                }
357            },
358            &mut |e: &MirRelationExpr| {
359                match e {
360                    MirRelationExpr::Let { .. } => {
361                        let body_typ = type_stack.pop().unwrap();
362                        // Insert a dummy relation type for the value, since `typ_with_input_types`
363                        // won't look at it, but expects the relation type of the body to be second.
364                        type_stack.push(ReprRelationType::empty());
365                        type_stack.push(body_typ);
366                    }
367                    MirRelationExpr::LetRec { values, .. } => {
368                        let body_typ = type_stack.pop().unwrap();
369                        type_stack.extend(
370                            std::iter::repeat(ReprRelationType::empty()).take(values.len()),
371                        );
372                        // Insert dummy relation types for the values, since `typ_with_input_types`
373                        // won't look at them, but expects the relation type of the body to be last.
374                        type_stack.push(body_typ);
375                    }
376                    _ => {}
377                }
378                let num_inputs = e.num_inputs();
379                let relation_type =
380                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
381                type_stack.truncate(type_stack.len() - num_inputs);
382                type_stack.push(relation_type);
383            },
384        );
385        assert_eq!(type_stack.len(), 1);
386        type_stack.pop().unwrap()
387    }
388
389    /// Reports the repr schema of the relation given the repr schema of the input relations.
390    pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
391        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
392        let unique_keys = self.keys_with_input_keys(
393            input_types.iter().map(|i| i.arity()),
394            input_types.iter().map(|i| &i.keys),
395        );
396        ReprRelationType::new(column_types).with_keys(unique_keys)
397    }
398
399    /// Reports the column types of the relation given the column types of the
400    /// input relations.
401    ///
402    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
403    /// variant is returned.
404    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
405    where
406        I: Iterator<Item = &'a Vec<ReprColumnType>>,
407    {
408        match self.try_col_with_input_cols(input_types) {
409            Ok(col_types) => col_types,
410            Err(err) => panic!("{err}"),
411        }
412    }
413
414    /// Reports the column types of the relation given the column types of the input relations.
415    ///
416    /// `input_types` is required to contain the column types for the input relations of
417    /// the current relation in the same order as they are visited by `try_visit_children`
418    /// method, even though not all may be used for computing the schema of the
419    /// current relation. For example, `Let` expects two input types, one for the
420    /// value relation and one for the body, in that order, but only the one for the
421    /// body is used to determine the type of the `Let` relation.
422    ///
423    /// It is meant to be used during post-order traversals to compute column types
424    /// incrementally.
425    pub fn try_col_with_input_cols<'a, I>(
426        &self,
427        mut input_types: I,
428    ) -> Result<Vec<ReprColumnType>, String>
429    where
430        I: Iterator<Item = &'a Vec<ReprColumnType>>,
431    {
432        use MirRelationExpr::*;
433
434        let col_types = match self {
435            Constant { rows, typ } => {
436                let mut col_types = typ.column_types.clone();
437                let mut seen_null = vec![false; typ.arity()];
438                if let Ok(rows) = rows {
439                    for (row, _diff) in rows {
440                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
441                            if datum.is_null() {
442                                seen_null[i] = true;
443                            }
444                        }
445                    }
446                }
447                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
448                    if !seen_null {
449                        col_types[i].nullable = false;
450                    } else {
451                        assert!(col_types[i].nullable);
452                    }
453                }
454                col_types
455            }
456            Get { typ, .. } => typ.column_types.clone(),
457            Project { outputs, .. } => {
458                let input = input_types.next().unwrap();
459                outputs.iter().map(|&i| input[i].clone()).collect()
460            }
461            Map { scalars, .. } => {
462                let mut result = input_types.next().unwrap().clone();
463                for scalar in scalars.iter() {
464                    result.push(scalar.typ(&result))
465                }
466                result
467            }
468            FlatMap { func, .. } => {
469                let mut result = input_types.next().unwrap().clone();
470                result.extend(
471                    func.output_sql_type()
472                        .column_types
473                        .iter()
474                        .map(ReprColumnType::from),
475                );
476                result
477            }
478            Filter { predicates, .. } => {
479                let mut result = input_types.next().unwrap().clone();
480
481                // Set as nonnull any columns where null values would cause
482                // any predicate to evaluate to null.
483                for column in non_nullable_columns(predicates) {
484                    result[column].nullable = false;
485                }
486                result
487            }
488            Join { equivalences, .. } => {
489                // Concatenate input column types
490                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
491                // In an equivalence class, if any column is non-null, then make all non-null
492                for equivalence in equivalences {
493                    let col_inds = equivalence
494                        .iter()
495                        .filter_map(|expr| match expr {
496                            MirScalarExpr::Column(col, _name) => Some(*col),
497                            _ => None,
498                        })
499                        .collect_vec();
500                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
501                        for i in col_inds {
502                            types.get_mut(i).unwrap().nullable = false;
503                        }
504                    }
505                }
506                types
507            }
508            Reduce {
509                group_key,
510                aggregates,
511                ..
512            } => {
513                let input = input_types.next().unwrap();
514                group_key
515                    .iter()
516                    .map(|e| e.typ(input))
517                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
518                    .collect()
519            }
520            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
521                input_types.next().unwrap().clone()
522            }
523            Let { .. } => {
524                // skip over the input types for `value`.
525                input_types.nth(1).unwrap().clone()
526            }
527            LetRec { values, .. } => {
528                // skip over the input types for `values`.
529                input_types.nth(values.len()).unwrap().clone()
530            }
531            Union { .. } => {
532                let mut result = input_types.next().unwrap().clone();
533                for input_col_types in input_types {
534                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
535                        *base_col = base_col
536                            .union(col)
537                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
538                    }
539                }
540                result
541            }
542        };
543
544        Ok(col_types)
545    }
546
547    /// Reports the unique keys of the relation given the arities and the unique
548    /// keys of the input relations.
549    ///
550    /// `input_arities` and `input_keys` are required to contain the
551    /// corresponding info for the input relations of
552    /// the current relation in the same order as they are visited by `try_visit_children`
553    /// method, even though not all may be used for computing the schema of the
554    /// current relation. For example, `Let` expects two input types, one for the
555    /// value relation and one for the body, in that order, but only the one for the
556    /// body is used to determine the type of the `Let` relation.
557    ///
558    /// It is meant to be used during post-order traversals to compute unique keys
559    /// incrementally.
560    pub fn keys_with_input_keys<'a, I, J>(
561        &self,
562        mut input_arities: I,
563        mut input_keys: J,
564    ) -> Vec<Vec<usize>>
565    where
566        I: Iterator<Item = usize>,
567        J: Iterator<Item = &'a Vec<Vec<usize>>>,
568    {
569        use MirRelationExpr::*;
570
571        let mut keys = match self {
572            Constant {
573                rows: Ok(rows),
574                typ,
575            } => {
576                let n_cols = typ.arity();
577                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
578                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
579                for (row, diff) in rows {
580                    for (i, datum) in row.iter().enumerate() {
581                        if datum != Datum::Dummy {
582                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
583                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
584                                if is_dupe {
585                                    unique_values_per_col[i] = None;
586                                }
587                            }
588                        }
589                    }
590                }
591                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
592                    vec![vec![]]
593                } else {
594                    // XXX - Multi-column keys are not detected.
595                    typ.keys
596                        .iter()
597                        .cloned()
598                        .chain(
599                            unique_values_per_col
600                                .into_iter()
601                                .enumerate()
602                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
603                                .map(|(idx, _)| vec![idx]),
604                        )
605                        .collect()
606                }
607            }
608            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
609            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
610            Let { .. } => {
611                // skip over the unique keys for value
612                input_keys.nth(1).unwrap().clone()
613            }
614            LetRec { values, .. } => {
615                // skip over the unique keys for value
616                input_keys.nth(values.len()).unwrap().clone()
617            }
618            Project { outputs, .. } => {
619                let input = input_keys.next().unwrap();
620                input
621                    .iter()
622                    .filter_map(|key_set| {
623                        if key_set.iter().all(|k| outputs.contains(k)) {
624                            Some(
625                                key_set
626                                    .iter()
627                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
628                                    .collect(),
629                            )
630                        } else {
631                            None
632                        }
633                    })
634                    .collect()
635            }
636            Map { scalars, .. } => {
637                let mut remappings = Vec::new();
638                let arity = input_arities.next().unwrap();
639                for (column, scalar) in scalars.iter().enumerate() {
640                    // assess whether the scalar preserves uniqueness,
641                    // and could participate in a key!
642
643                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
644                        match expr {
645                            MirScalarExpr::CallUnary { func, expr } => {
646                                if func.preserves_uniqueness() {
647                                    uniqueness(expr)
648                                } else {
649                                    None
650                                }
651                            }
652                            MirScalarExpr::Column(c, _name) => Some(*c),
653                            _ => None,
654                        }
655                    }
656
657                    if let Some(c) = uniqueness(scalar) {
658                        remappings.push((c, column + arity));
659                    }
660                }
661
662                let mut result = input_keys.next().unwrap().clone();
663                let mut new_keys = Vec::new();
664                // Any column in `remappings` could be replaced in a key
665                // by the corresponding c. This could lead to combinatorial
666                // explosion using our current representation, so we wont
667                // do that. Instead, we'll handle the case of one remapping.
668                if remappings.len() == 1 {
669                    let (old, new) = remappings.pop().unwrap();
670                    for key in &result {
671                        if key.contains(&old) {
672                            let mut new_key: Vec<usize> =
673                                key.iter().cloned().filter(|k| k != &old).collect();
674                            new_key.push(new);
675                            new_key.sort_unstable();
676                            new_keys.push(new_key);
677                        }
678                    }
679                    result.append(&mut new_keys);
680                }
681                result
682            }
683            FlatMap { .. } => {
684                // FlatMap can add duplicate rows, so input keys are no longer
685                // valid
686                vec![]
687            }
688            Negate { .. } => {
689                // Although negate may have distinct records for each key,
690                // the multiplicity is -1 rather than 1. This breaks many
691                // of the optimization uses of "keys".
692                vec![]
693            }
694            Filter { predicates, .. } => {
695                // A filter inherits the keys of its input unless the filters
696                // have reduced the input to a single row, in which case the
697                // keys of the input are `()`.
698                let mut input = input_keys.next().unwrap().clone();
699
700                if !input.is_empty() {
701                    // Track columns equated to literals, which we can prune.
702                    let mut cols_equal_to_literal = BTreeSet::new();
703
704                    // Perform union find on `col1 = col2` to establish
705                    // connected components of equated columns. Absent any
706                    // equalities, this will be `0 .. #c` (where #c is the
707                    // greatest column referenced by a predicate), but each
708                    // equality will orient the root of the greater to the root
709                    // of the lesser.
710                    let mut union_find = Vec::new();
711
712                    for expr in predicates.iter() {
713                        if let MirScalarExpr::CallBinary {
714                            func: crate::BinaryFunc::Eq(_),
715                            expr1,
716                            expr2,
717                        } = expr
718                        {
719                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
720                                if expr2.is_literal_ok() {
721                                    cols_equal_to_literal.insert(c);
722                                }
723                            }
724                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
725                                if expr1.is_literal_ok() {
726                                    cols_equal_to_literal.insert(c);
727                                }
728                            }
729                            // Perform union-find to equate columns.
730                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
731                                if c1 != c2 {
732                                    // Ensure union_find has entries up to
733                                    // max(c1, c2) by filling up missing
734                                    // positions with identity mappings.
735                                    while union_find.len() <= std::cmp::max(c1, c2) {
736                                        union_find.push(union_find.len());
737                                    }
738                                    let mut r1 = c1; // Find the representative column of [c1].
739                                    while r1 != union_find[r1] {
740                                        assert!(union_find[r1] < r1);
741                                        r1 = union_find[r1];
742                                    }
743                                    let mut r2 = c2; // Find the representative column of [c2].
744                                    while r2 != union_find[r2] {
745                                        assert!(union_find[r2] < r2);
746                                        r2 = union_find[r2];
747                                    }
748                                    // Union [c1] and [c2] by pointing the
749                                    // larger to the smaller representative (we
750                                    // update the remaining equivalence class
751                                    // members only once after this for-loop).
752                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
753                                }
754                            }
755                        }
756                    }
757
758                    // Complete union-find by pointing each element at its representative column.
759                    for i in 0..union_find.len() {
760                        // Iteration not required, as each prior already references the right column.
761                        union_find[i] = union_find[union_find[i]];
762                    }
763
764                    // Remove columns bound to literals, and remap columns equated to earlier columns.
765                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
766                    for key_set in &mut input {
767                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
768                        for col in key_set.iter_mut() {
769                            if let Some(equiv) = union_find.get(*col) {
770                                *col = *equiv;
771                            }
772                        }
773                        key_set.sort();
774                        key_set.dedup();
775                    }
776                    input.sort();
777                    input.dedup();
778
779                    // Expand out each key to each of its equivalent forms.
780                    // Each instance of `col` can be replaced by any equivalent column.
781                    // This has the potential to result in exponentially sized number of unique keys,
782                    // and in the future we should probably maintain unique keys modulo equivalence.
783
784                    // First, compute an inverse map from each representative
785                    // column `sub` to all other equivalent columns `col`.
786                    let mut subs = Vec::new();
787                    for (col, sub) in union_find.iter().enumerate() {
788                        if *sub != col {
789                            assert!(*sub < col);
790                            while subs.len() <= *sub {
791                                subs.push(Vec::new());
792                            }
793                            subs[*sub].push(col);
794                        }
795                    }
796                    // For each column, substitute for it in each occurrence.
797                    let mut to_add = Vec::new();
798                    for (col, subs) in subs.iter().enumerate() {
799                        if !subs.is_empty() {
800                            for key_set in input.iter() {
801                                if key_set.contains(&col) {
802                                    let mut to_extend = key_set.clone();
803                                    to_extend.retain(|c| c != &col);
804                                    for sub in subs {
805                                        to_extend.push(*sub);
806                                        to_add.push(to_extend.clone());
807                                        to_extend.pop();
808                                    }
809                                }
810                            }
811                        }
812                        // No deduplication, as we cannot introduce duplicates.
813                        input.append(&mut to_add);
814                    }
815                    for key_set in input.iter_mut() {
816                        key_set.sort();
817                        key_set.dedup();
818                    }
819                }
820                input
821            }
822            Join { equivalences, .. } => {
823                // It is important the `new_from_input_arities` constructor is
824                // used. Otherwise, Materialize may potentially end up in an
825                // infinite loop.
826                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
827
828                input_mapper.global_keys(input_keys, equivalences)
829            }
830            Reduce { group_key, .. } => {
831                // The group key should form a key, but we might already have
832                // keys that are subsets of the group key, and should retain
833                // those instead, if so.
834                let mut result = Vec::new();
835                for key_set in input_keys.next().unwrap() {
836                    if key_set
837                        .iter()
838                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
839                    {
840                        result.push(
841                            key_set
842                                .iter()
843                                .map(|i| {
844                                    group_key
845                                        .iter()
846                                        .position(|k| k == &MirScalarExpr::column(*i))
847                                        .unwrap()
848                                })
849                                .collect::<Vec<_>>(),
850                        );
851                    }
852                }
853                if result.is_empty() {
854                    result.push((0..group_key.len()).collect());
855                }
856                result
857            }
858            TopK {
859                group_key, limit, ..
860            } => {
861                // If `limit` is `Some(1)` then the group key will become
862                // a unique key, as there will be only one record with that key.
863                let mut result = input_keys.next().unwrap().clone();
864                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
865                    result.push(group_key.clone())
866                }
867                result
868            }
869            Union { base, inputs } => {
870                // Generally, unions do not have any unique keys, because
871                // each input might duplicate some. However, there is at
872                // least one idiomatic structure that does preserve keys,
873                // which results from SQL aggregations that must populate
874                // absent records with default values. In that pattern,
875                // the union of one GET with its negation, which has first
876                // been subjected to a projection and map, we can remove
877                // their influence on the key structure.
878                //
879                // If there are A, B, each with a unique `key` such that
880                // we are looking at
881                //
882                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
883                //
884                // Then we can report `key` as a unique key.
885                //
886                // TODO: make unique key structure an optimization analysis
887                // rather than part of the type information.
888                // TODO: perhaps ensure that (above) A.proj(key) is a
889                // subset of B, as otherwise there are negative records
890                // and who knows what is true (not expected, but again
891                // who knows what the query plan might look like).
892
893                let arity = input_arities.next().unwrap();
894                let (base_projection, base_with_project_stripped) =
895                    if let MirRelationExpr::Project { input, outputs } = &**base {
896                        (outputs.clone(), &**input)
897                    } else {
898                        // A input without a project is equivalent to an input
899                        // with the project being all columns in the input in order.
900                        ((0..arity).collect::<Vec<_>>(), &**base)
901                    };
902                let mut result = Vec::new();
903                if let MirRelationExpr::Get {
904                    id: first_id,
905                    typ: _,
906                    ..
907                } = base_with_project_stripped
908                {
909                    if inputs.len() == 1 {
910                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
911                            if let MirRelationExpr::Union { base, inputs } = &**input {
912                                if inputs.len() == 1 {
913                                    if let Some((input, outputs)) = base.is_negated_project() {
914                                        if let MirRelationExpr::Get {
915                                            id: second_id,
916                                            typ: _,
917                                            ..
918                                        } = input
919                                        {
920                                            if first_id == second_id {
921                                                result.extend(
922                                                    input_keys
923                                                        .next()
924                                                        .unwrap()
925                                                        .into_iter()
926                                                        .filter(|key| {
927                                                            key.iter().all(|c| {
928                                                                outputs.get(*c) == Some(c)
929                                                                    && base_projection.get(*c)
930                                                                        == Some(c)
931                                                            })
932                                                        })
933                                                        .cloned(),
934                                                );
935                                            }
936                                        }
937                                    }
938                                }
939                            }
940                        }
941                    }
942                }
943                // Important: do not inherit keys of either input, as not unique.
944                result
945            }
946        };
947        keys.sort();
948        keys.dedup();
949        keys
950    }
951
952    /// The number of columns in the relation.
953    ///
954    /// This number is determined from the type, which is determined recursively
955    /// at non-trivial cost.
956    ///
957    /// The arity is computed incrementally with a recursive post-order
958    /// traversal, that accumulates the arities for the relations yet to be
959    /// visited in `arity_stack`.
960    pub fn arity(&self) -> usize {
961        let mut arity_stack = Vec::new();
962        #[allow(deprecated)]
963        self.visit_pre_post_nolimit(
964            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
965                match &e {
966                    MirRelationExpr::Let { body, .. } => {
967                        // Do not traverse the value sub-graph, since it's not relevant for
968                        // determining the arity of Let operators.
969                        Some(vec![&*body])
970                    }
971                    MirRelationExpr::LetRec { body, .. } => {
972                        // Do not traverse the value sub-graph, since it's not relevant for
973                        // determining the arity of Let operators.
974                        Some(vec![&*body])
975                    }
976                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
977                        // No further traversal is required; these operators know their arity.
978                        Some(Vec::new())
979                    }
980                    _ => None,
981                }
982            },
983            &mut |e: &MirRelationExpr| {
984                match &e {
985                    MirRelationExpr::Let { .. } => {
986                        let body_arity = arity_stack.pop().unwrap();
987                        arity_stack.push(0);
988                        arity_stack.push(body_arity);
989                    }
990                    MirRelationExpr::LetRec { values, .. } => {
991                        let body_arity = arity_stack.pop().unwrap();
992                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
993                        arity_stack.push(body_arity);
994                    }
995                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
996                        arity_stack.push(0);
997                    }
998                    _ => {}
999                }
1000                let num_inputs = e.num_inputs();
1001                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1002                let arity = e.arity_with_input_arities(input_arities);
1003                arity_stack.push(arity);
1004            },
1005        );
1006        assert_eq!(arity_stack.len(), 1);
1007        arity_stack.pop().unwrap()
1008    }
1009
1010    /// Reports the arity of the relation given the schema of the input relations.
1011    ///
1012    /// `input_arities` is required to contain the arities for the input relations of
1013    /// the current relation in the same order as they are visited by `try_visit_children`
1014    /// method, even though not all may be used for computing the schema of the
1015    /// current relation. For example, `Let` expects two input types, one for the
1016    /// value relation and one for the body, in that order, but only the one for the
1017    /// body is used to determine the type of the `Let` relation.
1018    ///
1019    /// It is meant to be used during post-order traversals to compute arities
1020    /// incrementally.
1021    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1022    where
1023        I: Iterator<Item = usize>,
1024    {
1025        use MirRelationExpr::*;
1026
1027        match self {
1028            Constant { rows: _, typ } => typ.arity(),
1029            Get { typ, .. } => typ.arity(),
1030            Let { .. } => {
1031                input_arities.next();
1032                input_arities.next().unwrap()
1033            }
1034            LetRec { values, .. } => {
1035                for _ in 0..values.len() {
1036                    input_arities.next();
1037                }
1038                input_arities.next().unwrap()
1039            }
1040            Project { outputs, .. } => outputs.len(),
1041            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1042            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1043            Join { .. } => input_arities.sum(),
1044            Reduce {
1045                input: _,
1046                group_key,
1047                aggregates,
1048                ..
1049            } => group_key.len() + aggregates.len(),
1050            Filter { .. }
1051            | TopK { .. }
1052            | Negate { .. }
1053            | Threshold { .. }
1054            | Union { .. }
1055            | ArrangeBy { .. } => input_arities.next().unwrap(),
1056        }
1057    }
1058
1059    /// The number of child relations this relation has.
1060    pub fn num_inputs(&self) -> usize {
1061        let mut count = 0;
1062
1063        self.visit_children(|_| count += 1);
1064
1065        count
1066    }
1067
1068    /// Constructs a constant collection from specific rows and schema, where
1069    /// each row will have a multiplicity of one.
1070    pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1071        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1072        MirRelationExpr::constant_diff(rows, typ)
1073    }
1074
1075    /// Constructs a constant collection from specific rows and schema, where
1076    /// each row can have an arbitrary multiplicity.
1077    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1078        for (row, _diff) in &rows {
1079            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1080                assert!(
1081                    datum.is_instance_of(column_typ),
1082                    "Expected datum of type {:?}, got value {:?}",
1083                    column_typ,
1084                    datum
1085                );
1086            }
1087        }
1088        let rows = Ok(rows
1089            .into_iter()
1090            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1091            .collect());
1092        MirRelationExpr::Constant { rows, typ }
1093    }
1094
1095    /// If self is a constant, return the value and the type, otherwise `None`.
1096    /// Looks behind `ArrangeBy`s.
1097    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1098        match self {
1099            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1100            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1101            _ => None,
1102        }
1103    }
1104
1105    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1106    /// Looks behind `ArrangeBy`s.
1107    pub fn as_const_mut(
1108        &mut self,
1109    ) -> Option<(
1110        &mut Result<Vec<(Row, Diff)>, EvalError>,
1111        &mut ReprRelationType,
1112    )> {
1113        match self {
1114            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1115            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1116            _ => None,
1117        }
1118    }
1119
1120    /// If self is a constant error, return the error, otherwise `None`.
1121    /// Looks behind `ArrangeBy`s.
1122    pub fn as_const_err(&self) -> Option<&EvalError> {
1123        match self {
1124            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1125            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1126            _ => None,
1127        }
1128    }
1129
1130    /// Checks if `self` is the single element collection with no columns.
1131    pub fn is_constant_singleton(&self) -> bool {
1132        if let Some((Ok(rows), typ)) = self.as_const() {
1133            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1134        } else {
1135            false
1136        }
1137    }
1138
1139    /// Constructs the expression for getting a local collection.
1140    pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1141        MirRelationExpr::Get {
1142            id: Id::Local(id),
1143            typ,
1144            access_strategy: AccessStrategy::UnknownOrLocal,
1145        }
1146    }
1147
1148    /// Constructs the expression for getting a global collection
1149    pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1150        MirRelationExpr::Get {
1151            id: Id::Global(id),
1152            typ,
1153            access_strategy: AccessStrategy::UnknownOrLocal,
1154        }
1155    }
1156
1157    /// Retains only the columns specified by `output`.
1158    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1159        if let MirRelationExpr::Project {
1160            outputs: columns, ..
1161        } = &mut self
1162        {
1163            // Update `outputs` to reference base columns of `input`.
1164            for column in outputs.iter_mut() {
1165                *column = columns[*column];
1166            }
1167            *columns = outputs;
1168            self
1169        } else {
1170            MirRelationExpr::Project {
1171                input: Box::new(self),
1172                outputs,
1173            }
1174        }
1175    }
1176
1177    /// Append to each row the results of applying elements of `scalar`.
1178    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1179        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1180            s.extend(scalars);
1181            self
1182        } else if !scalars.is_empty() {
1183            MirRelationExpr::Map {
1184                input: Box::new(self),
1185                scalars,
1186            }
1187        } else {
1188            self
1189        }
1190    }
1191
1192    /// Append to each row a single `scalar`.
1193    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1194        self.map(vec![scalar])
1195    }
1196
1197    /// Like `map`, but yields zero-or-more output rows per input row
1198    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1199        MirRelationExpr::FlatMap {
1200            input: Box::new(self),
1201            func,
1202            exprs,
1203        }
1204    }
1205
1206    /// Retain only the rows satisfying each of several predicates.
1207    pub fn filter<I>(mut self, predicates: I) -> Self
1208    where
1209        I: IntoIterator<Item = MirScalarExpr>,
1210    {
1211        // Extract existing predicates
1212        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1213            self = *input;
1214            predicates
1215        } else {
1216            Vec::new()
1217        };
1218        // Normalize collection of predicates.
1219        new_predicates.extend(predicates);
1220        new_predicates.retain(|p| !p.is_literal_true());
1221        new_predicates.sort();
1222        new_predicates.dedup();
1223        // Introduce a `Filter` only if we have predicates.
1224        if !new_predicates.is_empty() {
1225            self = MirRelationExpr::Filter {
1226                input: Box::new(self),
1227                predicates: new_predicates,
1228            };
1229        }
1230
1231        self
1232    }
1233
1234    /// Form the Cartesian outer-product of rows in both inputs.
1235    pub fn product(mut self, right: Self) -> Self {
1236        if right.is_constant_singleton() {
1237            self
1238        } else if self.is_constant_singleton() {
1239            right
1240        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1241            inputs.push(right);
1242            self
1243        } else {
1244            MirRelationExpr::join(vec![self, right], vec![])
1245        }
1246    }
1247
1248    /// Performs a relational equijoin among the input collections.
1249    ///
1250    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1251    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1252    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1253    ///
1254    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1255    /// input collection and for each row examining its `column`th column.
1256    ///
1257    /// # Example
1258    ///
1259    /// ```rust
1260    /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1261    /// use mz_expr::MirRelationExpr;
1262    ///
1263    /// // A common schema for each input.
1264    /// let schema = ReprRelationType::new(vec![
1265    ///     ReprScalarType::Int32.nullable(false),
1266    ///     ReprScalarType::Int32.nullable(false),
1267    /// ]);
1268    ///
1269    /// // the specific data are not important here.
1270    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1271    ///
1272    /// // Three collections that could have been different.
1273    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1274    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1276    ///
1277    /// // Join the three relations looking for triangles, like so.
1278    /// //
1279    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1280    /// let joined = MirRelationExpr::join(
1281    ///     vec![input0, input1, input2],
1282    ///     vec![
1283    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1284    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1285    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1286    ///     ],
1287    /// );
1288    ///
1289    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1290    /// // A projection resolves this and produces the correct output.
1291    /// let result = joined.project(vec![0, 1, 3]);
1292    /// ```
1293    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1294        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1295
1296        let equivalences = variables
1297            .into_iter()
1298            .map(|vs| {
1299                vs.into_iter()
1300                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1301                    .collect::<Vec<_>>()
1302            })
1303            .collect::<Vec<_>>();
1304
1305        Self::join_scalars(inputs, equivalences)
1306    }
1307
1308    /// Constructs a join operator from inputs and required-equal scalar expressions.
1309    pub fn join_scalars(
1310        mut inputs: Vec<MirRelationExpr>,
1311        equivalences: Vec<Vec<MirScalarExpr>>,
1312    ) -> Self {
1313        // Remove all constant inputs that are the identity for join.
1314        // They neither introduce nor modify any column references.
1315        inputs.retain(|i| !i.is_constant_singleton());
1316        MirRelationExpr::Join {
1317            inputs,
1318            equivalences,
1319            implementation: JoinImplementation::Unimplemented,
1320        }
1321    }
1322
1323    /// Perform a key-wise reduction / aggregation.
1324    ///
1325    /// The `group_key` argument indicates columns in the input collection that should
1326    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1327    /// one output column in addition to the keys.
1328    pub fn reduce(
1329        self,
1330        group_key: Vec<usize>,
1331        aggregates: Vec<AggregateExpr>,
1332        expected_group_size: Option<u64>,
1333    ) -> Self {
1334        MirRelationExpr::Reduce {
1335            input: Box::new(self),
1336            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1337            aggregates,
1338            monotonic: false,
1339            expected_group_size,
1340        }
1341    }
1342
1343    /// Perform a key-wise reduction order by and limit.
1344    ///
1345    /// The `group_key` argument indicates columns in the input collection that should
1346    /// be grouped, the `order_key` argument indicates columns that should be further
1347    /// used to order records within groups, and the `limit` argument constrains the
1348    /// total number of records that should be produced in each group.
1349    pub fn top_k(
1350        self,
1351        group_key: Vec<usize>,
1352        order_key: Vec<ColumnOrder>,
1353        limit: Option<MirScalarExpr>,
1354        offset: usize,
1355        expected_group_size: Option<u64>,
1356    ) -> Self {
1357        MirRelationExpr::TopK {
1358            input: Box::new(self),
1359            group_key,
1360            order_key,
1361            limit,
1362            offset,
1363            expected_group_size,
1364            monotonic: false,
1365        }
1366    }
1367
1368    /// Negates the occurrences of each row.
1369    pub fn negate(self) -> Self {
1370        if let MirRelationExpr::Negate { input } = self {
1371            *input
1372        } else {
1373            MirRelationExpr::Negate {
1374                input: Box::new(self),
1375            }
1376        }
1377    }
1378
1379    /// Removes all but the first occurrence of each row.
1380    pub fn distinct(self) -> Self {
1381        let arity = self.arity();
1382        self.distinct_by((0..arity).collect())
1383    }
1384
1385    /// Removes all but the first occurrence of each key. Columns not included
1386    /// in the `group_key` are discarded.
1387    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1388        self.reduce(group_key, vec![], None)
1389    }
1390
1391    /// Discards rows with a negative frequency.
1392    pub fn threshold(self) -> Self {
1393        if let MirRelationExpr::Threshold { .. } = &self {
1394            self
1395        } else {
1396            MirRelationExpr::Threshold {
1397                input: Box::new(self),
1398            }
1399        }
1400    }
1401
1402    /// Unions together any number inputs.
1403    ///
1404    /// If `inputs` is empty, then an empty relation of type `typ` is
1405    /// constructed.
1406    pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1407        // Deconstruct `inputs` as `Union`s and reconstitute.
1408        let mut flat_inputs = Vec::with_capacity(inputs.len());
1409        for input in inputs {
1410            if let MirRelationExpr::Union { base, inputs } = input {
1411                flat_inputs.push(*base);
1412                flat_inputs.extend(inputs);
1413            } else {
1414                flat_inputs.push(input);
1415            }
1416        }
1417        inputs = flat_inputs;
1418        if inputs.len() == 0 {
1419            MirRelationExpr::Constant {
1420                rows: Ok(vec![]),
1421                typ,
1422            }
1423        } else if inputs.len() == 1 {
1424            inputs.into_element()
1425        } else {
1426            MirRelationExpr::Union {
1427                base: Box::new(inputs.remove(0)),
1428                inputs,
1429            }
1430        }
1431    }
1432
1433    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1434    pub fn union(self, other: Self) -> Self {
1435        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1436        let mut flat_inputs = Vec::with_capacity(2);
1437        if let MirRelationExpr::Union { base, inputs } = self {
1438            flat_inputs.push(*base);
1439            flat_inputs.extend(inputs);
1440        } else {
1441            flat_inputs.push(self);
1442        }
1443        if let MirRelationExpr::Union { base, inputs } = other {
1444            flat_inputs.push(*base);
1445            flat_inputs.extend(inputs);
1446        } else {
1447            flat_inputs.push(other);
1448        }
1449
1450        MirRelationExpr::Union {
1451            base: Box::new(flat_inputs.remove(0)),
1452            inputs: flat_inputs,
1453        }
1454    }
1455
1456    /// Arranges the collection by the specified columns
1457    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1458        MirRelationExpr::ArrangeBy {
1459            input: Box::new(self),
1460            keys: keys.to_owned(),
1461        }
1462    }
1463
1464    /// Indicates if this is a constant empty collection.
1465    ///
1466    /// A false value does not mean the collection is known to be non-empty,
1467    /// only that we cannot currently determine that it is statically empty.
1468    pub fn is_empty(&self) -> bool {
1469        if let Some((Ok(rows), ..)) = self.as_const() {
1470            rows.is_empty()
1471        } else {
1472            false
1473        }
1474    }
1475
1476    /// If the expression is a negated project, return the input and the projection.
1477    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1478        if let MirRelationExpr::Negate { input } = self {
1479            if let MirRelationExpr::Project { input, outputs } = &**input {
1480                return Some((&**input, outputs));
1481            }
1482        }
1483        if let MirRelationExpr::Project { input, outputs } = self {
1484            if let MirRelationExpr::Negate { input } = &**input {
1485                return Some((&**input, outputs));
1486            }
1487        }
1488        None
1489    }
1490
1491    /// Pretty-print this [MirRelationExpr] to a string.
1492    pub fn pretty(&self) -> String {
1493        let config = ExplainConfig::default();
1494        self.debug_explain(&config, None)
1495    }
1496
1497    /// Pretty-print this [MirRelationExpr] to a string using a custom
1498    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1499    /// This is intended for debugging and tests, not users.
1500    pub fn debug_explain(
1501        &self,
1502        config: &ExplainConfig,
1503        humanizer: Option<&dyn ExprHumanizer>,
1504    ) -> String {
1505        text_string_at(self, || PlanRenderingContext {
1506            indent: Indent::default(),
1507            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1508            annotations: BTreeMap::default(),
1509            config,
1510            ambiguous_ids: BTreeSet::default(),
1511        })
1512    }
1513
1514    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1515    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1516    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1517    /// the best possible key and nullability, since we are making an empty collection.
1518    ///
1519    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1520    /// the correct type.
1521    pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1522        if let Some(typ) = &typ {
1523            let self_typ = self.typ();
1524            soft_assert_no_log!(
1525                self_typ
1526                    .column_types
1527                    .iter()
1528                    .zip_eq(typ.column_types.iter())
1529                    .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1530            );
1531        }
1532        let mut typ = typ.unwrap_or_else(|| self.typ());
1533        typ.keys = vec![vec![]];
1534        for ct in typ.column_types.iter_mut() {
1535            ct.nullable = false;
1536        }
1537        std::mem::replace(
1538            self,
1539            MirRelationExpr::Constant {
1540                rows: Ok(vec![]),
1541                typ,
1542            },
1543        )
1544    }
1545
1546    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1547    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1548    /// possible nullability, since we are making an empty collection.
1549    pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1550        self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1551    }
1552
1553    /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1554    ///
1555    /// This is the preferred entry point for optimizer transforms, where repr
1556    /// types are the native currency. Internally converts to [`SqlColumnType`]
1557    /// and delegates to [`Self::take_safely_with_col_types`].
1558    pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1559        self.take_safely(Some(ReprRelationType::new(typ)))
1560    }
1561
1562    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1563    ///
1564    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1565    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1566        let empty = MirRelationExpr::Constant {
1567            rows: Ok(vec![]),
1568            typ: ReprRelationType::new(Vec::new()),
1569        };
1570        std::mem::replace(self, empty)
1571    }
1572
1573    /// Replaces `self` with some logic applied to `self`.
1574    pub fn replace_using<F>(&mut self, logic: F)
1575    where
1576        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1577    {
1578        let empty = MirRelationExpr::Constant {
1579            rows: Ok(vec![]),
1580            typ: ReprRelationType::new(Vec::new()),
1581        };
1582        let expr = std::mem::replace(self, empty);
1583        *self = logic(expr);
1584    }
1585
1586    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1587    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1588    where
1589        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1590    {
1591        if let MirRelationExpr::Get { .. } = self {
1592            // already done
1593            body(id_gen, self)
1594        } else {
1595            let id = LocalId::new(id_gen.allocate_id());
1596            let get = MirRelationExpr::Get {
1597                id: Id::Local(id),
1598                typ: self.typ(),
1599                access_strategy: AccessStrategy::UnknownOrLocal,
1600            };
1601            let body = (body)(id_gen, get)?;
1602            Ok(MirRelationExpr::Let {
1603                id,
1604                value: Box::new(self),
1605                body: Box::new(body),
1606            })
1607        }
1608    }
1609
1610    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1611    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1612    pub fn anti_lookup<E>(
1613        self,
1614        id_gen: &mut IdGen,
1615        keys_and_values: MirRelationExpr,
1616        default: Vec<(Datum, ReprScalarType)>,
1617    ) -> Result<MirRelationExpr, E> {
1618        let (data, column_types): (Vec<_>, Vec<_>) = default
1619            .into_iter()
1620            .map(|(datum, scalar_type)| {
1621                (
1622                    datum,
1623                    ReprColumnType {
1624                        scalar_type,
1625                        nullable: datum.is_null(),
1626                    },
1627                )
1628            })
1629            .unzip();
1630        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1631        self.let_in(id_gen, |_id_gen, get_keys| {
1632            let get_keys_arity = get_keys.arity();
1633            Ok(MirRelationExpr::join(
1634                vec![
1635                    // all the missing keys (with count 1)
1636                    keys_and_values
1637                        .distinct_by((0..get_keys_arity).collect())
1638                        .negate()
1639                        .union(get_keys.clone().distinct()),
1640                    // join with keys to get the correct counts
1641                    get_keys.clone(),
1642                ],
1643                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1644            )
1645            // get rid of the extra copies of columns from keys
1646            .project((0..get_keys_arity).collect())
1647            // This join is logically equivalent to
1648            // `.map(<default_expr>)`, but using a join allows for
1649            // potential predicate pushdown and elision in the
1650            // optimizer.
1651            .product(MirRelationExpr::constant(
1652                vec![data],
1653                ReprRelationType::new(column_types),
1654            )))
1655        })
1656    }
1657
1658    /// Return:
1659    /// * every row in keys_and_values
1660    /// * every row in `self` that does not have a matching row in the first columns of
1661    ///   `keys_and_values`, using `default` to fill in the remaining columns
1662    /// (This is LEFT OUTER JOIN if:
1663    /// 1) `default` is a row of null
1664    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1665    pub fn lookup<E>(
1666        self,
1667        id_gen: &mut IdGen,
1668        keys_and_values: MirRelationExpr,
1669        default: Vec<(Datum<'static>, ReprScalarType)>,
1670    ) -> Result<MirRelationExpr, E> {
1671        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1672            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1673                id_gen,
1674                get_keys_and_values,
1675                default,
1676            )?))
1677        })
1678    }
1679
1680    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1681    pub fn contains_temporal(&self) -> bool {
1682        let mut contains = false;
1683        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1684        contains
1685    }
1686
1687    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1688    ///
1689    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1690    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1691    where
1692        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1693    {
1694        use MirRelationExpr::*;
1695        match self {
1696            Map { scalars, .. } => {
1697                for s in scalars {
1698                    f(s)?;
1699                }
1700            }
1701            Filter { predicates, .. } => {
1702                for p in predicates {
1703                    f(p)?;
1704                }
1705            }
1706            FlatMap { exprs, .. } => {
1707                for expr in exprs {
1708                    f(expr)?;
1709                }
1710            }
1711            Join {
1712                inputs: _,
1713                equivalences,
1714                implementation,
1715            } => {
1716                for equivalence in equivalences {
1717                    for expr in equivalence {
1718                        f(expr)?;
1719                    }
1720                }
1721                match implementation {
1722                    JoinImplementation::Differential((_, start_key, _), order) => {
1723                        if let Some(start_key) = start_key {
1724                            for k in start_key {
1725                                f(k)?;
1726                            }
1727                        }
1728                        for (_, lookup_key, _) in order {
1729                            for k in lookup_key {
1730                                f(k)?;
1731                            }
1732                        }
1733                    }
1734                    JoinImplementation::DeltaQuery(paths) => {
1735                        for path in paths {
1736                            for (_, lookup_key, _) in path {
1737                                for k in lookup_key {
1738                                    f(k)?;
1739                                }
1740                            }
1741                        }
1742                    }
1743                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1744                        for k in index_key {
1745                            f(k)?;
1746                        }
1747                    }
1748                    JoinImplementation::Unimplemented => {} // No scalar exprs
1749                }
1750            }
1751            ArrangeBy { keys, .. } => {
1752                for key in keys {
1753                    for s in key {
1754                        f(s)?;
1755                    }
1756                }
1757            }
1758            Reduce {
1759                group_key,
1760                aggregates,
1761                ..
1762            } => {
1763                for s in group_key {
1764                    f(s)?;
1765                }
1766                for agg in aggregates {
1767                    f(&mut agg.expr)?;
1768                }
1769            }
1770            TopK { limit, .. } => {
1771                if let Some(s) = limit {
1772                    f(s)?;
1773                }
1774            }
1775            Constant { .. }
1776            | Get { .. }
1777            | Let { .. }
1778            | LetRec { .. }
1779            | Project { .. }
1780            | Negate { .. }
1781            | Threshold { .. }
1782            | Union { .. } => (),
1783        }
1784        Ok(())
1785    }
1786
1787    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1788    /// rooted at `self`.
1789    ///
1790    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1791    /// nodes.
1792    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1793    where
1794        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1795        E: From<RecursionLimitError>,
1796    {
1797        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1798    }
1799
1800    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1801    /// rooted at `self`.
1802    ///
1803    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1804    /// nodes.
1805    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1806    where
1807        F: FnMut(&mut MirScalarExpr),
1808    {
1809        self.try_visit_scalars_mut(&mut |s| {
1810            f(s);
1811            Ok::<_, RecursionLimitError>(())
1812        })
1813        .expect("Unexpected error in `visit_scalars_mut` call");
1814    }
1815
1816    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1817    ///
1818    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1819    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1820    where
1821        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1822    {
1823        use MirRelationExpr::*;
1824        match self {
1825            Map { scalars, .. } => {
1826                for s in scalars {
1827                    f(s)?;
1828                }
1829            }
1830            Filter { predicates, .. } => {
1831                for p in predicates {
1832                    f(p)?;
1833                }
1834            }
1835            FlatMap { exprs, .. } => {
1836                for expr in exprs {
1837                    f(expr)?;
1838                }
1839            }
1840            Join {
1841                inputs: _,
1842                equivalences,
1843                implementation,
1844            } => {
1845                for equivalence in equivalences {
1846                    for expr in equivalence {
1847                        f(expr)?;
1848                    }
1849                }
1850                match implementation {
1851                    JoinImplementation::Differential((_, start_key, _), order) => {
1852                        if let Some(start_key) = start_key {
1853                            for k in start_key {
1854                                f(k)?;
1855                            }
1856                        }
1857                        for (_, lookup_key, _) in order {
1858                            for k in lookup_key {
1859                                f(k)?;
1860                            }
1861                        }
1862                    }
1863                    JoinImplementation::DeltaQuery(paths) => {
1864                        for path in paths {
1865                            for (_, lookup_key, _) in path {
1866                                for k in lookup_key {
1867                                    f(k)?;
1868                                }
1869                            }
1870                        }
1871                    }
1872                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1873                        for k in index_key {
1874                            f(k)?;
1875                        }
1876                    }
1877                    JoinImplementation::Unimplemented => {} // No scalar exprs
1878                }
1879            }
1880            ArrangeBy { keys, .. } => {
1881                for key in keys {
1882                    for s in key {
1883                        f(s)?;
1884                    }
1885                }
1886            }
1887            Reduce {
1888                group_key,
1889                aggregates,
1890                ..
1891            } => {
1892                for s in group_key {
1893                    f(s)?;
1894                }
1895                for agg in aggregates {
1896                    f(&agg.expr)?;
1897                }
1898            }
1899            TopK { limit, .. } => {
1900                if let Some(s) = limit {
1901                    f(s)?;
1902                }
1903            }
1904            Constant { .. }
1905            | Get { .. }
1906            | Let { .. }
1907            | LetRec { .. }
1908            | Project { .. }
1909            | Negate { .. }
1910            | Threshold { .. }
1911            | Union { .. } => (),
1912        }
1913        Ok(())
1914    }
1915
1916    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1917    /// rooted at `self`.
1918    ///
1919    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1920    /// nodes.
1921    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1922    where
1923        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1924        E: From<RecursionLimitError>,
1925    {
1926        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1927    }
1928
1929    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1930    /// rooted at `self`.
1931    ///
1932    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1933    /// nodes.
1934    pub fn visit_scalars<F>(&self, f: &mut F)
1935    where
1936        F: FnMut(&MirScalarExpr),
1937    {
1938        self.try_visit_scalars(&mut |s| {
1939            f(s);
1940            Ok::<_, RecursionLimitError>(())
1941        })
1942        .expect("Unexpected error in `visit_scalars` call");
1943    }
1944
1945    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1946    /// stack overflow in `drop_in_place`.
1947    ///
1948    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1949    /// dropped or otherwise overwritten.
1950    pub fn destroy_carefully(&mut self) {
1951        let mut todo = vec![self.take_dangerous()];
1952        while let Some(mut expr) = todo.pop() {
1953            for child in expr.children_mut() {
1954                todo.push(child.take_dangerous());
1955            }
1956        }
1957    }
1958
1959    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1960    /// debug printing purposes.
1961    pub fn debug_size_and_depth(&self) -> (usize, usize) {
1962        let mut size = 0;
1963        let mut max_depth = 0;
1964        let mut todo = vec![(self, 1)];
1965        while let Some((expr, depth)) = todo.pop() {
1966            size += 1;
1967            max_depth = max(max_depth, depth);
1968            todo.extend(expr.children().map(|c| (c, depth + 1)));
1969        }
1970        (size, max_depth)
1971    }
1972
1973    /// The MirRelationExpr is considered potentially expensive if and only if
1974    /// at least one of the following conditions is true:
1975    ///
1976    ///  - It contains at least one MirScalarExpr with a function call.
1977    ///  - It contains at least one FlatMap or a Reduce operator.
1978    ///  - We run into a RecursionLimitError while analyzing the expression.
1979    ///
1980    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1981    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1982    pub fn could_run_expensive_function(&self) -> bool {
1983        let mut result = false;
1984        use MirRelationExpr::*;
1985        use MirScalarExpr::*;
1986        if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1987            result |= match scalar {
1988                Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1989                // Function calls are considered expensive
1990                CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1991            };
1992            Ok(())
1993        }) {
1994            // Conservatively set `true` if on RecursionLimitError.
1995            result = true;
1996        }
1997        self.visit_pre(|e: &MirRelationExpr| {
1998            // FlatMap has a table function; Reduce has an aggregate function.
1999            // Other constructs use MirScalarExpr to run a function
2000            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2001        });
2002        result
2003    }
2004
2005    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2006    /// than what `Hashable::hashed` would give us.)
2007    pub fn hash_to_u64(&self) -> u64 {
2008        let mut h = DefaultHasher::new();
2009        self.hash(&mut h);
2010        h.finish()
2011    }
2012}
2013
2014// `LetRec` helpers
2015impl MirRelationExpr {
2016    /// True when `expr` contains a `LetRec` AST node.
2017    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2018        let mut worklist = vec![self];
2019        while let Some(expr) = worklist.pop() {
2020            if let MirRelationExpr::LetRec { .. } = expr {
2021                return true;
2022            }
2023            worklist.extend(expr.children());
2024        }
2025        false
2026    }
2027
2028    /// Return the number of sub-expressions in the tree (including self).
2029    pub fn size(&self) -> usize {
2030        let mut size = 0;
2031        self.visit_pre(|_| size += 1);
2032        size
2033    }
2034
2035    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2036    /// iterations. These are those ids that have a reference before they are defined, when reading
2037    /// all the bindings in order.
2038    ///
2039    /// For example:
2040    /// ```SQL
2041    /// WITH MUTUALLY RECURSIVE
2042    ///     x(...) AS f(z),
2043    ///     y(...) AS g(x),
2044    ///     z(...) AS h(y)
2045    /// ...;
2046    /// ```
2047    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2048    /// iteration.
2049    ///
2050    /// Note that if a binding references itself, that is also returned.
2051    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2052        let mut used_across_iterations = BTreeSet::new();
2053        let mut defined = BTreeSet::new();
2054        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2055            value.visit_pre(|expr| {
2056                if let MirRelationExpr::Get {
2057                    id: Local(get_id), ..
2058                } = expr
2059                {
2060                    // If we haven't seen a definition for it yet, then this will refer
2061                    // to the previous iteration.
2062                    // The `ids.contains` part of the condition is needed to exclude
2063                    // those ids that are not really in this LetRec, but either an inner
2064                    // or outer one.
2065                    if !defined.contains(get_id) && ids.contains(get_id) {
2066                        used_across_iterations.insert(*get_id);
2067                    }
2068                }
2069            });
2070            defined.insert(*binding_id);
2071        }
2072        used_across_iterations
2073    }
2074
2075    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2076    ///
2077    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2078    /// identifiers are rewritten to be the constant collection.
2079    /// This makes the computation perform exactly "one" iteration.
2080    ///
2081    /// This was used only temporarily while developing `LetRec`.
2082    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2083        let mut deadlist = BTreeSet::new();
2084        let mut worklist = vec![self];
2085        while let Some(expr) = worklist.pop() {
2086            if let MirRelationExpr::LetRec {
2087                ids,
2088                values,
2089                limits: _,
2090                body,
2091            } = expr
2092            {
2093                let ids_values = values
2094                    .drain(..)
2095                    .zip_eq(ids)
2096                    .map(|(value, id)| (*id, value))
2097                    .collect::<Vec<_>>();
2098                *expr = body.take_dangerous();
2099                for (id, mut value) in ids_values.into_iter().rev() {
2100                    // Remove references to potentially recursive identifiers.
2101                    deadlist.insert(id);
2102                    value.visit_pre_mut(|e| {
2103                        if let MirRelationExpr::Get {
2104                            id: crate::Id::Local(id),
2105                            typ,
2106                            ..
2107                        } = e
2108                        {
2109                            let typ = typ.clone();
2110                            if deadlist.contains(id) {
2111                                e.take_safely(Some(typ));
2112                            }
2113                        }
2114                    });
2115                    *expr = MirRelationExpr::Let {
2116                        id,
2117                        value: Box::new(value),
2118                        body: Box::new(expr.take_dangerous()),
2119                    };
2120                }
2121                worklist.push(expr);
2122            } else {
2123                worklist.extend(expr.children_mut().rev());
2124            }
2125        }
2126    }
2127
2128    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2129    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2130    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2131    /// redefinition.
2132    ///
2133    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2134    pub fn collect_expirations(
2135        id: LocalId,
2136        expr: &MirRelationExpr,
2137        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2138    ) {
2139        expr.visit_pre(|e| {
2140            if let MirRelationExpr::Get {
2141                id: Id::Local(referenced_id),
2142                ..
2143            } = e
2144            {
2145                // The following check needs `renumber_bindings` to have run recently
2146                if referenced_id >= &id {
2147                    expire_whens
2148                        .entry(*referenced_id)
2149                        .or_insert_with(Vec::new)
2150                        .push(id);
2151                }
2152            }
2153        });
2154    }
2155
2156    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2157    /// about such Ids whose information depended on the earlier definition of `id`, according to
2158    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2159    pub fn do_expirations<I>(
2160        redefined_id: LocalId,
2161        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2162        id_infos: &mut BTreeMap<LocalId, I>,
2163    ) -> Vec<(LocalId, I)> {
2164        let mut expired_infos = Vec::new();
2165        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2166            for expired_id in expirations.into_iter() {
2167                if let Some(offer) = id_infos.remove(&expired_id) {
2168                    expired_infos.push((expired_id, offer));
2169                }
2170            }
2171        }
2172        expired_infos
2173    }
2174}
2175/// Augment non-nullability of columns, by observing either
2176/// 1. Predicates that explicitly test for null values, and
2177/// 2. Columns that if null would make a predicate be null.
2178pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2179    let mut nonnull_required_columns = BTreeSet::new();
2180    for predicate in predicates {
2181        // Add any columns that being null would force the predicate to be null.
2182        // Should that happen, the row would be discarded.
2183        predicate.non_null_requirements(&mut nonnull_required_columns);
2184
2185        /*
2186        Test for explicit checks that a column is non-null.
2187
2188        This analysis is ad hoc, and will miss things:
2189
2190        materialize=> create table a(x int, y int);
2191        CREATE TABLE
2192        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2193        Optimized Plan
2194        --------------------------------------------------------------------------------------------------------
2195        Explained Query:                                                                                      +
2196        Project (#0) // { types: "(integer?)" }                                                             +
2197        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2198        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2199                                                                                  +
2200        Source materialize.public.a                                                                           +
2201        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2202
2203        (1 row)
2204        */
2205
2206        if let MirScalarExpr::CallUnary {
2207            func: UnaryFunc::Not(scalar_func::Not),
2208            expr,
2209        } = predicate
2210        {
2211            if let MirScalarExpr::CallUnary {
2212                func: UnaryFunc::IsNull(scalar_func::IsNull),
2213                expr,
2214            } = &**expr
2215            {
2216                if let MirScalarExpr::Column(c, _name) = &**expr {
2217                    nonnull_required_columns.insert(*c);
2218                }
2219            }
2220        }
2221    }
2222
2223    nonnull_required_columns
2224}
2225
2226impl CollectionPlan for MirRelationExpr {
2227    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2228    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2229    ///
2230    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2231    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2232    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2233        if let MirRelationExpr::Get {
2234            id: Id::Global(id), ..
2235        } = self
2236        {
2237            out.insert(*id);
2238        }
2239        self.visit_children(|expr| expr.depends_on_into(out))
2240    }
2241}
2242
2243impl MirRelationExpr {
2244    /// Iterates through references to child expressions.
2245    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2246        let mut first = None;
2247        let mut second = None;
2248        let mut rest = None;
2249        let mut last = None;
2250
2251        use MirRelationExpr::*;
2252        match self {
2253            Constant { .. } | Get { .. } => (),
2254            Let { value, body, .. } => {
2255                first = Some(&**value);
2256                second = Some(&**body);
2257            }
2258            LetRec { values, body, .. } => {
2259                rest = Some(values);
2260                last = Some(&**body);
2261            }
2262            Project { input, .. }
2263            | Map { input, .. }
2264            | FlatMap { input, .. }
2265            | Filter { input, .. }
2266            | Reduce { input, .. }
2267            | TopK { input, .. }
2268            | Negate { input }
2269            | Threshold { input }
2270            | ArrangeBy { input, .. } => {
2271                first = Some(&**input);
2272            }
2273            Join { inputs, .. } => {
2274                rest = Some(inputs);
2275            }
2276            Union { base, inputs } => {
2277                first = Some(&**base);
2278                rest = Some(inputs);
2279            }
2280        }
2281
2282        first
2283            .into_iter()
2284            .chain(second)
2285            .chain(rest.into_iter().flatten())
2286            .chain(last)
2287    }
2288
2289    /// Iterates through mutable references to child expressions.
2290    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2291        let mut first = None;
2292        let mut second = None;
2293        let mut rest = None;
2294        let mut last = None;
2295
2296        use MirRelationExpr::*;
2297        match self {
2298            Constant { .. } | Get { .. } => (),
2299            Let { value, body, .. } => {
2300                first = Some(&mut **value);
2301                second = Some(&mut **body);
2302            }
2303            LetRec { values, body, .. } => {
2304                rest = Some(values);
2305                last = Some(&mut **body);
2306            }
2307            Project { input, .. }
2308            | Map { input, .. }
2309            | FlatMap { input, .. }
2310            | Filter { input, .. }
2311            | Reduce { input, .. }
2312            | TopK { input, .. }
2313            | Negate { input }
2314            | Threshold { input }
2315            | ArrangeBy { input, .. } => {
2316                first = Some(&mut **input);
2317            }
2318            Join { inputs, .. } => {
2319                rest = Some(inputs);
2320            }
2321            Union { base, inputs } => {
2322                first = Some(&mut **base);
2323                rest = Some(inputs);
2324            }
2325        }
2326
2327        first
2328            .into_iter()
2329            .chain(second)
2330            .chain(rest.into_iter().flatten())
2331            .chain(last)
2332    }
2333
2334    /// Iterative pre-order visitor.
2335    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2336        let mut worklist = vec![self];
2337        while let Some(expr) = worklist.pop() {
2338            f(expr);
2339            worklist.extend(expr.children().rev());
2340        }
2341    }
2342
2343    /// Iterative pre-order visitor.
2344    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2345        let mut worklist = vec![self];
2346        while let Some(expr) = worklist.pop() {
2347            f(expr);
2348            worklist.extend(expr.children_mut().rev());
2349        }
2350    }
2351
2352    /// Return a vector of references to the subtrees of this expression
2353    /// in post-visit order (the last element is `&self`).
2354    pub fn post_order_vec(&self) -> Vec<&Self> {
2355        let mut stack = vec![self];
2356        let mut result = vec![];
2357        while let Some(expr) = stack.pop() {
2358            result.push(expr);
2359            stack.extend(expr.children());
2360        }
2361        result.reverse();
2362        result
2363    }
2364}
2365
2366impl VisitChildren<Self> for MirRelationExpr {
2367    fn visit_children<F>(&self, mut f: F)
2368    where
2369        F: FnMut(&Self),
2370    {
2371        for child in self.children() {
2372            f(child)
2373        }
2374    }
2375
2376    fn visit_mut_children<F>(&mut self, mut f: F)
2377    where
2378        F: FnMut(&mut Self),
2379    {
2380        for child in self.children_mut() {
2381            f(child)
2382        }
2383    }
2384
2385    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2386    where
2387        F: FnMut(&Self) -> Result<(), E>,
2388        E: From<RecursionLimitError>,
2389    {
2390        for child in self.children() {
2391            f(child)?
2392        }
2393        Ok(())
2394    }
2395
2396    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2397    where
2398        F: FnMut(&mut Self) -> Result<(), E>,
2399        E: From<RecursionLimitError>,
2400    {
2401        for child in self.children_mut() {
2402            f(child)?
2403        }
2404        Ok(())
2405    }
2406}
2407
2408/// Specification for an ordering by a column.
2409#[derive(
2410    Debug,
2411    Clone,
2412    Copy,
2413    Eq,
2414    PartialEq,
2415    Ord,
2416    PartialOrd,
2417    Serialize,
2418    Deserialize,
2419    Hash,
2420    MzReflect
2421)]
2422pub struct ColumnOrder {
2423    /// The column index.
2424    pub column: usize,
2425    /// Whether to sort in descending order.
2426    #[serde(default)]
2427    pub desc: bool,
2428    /// Whether to sort nulls last.
2429    #[serde(default)]
2430    pub nulls_last: bool,
2431}
2432
2433impl Columnation for ColumnOrder {
2434    type InnerRegion = CopyRegion<Self>;
2435}
2436
2437impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2438where
2439    M: HumanizerMode,
2440{
2441    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2442        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2443        write!(
2444            f,
2445            "{} {} {}",
2446            self.child(&self.expr.column),
2447            if self.expr.desc { "desc" } else { "asc" },
2448            if self.expr.nulls_last {
2449                "nulls_last"
2450            } else {
2451                "nulls_first"
2452            },
2453        )
2454    }
2455}
2456
2457/// Describes an aggregation expression.
2458#[derive(
2459    Clone,
2460    Debug,
2461    Eq,
2462    PartialEq,
2463    Ord,
2464    PartialOrd,
2465    Serialize,
2466    Deserialize,
2467    Hash,
2468    MzReflect
2469)]
2470pub struct AggregateExpr {
2471    /// Names the aggregation function.
2472    pub func: AggregateFunc,
2473    /// An expression which extracts from each row the input to `func`.
2474    pub expr: MirScalarExpr,
2475    /// Should the aggregation be applied only to distinct results in each group.
2476    #[serde(default)]
2477    pub distinct: bool,
2478}
2479
2480impl AggregateExpr {
2481    /// Computes the type of this `AggregateExpr`.
2482    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2483        self.func.output_sql_type(self.expr.sql_typ(column_types))
2484    }
2485
2486    /// Computes the type of this `AggregateExpr`.
2487    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2488        self.func.output_type(self.expr.typ(column_types))
2489    }
2490
2491    /// Returns whether the expression has a constant result.
2492    pub fn is_constant(&self) -> bool {
2493        match self.func {
2494            AggregateFunc::MaxNumeric
2495            | AggregateFunc::MaxInt16
2496            | AggregateFunc::MaxInt32
2497            | AggregateFunc::MaxInt64
2498            | AggregateFunc::MaxUInt16
2499            | AggregateFunc::MaxUInt32
2500            | AggregateFunc::MaxUInt64
2501            | AggregateFunc::MaxMzTimestamp
2502            | AggregateFunc::MaxFloat32
2503            | AggregateFunc::MaxFloat64
2504            | AggregateFunc::MaxBool
2505            | AggregateFunc::MaxString
2506            | AggregateFunc::MaxDate
2507            | AggregateFunc::MaxTimestamp
2508            | AggregateFunc::MaxTimestampTz
2509            | AggregateFunc::MaxInterval
2510            | AggregateFunc::MaxTime
2511            | AggregateFunc::MinNumeric
2512            | AggregateFunc::MinInt16
2513            | AggregateFunc::MinInt32
2514            | AggregateFunc::MinInt64
2515            | AggregateFunc::MinUInt16
2516            | AggregateFunc::MinUInt32
2517            | AggregateFunc::MinUInt64
2518            | AggregateFunc::MinMzTimestamp
2519            | AggregateFunc::MinFloat32
2520            | AggregateFunc::MinFloat64
2521            | AggregateFunc::MinBool
2522            | AggregateFunc::MinString
2523            | AggregateFunc::MinDate
2524            | AggregateFunc::MinTimestamp
2525            | AggregateFunc::MinTimestampTz
2526            | AggregateFunc::MinInterval
2527            | AggregateFunc::MinTime
2528            | AggregateFunc::Any
2529            | AggregateFunc::All
2530            | AggregateFunc::Dummy => self.expr.is_literal(),
2531            AggregateFunc::Count => self.expr.is_literal_null(),
2532            AggregateFunc::SumInt16
2533            | AggregateFunc::SumInt32
2534            | AggregateFunc::SumInt64
2535            | AggregateFunc::SumUInt16
2536            | AggregateFunc::SumUInt32
2537            | AggregateFunc::SumUInt64
2538            | AggregateFunc::SumFloat32
2539            | AggregateFunc::SumFloat64
2540            | AggregateFunc::SumNumeric
2541            | AggregateFunc::JsonbAgg { .. }
2542            | AggregateFunc::JsonbObjectAgg { .. }
2543            | AggregateFunc::MapAgg { .. }
2544            | AggregateFunc::ArrayConcat { .. }
2545            | AggregateFunc::ListConcat { .. }
2546            | AggregateFunc::StringAgg { .. }
2547            | AggregateFunc::RowNumber { .. }
2548            | AggregateFunc::Rank { .. }
2549            | AggregateFunc::DenseRank { .. }
2550            | AggregateFunc::LagLead { .. }
2551            | AggregateFunc::FirstValue { .. }
2552            | AggregateFunc::LastValue { .. }
2553            | AggregateFunc::FusedValueWindowFunc { .. }
2554            | AggregateFunc::WindowAggregate { .. }
2555            | AggregateFunc::FusedWindowAggregate { .. } => self.expr.is_literal_err(),
2556        }
2557    }
2558
2559    /// Returns an expression that computes `self` on a group that has exactly one row.
2560    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2561    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2562    pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2563        match &self.func {
2564            // Count is one if non-null, and zero if null.
2565            AggregateFunc::Count => self
2566                .expr
2567                .clone()
2568                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2569                .if_then_else(
2570                    MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2571                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2572                ),
2573
2574            // SumInt16 takes Int16s as input, but outputs Int64s.
2575            AggregateFunc::SumInt16 => self
2576                .expr
2577                .clone()
2578                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2579
2580            // SumInt32 takes Int32s as input, but outputs Int64s.
2581            AggregateFunc::SumInt32 => self
2582                .expr
2583                .clone()
2584                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2585
2586            // SumInt64 takes Int64s as input, but outputs numerics.
2587            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2588                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2589            )),
2590
2591            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2592            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2593                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2594            ),
2595
2596            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2597            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2598                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2599            ),
2600
2601            // SumUInt64 takes UInt64s as input, but outputs numerics.
2602            AggregateFunc::SumUInt64 => {
2603                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2604                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2605                ))
2606            }
2607
2608            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2609            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2610                JsonbBuildArray,
2611                vec![
2612                    self.expr
2613                        .clone()
2614                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2615                ],
2616            ),
2617
2618            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2619            AggregateFunc::JsonbObjectAgg { .. } => {
2620                let record = self
2621                    .expr
2622                    .clone()
2623                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2624                MirScalarExpr::call_variadic(
2625                    JsonbBuildObject,
2626                    (0..2)
2627                        .map(|i| {
2628                            record
2629                                .clone()
2630                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2631                        })
2632                        .collect(),
2633                )
2634            }
2635
2636            AggregateFunc::MapAgg { value_type, .. } => {
2637                let record = self
2638                    .expr
2639                    .clone()
2640                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2641                MirScalarExpr::call_variadic(
2642                    MapBuild {
2643                        value_type: value_type.clone(),
2644                    },
2645                    (0..2)
2646                        .map(|i| {
2647                            record
2648                                .clone()
2649                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2650                        })
2651                        .collect(),
2652                )
2653            }
2654
2655            // StringAgg takes nested records of strings and outputs a string
2656            AggregateFunc::StringAgg { .. } => self
2657                .expr
2658                .clone()
2659                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2660                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2661
2662            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2663            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2664                .expr
2665                .clone()
2666                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2667
2668            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2669            AggregateFunc::RowNumber { .. } => {
2670                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2671            }
2672            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2673            AggregateFunc::DenseRank { .. } => {
2674                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2675            }
2676
2677            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2678            AggregateFunc::LagLead { lag_lead, .. } => {
2679                let tuple = self
2680                    .expr
2681                    .clone()
2682                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2683
2684                // Get the overall return type
2685                let return_type_with_orig_row = self
2686                    .typ(input_type)
2687                    .scalar_type
2688                    .unwrap_list_element_type()
2689                    .clone();
2690                let lag_lead_return_type =
2691                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2692
2693                // Extract the original row
2694                let original_row = tuple
2695                    .clone()
2696                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2697
2698                // Extract the encoded args
2699                let encoded_args =
2700                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2701
2702                let (result_expr, column_name) =
2703                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2704
2705                MirScalarExpr::call_variadic(
2706                    ListCreate {
2707                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2708                    },
2709                    vec![MirScalarExpr::call_variadic(
2710                        RecordCreate {
2711                            field_names: vec![column_name, ColumnName::from("?record?")],
2712                        },
2713                        vec![result_expr, original_row],
2714                    )],
2715                )
2716            }
2717
2718            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2719            AggregateFunc::FirstValue { window_frame, .. } => {
2720                let tuple = self
2721                    .expr
2722                    .clone()
2723                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2724
2725                // Get the overall return type
2726                let return_type_with_orig_row = self
2727                    .typ(input_type)
2728                    .scalar_type
2729                    .unwrap_list_element_type()
2730                    .clone();
2731                let first_value_return_type =
2732                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2733
2734                // Extract the original row
2735                let original_row = tuple
2736                    .clone()
2737                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2738
2739                // Extract the input value
2740                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2741
2742                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2743                    window_frame,
2744                    arg,
2745                    first_value_return_type,
2746                );
2747
2748                MirScalarExpr::call_variadic(
2749                    ListCreate {
2750                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2751                    },
2752                    vec![MirScalarExpr::call_variadic(
2753                        RecordCreate {
2754                            field_names: vec![column_name, ColumnName::from("?record?")],
2755                        },
2756                        vec![result_expr, original_row],
2757                    )],
2758                )
2759            }
2760
2761            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2762            AggregateFunc::LastValue { window_frame, .. } => {
2763                let tuple = self
2764                    .expr
2765                    .clone()
2766                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2767
2768                // Get the overall return type
2769                let return_type_with_orig_row = self
2770                    .typ(input_type)
2771                    .scalar_type
2772                    .unwrap_list_element_type()
2773                    .clone();
2774                let last_value_return_type =
2775                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2776
2777                // Extract the original row
2778                let original_row = tuple
2779                    .clone()
2780                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2781
2782                // Extract the input value
2783                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2784
2785                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2786                    window_frame,
2787                    arg,
2788                    last_value_return_type,
2789                );
2790
2791                MirScalarExpr::call_variadic(
2792                    ListCreate {
2793                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2794                    },
2795                    vec![MirScalarExpr::call_variadic(
2796                        RecordCreate {
2797                            field_names: vec![column_name, ColumnName::from("?record?")],
2798                        },
2799                        vec![result_expr, original_row],
2800                    )],
2801                )
2802            }
2803
2804            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2805            // See an example MIR in `window_func_applied_to`.
2806            AggregateFunc::WindowAggregate {
2807                wrapped_aggregate,
2808                window_frame,
2809                order_by: _,
2810            } => {
2811                // TODO: deduplicate code between the various window function cases.
2812
2813                let tuple = self
2814                    .expr
2815                    .clone()
2816                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2817
2818                // Get the overall return type
2819                let return_type = self
2820                    .typ(input_type)
2821                    .scalar_type
2822                    .unwrap_list_element_type()
2823                    .clone();
2824                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2825
2826                // Extract the original row
2827                let original_row = tuple
2828                    .clone()
2829                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2830
2831                // Extract the input value
2832                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2833
2834                let (result, column_name) = Self::on_unique_window_agg(
2835                    window_frame,
2836                    arg_expr,
2837                    input_type,
2838                    window_agg_return_type,
2839                    wrapped_aggregate,
2840                );
2841
2842                MirScalarExpr::call_variadic(
2843                    ListCreate {
2844                        elem_type: SqlScalarType::from_repr(&return_type),
2845                    },
2846                    vec![MirScalarExpr::call_variadic(
2847                        RecordCreate {
2848                            field_names: vec![column_name, ColumnName::from("?record?")],
2849                        },
2850                        vec![result, original_row],
2851                    )],
2852                )
2853            }
2854
2855            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2856            AggregateFunc::FusedWindowAggregate {
2857                wrapped_aggregates,
2858                order_by: _,
2859                window_frame,
2860            } => {
2861                // Throw away OrderByExprs
2862                let tuple = self
2863                    .expr
2864                    .clone()
2865                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2866
2867                // Extract the original row
2868                let original_row = tuple
2869                    .clone()
2870                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2871
2872                // Extract the args of the fused call
2873                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2874
2875                let return_type_with_orig_row = self
2876                    .typ(input_type)
2877                    .scalar_type
2878                    .unwrap_list_element_type()
2879                    .clone();
2880
2881                let all_func_return_types =
2882                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2883                let mut func_result_exprs = Vec::new();
2884                let mut col_names = Vec::new();
2885                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2886                    let arg = all_args
2887                        .clone()
2888                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2889                    let return_type =
2890                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2891                    let (result, column_name) = Self::on_unique_window_agg(
2892                        window_frame,
2893                        arg,
2894                        input_type,
2895                        return_type,
2896                        wrapped_aggr,
2897                    );
2898                    func_result_exprs.push(result);
2899                    col_names.push(column_name);
2900                }
2901
2902                MirScalarExpr::call_variadic(
2903                    ListCreate {
2904                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2905                    },
2906                    vec![MirScalarExpr::call_variadic(
2907                        RecordCreate {
2908                            field_names: vec![
2909                                ColumnName::from("?fused_window_aggr?"),
2910                                ColumnName::from("?record?"),
2911                            ],
2912                        },
2913                        vec![
2914                            MirScalarExpr::call_variadic(
2915                                RecordCreate {
2916                                    field_names: col_names,
2917                                },
2918                                func_result_exprs,
2919                            ),
2920                            original_row,
2921                        ],
2922                    )],
2923                )
2924            }
2925
2926            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2927            AggregateFunc::FusedValueWindowFunc {
2928                funcs,
2929                order_by: outer_order_by,
2930            } => {
2931                // Throw away OrderByExprs
2932                let tuple = self
2933                    .expr
2934                    .clone()
2935                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2936
2937                // Extract the original row
2938                let original_row = tuple
2939                    .clone()
2940                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2941
2942                // Extract the encoded args of the fused call
2943                let all_encoded_args =
2944                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2945
2946                let return_type_with_orig_row = self
2947                    .typ(input_type)
2948                    .scalar_type
2949                    .unwrap_list_element_type()
2950                    .clone();
2951
2952                let all_func_return_types =
2953                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2954                let mut func_result_exprs = Vec::new();
2955                let mut col_names = Vec::new();
2956                for (idx, func) in funcs.iter().enumerate() {
2957                    let args_for_func = all_encoded_args
2958                        .clone()
2959                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2960                    let return_type_for_func =
2961                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2962                    let (result, column_name) = match func {
2963                        AggregateFunc::LagLead {
2964                            lag_lead,
2965                            order_by,
2966                            ignore_nulls: _,
2967                        } => {
2968                            assert_eq!(order_by, outer_order_by);
2969                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2970                        }
2971                        AggregateFunc::FirstValue {
2972                            window_frame,
2973                            order_by,
2974                        } => {
2975                            assert_eq!(order_by, outer_order_by);
2976                            Self::on_unique_first_value_last_value(
2977                                window_frame,
2978                                args_for_func,
2979                                return_type_for_func,
2980                            )
2981                        }
2982                        AggregateFunc::LastValue {
2983                            window_frame,
2984                            order_by,
2985                        } => {
2986                            assert_eq!(order_by, outer_order_by);
2987                            Self::on_unique_first_value_last_value(
2988                                window_frame,
2989                                args_for_func,
2990                                return_type_for_func,
2991                            )
2992                        }
2993                        _ => panic!("unknown function in FusedValueWindowFunc"),
2994                    };
2995                    func_result_exprs.push(result);
2996                    col_names.push(column_name);
2997                }
2998
2999                MirScalarExpr::call_variadic(
3000                    ListCreate {
3001                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
3002                    },
3003                    vec![MirScalarExpr::call_variadic(
3004                        RecordCreate {
3005                            field_names: vec![
3006                                ColumnName::from("?fused_value_window_func?"),
3007                                ColumnName::from("?record?"),
3008                            ],
3009                        },
3010                        vec![
3011                            MirScalarExpr::call_variadic(
3012                                RecordCreate {
3013                                    field_names: col_names,
3014                                },
3015                                func_result_exprs,
3016                            ),
3017                            original_row,
3018                        ],
3019                    )],
3020                )
3021            }
3022
3023            // All other variants should return the argument to the aggregation.
3024            AggregateFunc::MaxNumeric
3025            | AggregateFunc::MaxInt16
3026            | AggregateFunc::MaxInt32
3027            | AggregateFunc::MaxInt64
3028            | AggregateFunc::MaxUInt16
3029            | AggregateFunc::MaxUInt32
3030            | AggregateFunc::MaxUInt64
3031            | AggregateFunc::MaxMzTimestamp
3032            | AggregateFunc::MaxFloat32
3033            | AggregateFunc::MaxFloat64
3034            | AggregateFunc::MaxBool
3035            | AggregateFunc::MaxString
3036            | AggregateFunc::MaxDate
3037            | AggregateFunc::MaxTimestamp
3038            | AggregateFunc::MaxTimestampTz
3039            | AggregateFunc::MaxInterval
3040            | AggregateFunc::MaxTime
3041            | AggregateFunc::MinNumeric
3042            | AggregateFunc::MinInt16
3043            | AggregateFunc::MinInt32
3044            | AggregateFunc::MinInt64
3045            | AggregateFunc::MinUInt16
3046            | AggregateFunc::MinUInt32
3047            | AggregateFunc::MinUInt64
3048            | AggregateFunc::MinMzTimestamp
3049            | AggregateFunc::MinFloat32
3050            | AggregateFunc::MinFloat64
3051            | AggregateFunc::MinBool
3052            | AggregateFunc::MinString
3053            | AggregateFunc::MinDate
3054            | AggregateFunc::MinTimestamp
3055            | AggregateFunc::MinTimestampTz
3056            | AggregateFunc::MinInterval
3057            | AggregateFunc::MinTime
3058            | AggregateFunc::SumFloat32
3059            | AggregateFunc::SumFloat64
3060            | AggregateFunc::SumNumeric
3061            | AggregateFunc::Any
3062            | AggregateFunc::All
3063            | AggregateFunc::Dummy => self.expr.clone(),
3064        }
3065    }
3066
3067    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3068    fn on_unique_ranking_window_funcs(
3069        &self,
3070        input_type: &[ReprColumnType],
3071        col_name: &str,
3072    ) -> MirScalarExpr {
3073        let sql_input_type: Vec<SqlColumnType> =
3074            input_type.iter().map(SqlColumnType::from_repr).collect();
3075        let list = self
3076            .expr
3077            .clone()
3078            // extract the list within the record
3079            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3080
3081        // extract the expression within the list
3082        let record = MirScalarExpr::call_variadic(
3083            ListIndex,
3084            vec![
3085                list,
3086                MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3087            ],
3088        );
3089
3090        MirScalarExpr::call_variadic(
3091            ListCreate {
3092                elem_type: self
3093                    .sql_typ(&sql_input_type)
3094                    .scalar_type
3095                    .unwrap_list_element_type()
3096                    .clone(),
3097            },
3098            vec![MirScalarExpr::call_variadic(
3099                RecordCreate {
3100                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3101                },
3102                vec![
3103                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3104                    record,
3105                ],
3106            )],
3107        )
3108    }
3109
3110    /// `on_unique` for `lag` and `lead`
3111    fn on_unique_lag_lead(
3112        lag_lead: &LagLeadType,
3113        encoded_args: MirScalarExpr,
3114        return_type: ReprScalarType,
3115    ) -> (MirScalarExpr, ColumnName) {
3116        let expr = encoded_args
3117            .clone()
3118            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3119        let offset = encoded_args
3120            .clone()
3121            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3122        let default_value =
3123            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3124
3125        // In this case, the window always has only one element, so if the offset is not null and
3126        // not zero, the default value should be returned instead.
3127        let value = offset
3128            .clone()
3129            .call_binary(
3130                MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3131                crate::func::Eq,
3132            )
3133            .if_then_else(expr, default_value);
3134        let result_expr = offset
3135            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3136            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3137
3138        let column_name = ColumnName::from(match lag_lead {
3139            LagLeadType::Lag => "?lag?",
3140            LagLeadType::Lead => "?lead?",
3141        });
3142
3143        (result_expr, column_name)
3144    }
3145
3146    /// `on_unique` for `first_value` and `last_value`
3147    fn on_unique_first_value_last_value(
3148        window_frame: &WindowFrame,
3149        arg: MirScalarExpr,
3150        return_type: ReprScalarType,
3151    ) -> (MirScalarExpr, ColumnName) {
3152        // If the window frame includes the current (single) row, return its value, null otherwise
3153        let result_expr = if window_frame.includes_current_row() {
3154            arg
3155        } else {
3156            MirScalarExpr::literal_null(return_type)
3157        };
3158        (result_expr, ColumnName::from("?first_value?"))
3159    }
3160
3161    /// `on_unique` for window aggregations
3162    fn on_unique_window_agg(
3163        window_frame: &WindowFrame,
3164        arg_expr: MirScalarExpr,
3165        input_type: &[ReprColumnType],
3166        return_type: ReprScalarType,
3167        wrapped_aggr: &AggregateFunc,
3168    ) -> (MirScalarExpr, ColumnName) {
3169        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3170        // that row. Otherwise, return the default value for the aggregate.
3171        let result_expr = if window_frame.includes_current_row() {
3172            AggregateExpr {
3173                func: wrapped_aggr.clone(),
3174                expr: arg_expr,
3175                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3176            }
3177            .on_unique(input_type)
3178        } else {
3179            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3180        };
3181        (result_expr, ColumnName::from("?window_agg?"))
3182    }
3183
3184    /// Returns whether the expression is COUNT(*) or not.  Note that
3185    /// when we define the count builtin in sql::func, we convert
3186    /// COUNT(*) to COUNT(true), making it indistinguishable from
3187    /// literal COUNT(true), but we prefer to consider this as the
3188    /// former.
3189    ///
3190    /// (HIR has the same `is_count_asterisk`.)
3191    pub fn is_count_asterisk(&self) -> bool {
3192        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3193    }
3194}
3195
3196/// Describe a join implementation in dataflow.
3197#[derive(
3198    Clone,
3199    Debug,
3200    Eq,
3201    PartialEq,
3202    Ord,
3203    PartialOrd,
3204    Serialize,
3205    Deserialize,
3206    Hash,
3207    MzReflect
3208)]
3209pub enum JoinImplementation {
3210    /// Perform a sequence of binary differential dataflow joins.
3211    ///
3212    /// The first argument indicates
3213    /// 1) the index of the starting collection,
3214    /// 2) if it should be arranged, the keys to arrange it by, and
3215    /// 3) the characteristics of the starting collection (for EXPLAINing).
3216    /// The sequence that follows lists other relation indexes, and the key for
3217    /// the arrangement we should use when joining it in.
3218    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3219    /// were used for join ordering.
3220    ///
3221    /// Each collection index should occur exactly once, either as the starting collection
3222    /// or somewhere in the list.
3223    Differential(
3224        (
3225            usize,
3226            Option<Vec<MirScalarExpr>>,
3227            Option<JoinInputCharacteristics>,
3228        ),
3229        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3230    ),
3231    /// Perform independent delta query dataflows for each input.
3232    ///
3233    /// The argument is a sequence of plans, for the input collections in order.
3234    /// Each plan starts from the corresponding index, and then in sequence joins
3235    /// against collections identified by index and with the specified arrangement key.
3236    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3237    /// used for join ordering.
3238    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3239    /// Join a user-created index with a constant collection to speed up the evaluation of a
3240    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3241    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3242    /// to represent it in MIR, because the fast path detection wants to match on this.
3243    ///
3244    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3245    IndexedFilter(
3246        GlobalId,
3247        GlobalId,
3248        Vec<MirScalarExpr>,
3249        #[mzreflect(ignore)] Vec<Row>,
3250    ),
3251    /// No implementation yet selected.
3252    Unimplemented,
3253}
3254
3255impl Default for JoinImplementation {
3256    fn default() -> Self {
3257        JoinImplementation::Unimplemented
3258    }
3259}
3260
3261impl JoinImplementation {
3262    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3263    pub fn is_implemented(&self) -> bool {
3264        match self {
3265            Self::Unimplemented => false,
3266            _ => true,
3267        }
3268    }
3269
3270    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3271    pub fn name(&self) -> Option<&'static str> {
3272        match self {
3273            Self::Differential(..) => Some("differential"),
3274            Self::DeltaQuery(..) => Some("delta"),
3275            Self::IndexedFilter(..) => Some("indexed_filter"),
3276            Self::Unimplemented => None,
3277        }
3278    }
3279}
3280
3281/// Characteristics of a join order candidate collection.
3282///
3283/// A candidate is described by a collection and a key, and may have various liabilities.
3284/// Primarily, the candidate may risk substantial inflation of records, which is something
3285/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3286/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3287/// collections in the interest of consistent tie-breaking. For more characteristics, see
3288/// comments on individual fields.
3289///
3290/// This has more than one version. `new` instantiates the appropriate version based on a
3291/// feature flag.
3292#[derive(
3293    Eq,
3294    PartialEq,
3295    Ord,
3296    PartialOrd,
3297    Debug,
3298    Clone,
3299    Serialize,
3300    Deserialize,
3301    Hash,
3302    MzReflect
3303)]
3304pub enum JoinInputCharacteristics {
3305    /// Old version, with `enable_join_prioritize_arranged` turned off.
3306    V1(JoinInputCharacteristicsV1),
3307    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3308    V2(JoinInputCharacteristicsV2),
3309}
3310
3311impl JoinInputCharacteristics {
3312    /// Creates a new instance with the given characteristics.
3313    pub fn new(
3314        unique_key: bool,
3315        key_length: usize,
3316        arranged: bool,
3317        cardinality: Option<usize>,
3318        filters: FilterCharacteristics,
3319        input: usize,
3320        enable_join_prioritize_arranged: bool,
3321    ) -> Self {
3322        if enable_join_prioritize_arranged {
3323            Self::V2(JoinInputCharacteristicsV2::new(
3324                unique_key,
3325                key_length,
3326                arranged,
3327                cardinality,
3328                filters,
3329                input,
3330            ))
3331        } else {
3332            Self::V1(JoinInputCharacteristicsV1::new(
3333                unique_key,
3334                key_length,
3335                arranged,
3336                cardinality,
3337                filters,
3338                input,
3339            ))
3340        }
3341    }
3342
3343    /// Turns the instance into a String to be printed in EXPLAIN.
3344    pub fn explain(&self) -> String {
3345        match self {
3346            Self::V1(jic) => jic.explain(),
3347            Self::V2(jic) => jic.explain(),
3348        }
3349    }
3350
3351    /// Whether the join input described by `self` is arranged.
3352    pub fn arranged(&self) -> bool {
3353        match self {
3354            Self::V1(jic) => jic.arranged,
3355            Self::V2(jic) => jic.arranged,
3356        }
3357    }
3358
3359    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3360    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3361        match self {
3362            Self::V1(jic) => &mut jic.filters,
3363            Self::V2(jic) => &mut jic.filters,
3364        }
3365    }
3366}
3367
3368/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3369#[derive(
3370    Eq,
3371    PartialEq,
3372    Ord,
3373    PartialOrd,
3374    Debug,
3375    Clone,
3376    Serialize,
3377    Deserialize,
3378    Hash,
3379    MzReflect
3380)]
3381pub struct JoinInputCharacteristicsV2 {
3382    /// An excellent indication that record count will not increase.
3383    pub unique_key: bool,
3384    /// Cross joins are bad.
3385    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3386    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3387    /// but otherwise `key_length` is less important than `arranged`.)
3388    pub not_cross: bool,
3389    /// Indicates that there will be no additional in-memory footprint.
3390    pub arranged: bool,
3391    /// A weaker signal that record count will not increase.
3392    pub key_length: usize,
3393    /// Estimated cardinality (lower is better)
3394    pub cardinality: Option<std::cmp::Reverse<usize>>,
3395    /// Characteristics of the filter that is applied at this input.
3396    pub filters: FilterCharacteristics,
3397    /// We want to prefer input earlier in the input list, for stability of ordering.
3398    pub input: std::cmp::Reverse<usize>,
3399}
3400
3401impl JoinInputCharacteristicsV2 {
3402    /// Creates a new instance with the given characteristics.
3403    pub fn new(
3404        unique_key: bool,
3405        key_length: usize,
3406        arranged: bool,
3407        cardinality: Option<usize>,
3408        filters: FilterCharacteristics,
3409        input: usize,
3410    ) -> Self {
3411        Self {
3412            unique_key,
3413            not_cross: key_length > 0,
3414            arranged,
3415            key_length,
3416            cardinality: cardinality.map(std::cmp::Reverse),
3417            filters,
3418            input: std::cmp::Reverse(input),
3419        }
3420    }
3421
3422    /// Turns the instance into a String to be printed in EXPLAIN.
3423    pub fn explain(&self) -> String {
3424        let mut e = "".to_owned();
3425        if self.unique_key {
3426            e.push_str("U");
3427        }
3428        // Don't need to print `not_cross`, because that is visible in the printed key.
3429        // if !self.not_cross {
3430        //     e.push_str("C");
3431        // }
3432        for _ in 0..self.key_length {
3433            e.push_str("K");
3434        }
3435        if self.arranged {
3436            e.push_str("A");
3437        }
3438        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3439            e.push_str(&format!("|{cardinality}|"));
3440        }
3441        e.push_str(&self.filters.explain());
3442        e
3443    }
3444}
3445
3446/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3447#[derive(
3448    Eq,
3449    PartialEq,
3450    Ord,
3451    PartialOrd,
3452    Debug,
3453    Clone,
3454    Serialize,
3455    Deserialize,
3456    Hash,
3457    MzReflect
3458)]
3459pub struct JoinInputCharacteristicsV1 {
3460    /// An excellent indication that record count will not increase.
3461    pub unique_key: bool,
3462    /// A weaker signal that record count will not increase.
3463    pub key_length: usize,
3464    /// Indicates that there will be no additional in-memory footprint.
3465    pub arranged: bool,
3466    /// Estimated cardinality (lower is better)
3467    pub cardinality: Option<std::cmp::Reverse<usize>>,
3468    /// Characteristics of the filter that is applied at this input.
3469    pub filters: FilterCharacteristics,
3470    /// We want to prefer input earlier in the input list, for stability of ordering.
3471    pub input: std::cmp::Reverse<usize>,
3472}
3473
3474impl JoinInputCharacteristicsV1 {
3475    /// Creates a new instance with the given characteristics.
3476    pub fn new(
3477        unique_key: bool,
3478        key_length: usize,
3479        arranged: bool,
3480        cardinality: Option<usize>,
3481        filters: FilterCharacteristics,
3482        input: usize,
3483    ) -> Self {
3484        Self {
3485            unique_key,
3486            key_length,
3487            arranged,
3488            cardinality: cardinality.map(std::cmp::Reverse),
3489            filters,
3490            input: std::cmp::Reverse(input),
3491        }
3492    }
3493
3494    /// Turns the instance into a String to be printed in EXPLAIN.
3495    pub fn explain(&self) -> String {
3496        let mut e = "".to_owned();
3497        if self.unique_key {
3498            e.push_str("U");
3499        }
3500        for _ in 0..self.key_length {
3501            e.push_str("K");
3502        }
3503        if self.arranged {
3504            e.push_str("A");
3505        }
3506        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3507            e.push_str(&format!("|{cardinality}|"));
3508        }
3509        e.push_str(&self.filters.explain());
3510        e
3511    }
3512}
3513
3514/// Instructions for finishing the result of a query.
3515///
3516/// The primary reason for the existence of this structure and attendant code
3517/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3518/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3519/// multisets. But as it turns out, the same idea can be used to optimize
3520/// trivial peeks.
3521///
3522/// The generic parameters are for accommodating prepared statement parameters in
3523/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3524/// `bind_parameters` on them.
3525#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3526pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3527    /// Order rows by the given columns.
3528    pub order_by: Vec<ColumnOrder>,
3529    /// Include only as many rows (after offset).
3530    pub limit: Option<L>,
3531    /// Omit as many rows.
3532    pub offset: O,
3533    /// Include only given columns.
3534    pub project: Vec<usize>,
3535}
3536
3537impl<L> RowSetFinishing<L> {
3538    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3539    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3540        RowSetFinishing {
3541            order_by: Vec::new(),
3542            limit: None,
3543            offset: 0,
3544            project: (0..arity).collect(),
3545        }
3546    }
3547    /// True if the finishing does nothing to any result set.
3548    pub fn is_trivial(&self, arity: usize) -> bool {
3549        self.limit.is_none()
3550            && self.order_by.is_empty()
3551            && self.offset == 0
3552            && self.project.iter().copied().eq(0..arity)
3553    }
3554    /// True if the finishing does not require an ORDER BY.
3555    ///
3556    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3557    /// explicit ordering we will skip an arbitrary bag of elements and return
3558    /// the first arbitrary elements in the remaining bag. The result semantics
3559    /// are still correct but maybe surprising for some users.
3560    pub fn is_streamable(&self, arity: usize) -> bool {
3561        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3562    }
3563}
3564
3565impl RowSetFinishing<NonNeg<i64>, usize> {
3566    /// The number of rows needed from before the finishing to evaluate the finishing:
3567    /// offset + limit.
3568    ///
3569    /// If it returns None, then we need all the rows.
3570    pub fn num_rows_needed(&self) -> Option<usize> {
3571        self.limit
3572            .as_ref()
3573            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3574    }
3575}
3576
3577impl RowSetFinishing {
3578    /// Applies finishing actions to a [`RowCollection`], and reports the total
3579    /// time it took to run.
3580    ///
3581    /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3582    /// well as the size of the response in bytes.
3583    pub fn finish(
3584        &self,
3585        rows: RowCollection,
3586        max_result_size: u64,
3587        max_returned_query_size: Option<u64>,
3588        duration_histogram: &Histogram,
3589    ) -> Result<(RowCollectionIter, usize), String> {
3590        let now = Instant::now();
3591        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3592        let duration = now.elapsed();
3593        duration_histogram.observe(duration.as_secs_f64());
3594
3595        result
3596    }
3597
3598    /// Implementation for [`RowSetFinishing::finish`].
3599    fn finish_inner(
3600        &self,
3601        rows: RowCollection,
3602        max_result_size: u64,
3603        max_returned_query_size: Option<u64>,
3604    ) -> Result<(RowCollectionIter, usize), String> {
3605        // How much additional memory is required to make a sorted view.
3606        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3607        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3608
3609        // Bail if creating the sorted view would require us to use too much memory.
3610        if required_memory > usize::cast_from(max_result_size) {
3611            let max_bytes = ByteSize::b(max_result_size);
3612            return Err(format!("result exceeds max size of {max_bytes}",));
3613        }
3614
3615        let sorted_view = rows;
3616        let mut iter = sorted_view
3617            .into_row_iter()
3618            .apply_offset(self.offset)
3619            .with_projection(self.project.clone());
3620
3621        if let Some(limit) = self.limit {
3622            let limit = u64::from(limit);
3623            let limit = usize::cast_from(limit);
3624            iter = iter.with_limit(limit);
3625        };
3626
3627        // TODO(parkmycar): Re-think how we can calculate the total response size without
3628        // having to iterate through the entire collection of Rows, while still
3629        // respecting the LIMIT, OFFSET, and projections.
3630        //
3631        // Note: It feels a bit bad always calculating the response size, but we almost
3632        // always need it to either check the `max_returned_query_size`, or for reporting
3633        // in the query history.
3634        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3635
3636        // Bail if we would end up returning more data to the client than they can support.
3637        if let Some(max) = max_returned_query_size {
3638            if response_size > usize::cast_from(max) {
3639                let max_bytes = ByteSize::b(max);
3640                return Err(format!("result exceeds max size of {max_bytes}"));
3641            }
3642        }
3643
3644        Ok((iter, response_size))
3645    }
3646}
3647
3648/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3649/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3650/// on query result size.
3651#[derive(Debug)]
3652pub struct RowSetFinishingIncremental {
3653    /// Include only as many rows (after offset).
3654    pub remaining_limit: Option<usize>,
3655    /// Omit as many rows.
3656    pub remaining_offset: usize,
3657    /// The maximum allowed result size, as requested by the client.
3658    pub max_returned_query_size: Option<u64>,
3659    /// Tracks our remaining allowed budget for result size.
3660    pub remaining_max_returned_query_size: Option<u64>,
3661    /// Include only given columns.
3662    pub project: Vec<usize>,
3663}
3664
3665impl RowSetFinishingIncremental {
3666    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3667    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3668    /// `true`.
3669    ///
3670    /// # Panics
3671    ///
3672    /// Panics if the result is not streamable, that is it has an ORDER BY.
3673    pub fn new(
3674        offset: usize,
3675        limit: Option<NonNeg<i64>>,
3676        project: Vec<usize>,
3677        max_returned_query_size: Option<u64>,
3678    ) -> Self {
3679        let limit = limit.map(|l| {
3680            let l = u64::from(l);
3681            let l = usize::cast_from(l);
3682            l
3683        });
3684
3685        RowSetFinishingIncremental {
3686            remaining_limit: limit,
3687            remaining_offset: offset,
3688            max_returned_query_size,
3689            remaining_max_returned_query_size: max_returned_query_size,
3690            project,
3691        }
3692    }
3693
3694    /// Applies finishing actions to the given [`RowCollection`], and reports
3695    /// the total time it took to run.
3696    ///
3697    /// Returns a [`RowCollectionIter`] that contains all of the response
3698    /// data.
3699    pub fn finish_incremental(
3700        &mut self,
3701        rows: RowCollection,
3702        max_result_size: u64,
3703        duration_histogram: &Histogram,
3704    ) -> Result<RowCollectionIter, String> {
3705        let now = Instant::now();
3706        let result = self.finish_incremental_inner(rows, max_result_size);
3707        let duration = now.elapsed();
3708        duration_histogram.observe(duration.as_secs_f64());
3709
3710        result
3711    }
3712
3713    fn finish_incremental_inner(
3714        &mut self,
3715        rows: RowCollection,
3716        max_result_size: u64,
3717    ) -> Result<RowCollectionIter, String> {
3718        // How much additional memory is required to make a sorted view.
3719        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3720        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3721
3722        // Bail if creating the sorted view would require us to use too much memory.
3723        if required_memory > usize::cast_from(max_result_size) {
3724            let max_bytes = ByteSize::b(max_result_size);
3725            return Err(format!("total result exceeds max size of {max_bytes}",));
3726        }
3727
3728        let batch_num_rows = rows.count();
3729
3730        let sorted_view = rows;
3731        let mut iter = sorted_view
3732            .into_row_iter()
3733            .apply_offset(self.remaining_offset)
3734            .with_projection(self.project.clone());
3735
3736        if let Some(limit) = self.remaining_limit {
3737            iter = iter.with_limit(limit);
3738        };
3739
3740        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3741        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3742            *remaining_limit -= iter.count();
3743        }
3744
3745        // TODO(parkmycar): Re-think how we can calculate the total response size without
3746        // having to iterate through the entire collection of Rows, while still
3747        // respecting the LIMIT, OFFSET, and projections.
3748        //
3749        // Note: It feels a bit bad always calculating the response size, but we almost
3750        // always need it to either check the `max_returned_query_size`, or for reporting
3751        // in the query history.
3752        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3753
3754        // Bail if we would end up returning more data to the client than they can support.
3755        if let Some(max) = &mut self.remaining_max_returned_query_size {
3756            if let Some(remaining) = max.checked_sub(response_size.cast_into()) {
3757                *max = remaining;
3758            } else {
3759                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3760                return Err(format!("total result exceeds max size of {max_bytes}"));
3761            }
3762        }
3763
3764        Ok(iter)
3765    }
3766}
3767
3768/// Compares two rows columnwise, using [compare_columns].
3769///
3770/// Compared to the naive implementation, this allows sharing some memory and implements some
3771/// optimizations that avoid unnecessary row unpacking.
3772#[derive(Debug, Clone)]
3773pub struct RowComparator<O: AsRef<[ColumnOrder]> = Vec<ColumnOrder>> {
3774    order: O,
3775    /// Invariant: all column references in the order are less than this limit.
3776    /// This allows for partial unpacking of rows.
3777    limit: usize,
3778    left_vec: RefCell<DatumVec>,
3779    right_vec: RefCell<DatumVec>,
3780}
3781
3782impl<O: AsRef<[ColumnOrder]>> RowComparator<O> {
3783    /// Create a new row comparator from the given column ordering.
3784    pub fn new(order: O) -> Self {
3785        let limit = order
3786            .as_ref()
3787            .iter()
3788            .map(|o| o.column + 1)
3789            .max()
3790            .unwrap_or(0);
3791        Self {
3792            order,
3793            limit,
3794            left_vec: Default::default(),
3795            right_vec: Default::default(),
3796        }
3797    }
3798
3799    /// Compare two (references to) rows.
3800    pub fn compare_rows(
3801        &self,
3802        left_row: &RowRef,
3803        right_row: &RowRef,
3804        tiebreaker: impl Fn() -> Ordering,
3805    ) -> Ordering {
3806        let order = if self.limit == 0 {
3807            Ordering::Equal
3808        } else {
3809            // These borrows should never fail, since this struct is non-sync and this function
3810            // is non-recursive.
3811            let mut left_ref = self.left_vec.borrow_mut();
3812            let mut right_ref = self.right_vec.borrow_mut();
3813            let left_cols = left_ref.borrow_with_limit(left_row, self.limit);
3814            let right_cols = right_ref.borrow_with_limit(right_row, self.limit);
3815            compare_columns(self.order.as_ref(), &left_cols, &right_cols, || {
3816                Ordering::Equal
3817            })
3818        };
3819        // Tiebreak without the vecs borrowed, in case that recursively invokes this function.
3820        order.then_with(tiebreaker)
3821    }
3822}
3823
3824/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3825/// ordering, call `tiebreaker`.
3826pub fn compare_columns<F>(
3827    order: &[ColumnOrder],
3828    left: &[Datum],
3829    right: &[Datum],
3830    tiebreaker: F,
3831) -> Ordering
3832where
3833    F: Fn() -> Ordering,
3834{
3835    for order in order {
3836        let cmp = match (&left[order.column], &right[order.column]) {
3837            (Datum::Null, Datum::Null) => Ordering::Equal,
3838            (Datum::Null, _) => {
3839                if order.nulls_last {
3840                    Ordering::Greater
3841                } else {
3842                    Ordering::Less
3843                }
3844            }
3845            (_, Datum::Null) => {
3846                if order.nulls_last {
3847                    Ordering::Less
3848                } else {
3849                    Ordering::Greater
3850                }
3851            }
3852            (lval, rval) => {
3853                if order.desc {
3854                    rval.cmp(lval)
3855                } else {
3856                    lval.cmp(rval)
3857                }
3858            }
3859        };
3860        if cmp != Ordering::Equal {
3861            return cmp;
3862        }
3863    }
3864    tiebreaker()
3865}
3866
3867/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3868/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3869///
3870/// Window frames define a subset of the partition , and only a subset of
3871/// window functions make use of the window frame.
3872#[derive(
3873    Debug,
3874    Clone,
3875    Eq,
3876    PartialEq,
3877    Ord,
3878    PartialOrd,
3879    Serialize,
3880    Deserialize,
3881    Hash,
3882    MzReflect
3883)]
3884pub struct WindowFrame {
3885    /// ROWS, RANGE or GROUPS
3886    pub units: WindowFrameUnits,
3887    /// Where the frame starts
3888    pub start_bound: WindowFrameBound,
3889    /// Where the frame ends
3890    pub end_bound: WindowFrameBound,
3891}
3892
3893impl Display for WindowFrame {
3894    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3895        write!(
3896            f,
3897            "{} between {} and {}",
3898            self.units, self.start_bound, self.end_bound
3899        )
3900    }
3901}
3902
3903impl WindowFrame {
3904    /// Return the default window frame used when one is not explicitly defined
3905    pub fn default() -> Self {
3906        WindowFrame {
3907            units: WindowFrameUnits::Range,
3908            start_bound: WindowFrameBound::UnboundedPreceding,
3909            end_bound: WindowFrameBound::CurrentRow,
3910        }
3911    }
3912
3913    fn includes_current_row(&self) -> bool {
3914        use WindowFrameBound::*;
3915        match self.start_bound {
3916            UnboundedPreceding => match self.end_bound {
3917                UnboundedPreceding => false,
3918                OffsetPreceding(0) => true,
3919                OffsetPreceding(_) => false,
3920                CurrentRow => true,
3921                OffsetFollowing(_) => true,
3922                UnboundedFollowing => true,
3923            },
3924            OffsetPreceding(0) => match self.end_bound {
3925                UnboundedPreceding => unreachable!(),
3926                OffsetPreceding(0) => true,
3927                // Any nonzero offsets here will create an empty window
3928                OffsetPreceding(_) => false,
3929                CurrentRow => true,
3930                OffsetFollowing(_) => true,
3931                UnboundedFollowing => true,
3932            },
3933            OffsetPreceding(_) => match self.end_bound {
3934                UnboundedPreceding => unreachable!(),
3935                // Window ends at the current row
3936                OffsetPreceding(0) => true,
3937                OffsetPreceding(_) => false,
3938                CurrentRow => true,
3939                OffsetFollowing(_) => true,
3940                UnboundedFollowing => true,
3941            },
3942            CurrentRow => true,
3943            OffsetFollowing(0) => match self.end_bound {
3944                UnboundedPreceding => unreachable!(),
3945                OffsetPreceding(_) => unreachable!(),
3946                CurrentRow => unreachable!(),
3947                OffsetFollowing(_) => true,
3948                UnboundedFollowing => true,
3949            },
3950            OffsetFollowing(_) => match self.end_bound {
3951                UnboundedPreceding => unreachable!(),
3952                OffsetPreceding(_) => unreachable!(),
3953                CurrentRow => unreachable!(),
3954                OffsetFollowing(_) => false,
3955                UnboundedFollowing => false,
3956            },
3957            UnboundedFollowing => false,
3958        }
3959    }
3960}
3961
3962/// Describe how frame bounds are interpreted
3963#[derive(
3964    Debug,
3965    Clone,
3966    Eq,
3967    PartialEq,
3968    Ord,
3969    PartialOrd,
3970    Serialize,
3971    Deserialize,
3972    Hash,
3973    MzReflect
3974)]
3975pub enum WindowFrameUnits {
3976    /// Each row is treated as the unit of work for bounds
3977    Rows,
3978    /// Each peer group is treated as the unit of work for bounds,
3979    /// and offset-based bounds use the value of the ORDER BY expression
3980    Range,
3981    /// Each peer group is treated as the unit of work for bounds.
3982    /// Groups is currently not supported, and it is rejected during planning.
3983    Groups,
3984}
3985
3986impl Display for WindowFrameUnits {
3987    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3988        match self {
3989            WindowFrameUnits::Rows => write!(f, "rows"),
3990            WindowFrameUnits::Range => write!(f, "range"),
3991            WindowFrameUnits::Groups => write!(f, "groups"),
3992        }
3993    }
3994}
3995
3996/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3997///
3998/// The order between frame bounds is significant, as Postgres enforces
3999/// some restrictions there.
4000#[derive(
4001    Debug,
4002    Clone,
4003    Serialize,
4004    Deserialize,
4005    PartialEq,
4006    Eq,
4007    Hash,
4008    MzReflect,
4009    PartialOrd,
4010    Ord
4011)]
4012pub enum WindowFrameBound {
4013    /// `UNBOUNDED PRECEDING`
4014    UnboundedPreceding,
4015    /// `<N> PRECEDING`
4016    OffsetPreceding(u64),
4017    /// `CURRENT ROW`
4018    CurrentRow,
4019    /// `<N> FOLLOWING`
4020    OffsetFollowing(u64),
4021    /// `UNBOUNDED FOLLOWING`.
4022    UnboundedFollowing,
4023}
4024
4025impl Display for WindowFrameBound {
4026    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4027        match self {
4028            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4029            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4030            WindowFrameBound::CurrentRow => write!(f, "current row"),
4031            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4032            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4033        }
4034    }
4035}
4036
4037/// Maximum iterations for a LetRec.
4038#[derive(
4039    Debug,
4040    Clone,
4041    Copy,
4042    PartialEq,
4043    Eq,
4044    PartialOrd,
4045    Ord,
4046    Hash,
4047    Serialize,
4048    Deserialize
4049)]
4050pub struct LetRecLimit {
4051    /// Maximum number of iterations to evaluate.
4052    pub max_iters: NonZeroU64,
4053    /// Whether to throw an error when reaching the above limit.
4054    /// If true, we simply use the current contents of each Id as the final result.
4055    pub return_at_limit: bool,
4056}
4057
4058impl LetRecLimit {
4059    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4060    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4061        limits
4062            .iter()
4063            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4064            .min()
4065    }
4066
4067    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4068    /// WMR without ERROR AT or RETURN AT.
4069    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4070}
4071
4072impl Display for LetRecLimit {
4073    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4074        write!(f, "[recursion_limit={}", self.max_iters)?;
4075        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4076            write!(f, ", return_at_limit")?;
4077        }
4078        write!(f, "]")
4079    }
4080}
4081
4082/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4083/// (See comment in MirRelationExpr::Get.)
4084#[derive(
4085    Clone,
4086    Debug,
4087    Eq,
4088    PartialEq,
4089    Ord,
4090    PartialOrd,
4091    Serialize,
4092    Deserialize,
4093    Hash
4094)]
4095pub enum AccessStrategy {
4096    /// It's either a local Get (a CTE), or unknown at the time.
4097    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4098    /// one of the other variants.
4099    UnknownOrLocal,
4100    /// The Get will read from Persist.
4101    Persist,
4102    /// The Get will read from an index or indexes: (index id, how the index will be used).
4103    Index(Vec<(GlobalId, IndexUsageType)>),
4104    /// The Get will read a collection that is computed by the same dataflow, but in a different
4105    /// `BuildDesc` in `objects_to_build`.
4106    SameDataflow,
4107}
4108
4109#[cfg(test)]
4110mod tests {
4111    use std::num::NonZeroUsize;
4112
4113    use mz_repr::explain::text::text_string_at;
4114
4115    use crate::explain::HumanizedExplain;
4116
4117    use super::*;
4118
4119    #[mz_ore::test]
4120    fn test_row_set_finishing_as_text() {
4121        let finishing = RowSetFinishing {
4122            order_by: vec![ColumnOrder {
4123                column: 4,
4124                desc: true,
4125                nulls_last: true,
4126            }],
4127            limit: Some(NonNeg::try_from(7).unwrap()),
4128            offset: Default::default(),
4129            project: vec![1, 3, 4, 5],
4130        };
4131
4132        let mode = HumanizedExplain::new(false);
4133        let expr = mode.expr(&finishing, None);
4134
4135        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4136
4137        let exp = {
4138            use mz_ore::fmt::FormatBuffer;
4139            let mut s = String::new();
4140            write!(&mut s, "Finish");
4141            write!(&mut s, " order_by=[#4 desc nulls_last]");
4142            write!(&mut s, " limit=7");
4143            write!(&mut s, " output=[#1, #3..=#5]");
4144            writeln!(&mut s, "");
4145            s
4146        };
4147
4148        assert_eq!(act, exp);
4149    }
4150
4151    #[mz_ore::test]
4152    fn test_row_set_finishing_incremental_max_returned_query_size() {
4153        let row = Row::pack_slice(&[Datum::String("hello")]);
4154        let row_size = u64::cast_from(row.data().len());
4155        let diff = NonZeroUsize::new(1).unwrap();
4156        let batch = RowCollection::new(vec![(row, diff)], &[]);
4157
4158        // Set max_returned_query_size to hold exactly 2 batches worth of rows.
4159        let mut finishing = RowSetFinishingIncremental::new(0, None, vec![0], Some(row_size * 2));
4160
4161        let max_result_size = u64::MAX;
4162
4163        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4164        assert!(r.is_ok());
4165        assert_eq!(finishing.remaining_max_returned_query_size, Some(row_size));
4166
4167        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4168        assert!(r.is_ok());
4169        assert_eq!(finishing.remaining_max_returned_query_size, Some(0));
4170
4171        let r = finishing.finish_incremental_inner(batch, max_result_size);
4172        assert!(r.unwrap_err().contains("total result exceeds max size"));
4173    }
4174}
4175
4176/// An iterator over AST structures, which calls out nodes in difference.
4177///
4178/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4179/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4180/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4181/// to compare two ASTs.
4182mod structured_diff {
4183
4184    use super::MirRelationExpr;
4185    use itertools::Itertools;
4186
4187    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4188    pub struct MreDiff<'a> {
4189        /// Pairs of expressions that must still be compared.
4190        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4191    }
4192
4193    impl<'a> MreDiff<'a> {
4194        /// Create a new `MirRelationExpr` structured difference.
4195        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4196            MreDiff {
4197                todo: vec![(expr1, expr2)],
4198            }
4199        }
4200    }
4201
4202    impl<'a> Iterator for MreDiff<'a> {
4203        // Pairs of expressions that do not match.
4204        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4205
4206        fn next(&mut self) -> Option<Self::Item> {
4207            while let Some((expr1, expr2)) = self.todo.pop() {
4208                match (expr1, expr2) {
4209                    (
4210                        MirRelationExpr::Constant {
4211                            rows: rows1,
4212                            typ: typ1,
4213                        },
4214                        MirRelationExpr::Constant {
4215                            rows: rows2,
4216                            typ: typ2,
4217                        },
4218                    ) => {
4219                        if rows1 != rows2 || typ1 != typ2 {
4220                            return Some((expr1, expr2));
4221                        }
4222                    }
4223                    (
4224                        MirRelationExpr::Get {
4225                            id: id1,
4226                            typ: typ1,
4227                            access_strategy: as1,
4228                        },
4229                        MirRelationExpr::Get {
4230                            id: id2,
4231                            typ: typ2,
4232                            access_strategy: as2,
4233                        },
4234                    ) => {
4235                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4236                            return Some((expr1, expr2));
4237                        }
4238                    }
4239                    (
4240                        MirRelationExpr::Let {
4241                            id: id1,
4242                            body: body1,
4243                            value: value1,
4244                        },
4245                        MirRelationExpr::Let {
4246                            id: id2,
4247                            body: body2,
4248                            value: value2,
4249                        },
4250                    ) => {
4251                        if id1 != id2 {
4252                            return Some((expr1, expr2));
4253                        } else {
4254                            self.todo.push((body1, body2));
4255                            self.todo.push((value1, value2));
4256                        }
4257                    }
4258                    (
4259                        MirRelationExpr::LetRec {
4260                            ids: ids1,
4261                            body: body1,
4262                            values: values1,
4263                            limits: limits1,
4264                        },
4265                        MirRelationExpr::LetRec {
4266                            ids: ids2,
4267                            body: body2,
4268                            values: values2,
4269                            limits: limits2,
4270                        },
4271                    ) => {
4272                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4273                            return Some((expr1, expr2));
4274                        } else {
4275                            self.todo.push((body1, body2));
4276                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4277                        }
4278                    }
4279                    (
4280                        MirRelationExpr::Project {
4281                            outputs: outputs1,
4282                            input: input1,
4283                        },
4284                        MirRelationExpr::Project {
4285                            outputs: outputs2,
4286                            input: input2,
4287                        },
4288                    ) => {
4289                        if outputs1 != outputs2 {
4290                            return Some((expr1, expr2));
4291                        } else {
4292                            self.todo.push((input1, input2));
4293                        }
4294                    }
4295                    (
4296                        MirRelationExpr::Map {
4297                            scalars: scalars1,
4298                            input: input1,
4299                        },
4300                        MirRelationExpr::Map {
4301                            scalars: scalars2,
4302                            input: input2,
4303                        },
4304                    ) => {
4305                        if scalars1 != scalars2 {
4306                            return Some((expr1, expr2));
4307                        } else {
4308                            self.todo.push((input1, input2));
4309                        }
4310                    }
4311                    (
4312                        MirRelationExpr::Filter {
4313                            predicates: predicates1,
4314                            input: input1,
4315                        },
4316                        MirRelationExpr::Filter {
4317                            predicates: predicates2,
4318                            input: input2,
4319                        },
4320                    ) => {
4321                        if predicates1 != predicates2 {
4322                            return Some((expr1, expr2));
4323                        } else {
4324                            self.todo.push((input1, input2));
4325                        }
4326                    }
4327                    (
4328                        MirRelationExpr::FlatMap {
4329                            input: input1,
4330                            func: func1,
4331                            exprs: exprs1,
4332                        },
4333                        MirRelationExpr::FlatMap {
4334                            input: input2,
4335                            func: func2,
4336                            exprs: exprs2,
4337                        },
4338                    ) => {
4339                        if func1 != func2 || exprs1 != exprs2 {
4340                            return Some((expr1, expr2));
4341                        } else {
4342                            self.todo.push((input1, input2));
4343                        }
4344                    }
4345                    (
4346                        MirRelationExpr::Join {
4347                            inputs: inputs1,
4348                            equivalences: eq1,
4349                            implementation: impl1,
4350                        },
4351                        MirRelationExpr::Join {
4352                            inputs: inputs2,
4353                            equivalences: eq2,
4354                            implementation: impl2,
4355                        },
4356                    ) => {
4357                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4358                            return Some((expr1, expr2));
4359                        } else {
4360                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4361                        }
4362                    }
4363                    (
4364                        MirRelationExpr::Reduce {
4365                            aggregates: aggregates1,
4366                            input: inputs1,
4367                            group_key: gk1,
4368                            monotonic: m1,
4369                            expected_group_size: egs1,
4370                        },
4371                        MirRelationExpr::Reduce {
4372                            aggregates: aggregates2,
4373                            input: inputs2,
4374                            group_key: gk2,
4375                            monotonic: m2,
4376                            expected_group_size: egs2,
4377                        },
4378                    ) => {
4379                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4380                            return Some((expr1, expr2));
4381                        } else {
4382                            self.todo.push((inputs1, inputs2));
4383                        }
4384                    }
4385                    (
4386                        MirRelationExpr::TopK {
4387                            group_key: gk1,
4388                            order_key: order1,
4389                            input: input1,
4390                            limit: l1,
4391                            offset: o1,
4392                            monotonic: m1,
4393                            expected_group_size: egs1,
4394                        },
4395                        MirRelationExpr::TopK {
4396                            group_key: gk2,
4397                            order_key: order2,
4398                            input: input2,
4399                            limit: l2,
4400                            offset: o2,
4401                            monotonic: m2,
4402                            expected_group_size: egs2,
4403                        },
4404                    ) => {
4405                        if order1 != order2
4406                            || gk1 != gk2
4407                            || l1 != l2
4408                            || o1 != o2
4409                            || m1 != m2
4410                            || egs1 != egs2
4411                        {
4412                            return Some((expr1, expr2));
4413                        } else {
4414                            self.todo.push((input1, input2));
4415                        }
4416                    }
4417                    (
4418                        MirRelationExpr::Negate { input: input1 },
4419                        MirRelationExpr::Negate { input: input2 },
4420                    ) => {
4421                        self.todo.push((input1, input2));
4422                    }
4423                    (
4424                        MirRelationExpr::Threshold { input: input1 },
4425                        MirRelationExpr::Threshold { input: input2 },
4426                    ) => {
4427                        self.todo.push((input1, input2));
4428                    }
4429                    (
4430                        MirRelationExpr::Union {
4431                            base: base1,
4432                            inputs: inputs1,
4433                        },
4434                        MirRelationExpr::Union {
4435                            base: base2,
4436                            inputs: inputs2,
4437                        },
4438                    ) => {
4439                        if inputs1.len() != inputs2.len() {
4440                            return Some((expr1, expr2));
4441                        } else {
4442                            self.todo.push((base1, base2));
4443                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4444                        }
4445                    }
4446                    (
4447                        MirRelationExpr::ArrangeBy {
4448                            keys: keys1,
4449                            input: input1,
4450                        },
4451                        MirRelationExpr::ArrangeBy {
4452                            keys: keys2,
4453                            input: input2,
4454                        },
4455                    ) => {
4456                        if keys1 != keys2 {
4457                            return Some((expr1, expr2));
4458                        } else {
4459                            self.todo.push((input1, input2));
4460                        }
4461                    }
4462                    _ => {
4463                        return Some((expr1, expr2));
4464                    }
4465                }
4466            }
4467            None
4468        }
4469    }
4470}