Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cell::RefCell;
13use std::cmp::{Ordering, max};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17use std::hash::{DefaultHasher, Hash, Hasher};
18use std::num::NonZeroU64;
19use std::time::Instant;
20
21use bytesize::ByteSize;
22use columnation::{Columnation, CopyRegion};
23use itertools::Itertools;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, CastInto};
26use mz_ore::collections::CollectionExt;
27use mz_ore::id_gen::IdGen;
28use mz_ore::metrics::Histogram;
29use mz_ore::num::NonNeg;
30use mz_ore::soft_assert_no_log;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::Indent;
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39    ColumnName, Datum, DatumVec, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
40    ReprScalarType, Row, RowIterator, RowRef, SqlColumnType, SqlRelationType, SqlScalarType,
41};
42use serde::{Deserialize, Serialize};
43
44use crate::Id::Local;
45use crate::explain::{HumanizedExpr, HumanizerMode};
46use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
47use crate::row::{RowCollection, RowCollectionIter};
48use crate::scalar::columns::Columns;
49use crate::scalar::func::variadic::{
50    JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
51};
52use crate::visit::{Visit, VisitChildren};
53use crate::{
54    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
62///
63/// The recursion limit must be large enough to accommodate for the linear representation
64/// of some pathological but frequently occurring query fragments.
65///
66/// For example, in MIR we could have long chains of
67/// - (1) `Let` bindings,
68/// - (2) `CallBinary` calls with associative functions such as `+`
69///
70/// Until we fix those, we need to stick with the larger recursion limit.
71pub const RECURSION_LIMIT: usize = 2048;
72
73/// A trait for types that describe how to build a collection.
74pub trait CollectionPlan {
75    /// Collects the set of global identifiers from dataflows referenced in Get.
76    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
77
78    /// Returns the set of global identifiers from dataflows referenced in Get.
79    ///
80    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
81    fn depends_on(&self) -> BTreeSet<GlobalId> {
82        let mut out = BTreeSet::new();
83        self.depends_on_into(&mut out);
84        out
85    }
86}
87
88/// An abstract syntax tree which defines a collection.
89///
90/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
91/// written generically enough to avoid run-time compilation work.
92///
93/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
94/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
95/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
96/// one. This is because the reason for the manual implementation is not to change the semantics
97/// from the derived one, but to avoid stack overflows.
98#[allow(clippy::derived_hash_with_manual_eq)]
99#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
100pub enum MirRelationExpr {
101    /// A constant relation containing specified rows.
102    ///
103    /// The runtime memory footprint of this operator is zero.
104    ///
105    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
106    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
107    /// constant folding doesn't remove `ArrangeBy`s.
108    Constant {
109        /// Rows of the constant collection and their multiplicities.
110        rows: Result<Vec<(Row, Diff)>, EvalError>,
111        /// Schema of the collection.
112        typ: ReprRelationType,
113    },
114    /// Get an existing dataflow.
115    ///
116    /// The runtime memory footprint of this operator is zero.
117    Get {
118        /// The identifier for the collection to load.
119        #[mzreflect(ignore)]
120        id: Id,
121        /// Schema of the collection.
122        typ: ReprRelationType,
123        /// If this is a global Get, this will indicate whether we are going to read from Persist or
124        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
125        /// how downstream dataflow operations will use this index is also recorded. This is filled
126        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
127        /// lowering to LIR, but is used only by EXPLAIN.
128        #[mzreflect(ignore)]
129        access_strategy: AccessStrategy,
130    },
131    /// Introduce a temporary dataflow.
132    ///
133    /// The runtime memory footprint of this operator is zero.
134    Let {
135        /// The identifier to be used in `Get` variants to retrieve `value`.
136        #[mzreflect(ignore)]
137        id: LocalId,
138        /// The collection to be bound to `id`.
139        value: Box<MirRelationExpr>,
140        /// The result of the `Let`, evaluated with `id` bound to `value`.
141        body: Box<MirRelationExpr>,
142    },
143    /// Introduce mutually recursive bindings.
144    ///
145    /// Each `LocalId` is immediately bound to an initially empty  collection
146    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
147    /// binding is evaluated using the current contents of each other binding,
148    /// and is refreshed to contain the new evaluation. This process continues
149    /// through all bindings, and repeats as long as changes continue to occur.
150    ///
151    /// The resulting value of the expression is `body` evaluated once in the
152    /// context of the final iterates.
153    ///
154    /// A zero-binding instance can be replaced by `body`.
155    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
156    ///
157    /// The runtime memory footprint of this operator is zero.
158    LetRec {
159        /// The identifiers to be used in `Get` variants to retrieve each `value`.
160        #[mzreflect(ignore)]
161        ids: Vec<LocalId>,
162        /// The collections to be bound to each `id`.
163        values: Vec<MirRelationExpr>,
164        /// Maximum number of iterations, after which we should artificially force a fixpoint.
165        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
166        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
167        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
168        #[mzreflect(ignore)]
169        limits: Vec<Option<LetRecLimit>>,
170        /// The result of the `Let`, evaluated with `id` bound to `value`.
171        body: Box<MirRelationExpr>,
172    },
173    /// Project out some columns from a dataflow
174    ///
175    /// The runtime memory footprint of this operator is zero.
176    Project {
177        /// The source collection.
178        input: Box<MirRelationExpr>,
179        /// Indices of columns to retain.
180        outputs: Vec<usize>,
181    },
182    /// Append new columns to a dataflow
183    ///
184    /// The runtime memory footprint of this operator is zero.
185    Map {
186        /// The source collection.
187        input: Box<MirRelationExpr>,
188        /// Expressions which determine values to append to each row.
189        /// An expression may refer to columns in `input` or
190        /// expressions defined earlier in the vector
191        scalars: Vec<MirScalarExpr>,
192    },
193    /// Like Map, but yields zero-or-more output rows per input row
194    ///
195    /// The runtime memory footprint of this operator is zero.
196    FlatMap {
197        /// The source collection
198        input: Box<MirRelationExpr>,
199        /// The table func to apply
200        func: TableFunc,
201        /// The argument to the table func
202        exprs: Vec<MirScalarExpr>,
203    },
204    /// Keep rows from a dataflow where all the predicates are true
205    ///
206    /// The runtime memory footprint of this operator is zero.
207    Filter {
208        /// The source collection.
209        input: Box<MirRelationExpr>,
210        /// Predicates, each of which must be true.
211        predicates: Vec<MirScalarExpr>,
212    },
213    /// Join several collections, where some columns must be equal.
214    ///
215    /// For further details consult the documentation for [`MirRelationExpr::join`].
216    ///
217    /// The runtime memory footprint of this operator can be proportional to
218    /// the sizes of all inputs and the size of all joins of prefixes.
219    /// This may be reduced due to arrangements available at rendering time.
220    Join {
221        /// A sequence of input relations.
222        inputs: Vec<MirRelationExpr>,
223        /// A sequence of equivalence classes of expressions on the cross product of inputs.
224        ///
225        /// Each equivalence class is a list of scalar expressions, where for each class the
226        /// intended interpretation is that all evaluated expressions should be equal.
227        ///
228        /// Each scalar expression is to be evaluated over the cross-product of all records
229        /// from all inputs. In many cases this may just be column selection from specific
230        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
231        /// from multiple inputs, or just constant literals).
232        equivalences: Vec<Vec<MirScalarExpr>>,
233        /// Join implementation information.
234        #[serde(default)]
235        implementation: JoinImplementation,
236    },
237    /// Group a dataflow by some columns and aggregate over each group
238    ///
239    /// The runtime memory footprint of this operator is at most proportional to the
240    /// number of distinct records in the input and output. The actual requirements
241    /// can be less: the number of distinct inputs to each aggregate, summed across
242    /// each aggregate, plus the output size. For more details consult the code that
243    /// builds the associated dataflow.
244    Reduce {
245        /// The source collection.
246        input: Box<MirRelationExpr>,
247        /// Column indices used to form groups.
248        group_key: Vec<MirScalarExpr>,
249        /// Expressions which determine values to append to each row, after the group keys.
250        aggregates: Vec<AggregateExpr>,
251        /// True iff the input is known to monotonically increase (only addition of records).
252        #[serde(default)]
253        monotonic: bool,
254        /// User hint: expected number of values per group key. Used to optimize physical rendering.
255        #[serde(default)]
256        expected_group_size: Option<u64>,
257    },
258    /// Groups and orders within each group, limiting output.
259    ///
260    /// The runtime memory footprint of this operator is proportional to its input and output.
261    TopK {
262        /// The source collection.
263        input: Box<MirRelationExpr>,
264        /// Column indices used to form groups.
265        group_key: Vec<usize>,
266        /// Column indices used to order rows within groups.
267        order_key: Vec<ColumnOrder>,
268        /// Number of records to retain
269        #[serde(default)]
270        limit: Option<MirScalarExpr>,
271        /// Number of records to skip
272        #[serde(default)]
273        offset: usize,
274        /// True iff the input is known to monotonically increase (only addition of records).
275        #[serde(default)]
276        monotonic: bool,
277        /// User-supplied hint: how many rows will have the same group key.
278        #[serde(default)]
279        expected_group_size: Option<u64>,
280    },
281    /// Return a dataflow where the row counts are negated
282    ///
283    /// The runtime memory footprint of this operator is zero.
284    Negate {
285        /// The source collection.
286        input: Box<MirRelationExpr>,
287    },
288    /// Keep rows from a dataflow where the row counts are positive
289    ///
290    /// The runtime memory footprint of this operator is proportional to its input and output.
291    Threshold {
292        /// The source collection.
293        input: Box<MirRelationExpr>,
294    },
295    /// Adds the frequencies of elements in contained sets.
296    ///
297    /// The runtime memory footprint of this operator is zero.
298    Union {
299        /// A source collection.
300        base: Box<MirRelationExpr>,
301        /// Source collections to union.
302        inputs: Vec<MirRelationExpr>,
303    },
304    /// Technically a no-op. Used to render an index. Will be used to optimize queries
305    /// on finer grain. Each `keys` item represents a different index that should be
306    /// produced from the `keys`.
307    ///
308    /// The runtime memory footprint of this operator is proportional to its input.
309    ArrangeBy {
310        /// The source collection
311        input: Box<MirRelationExpr>,
312        /// Columns to arrange `input` by, in order of decreasing primacy
313        keys: Vec<Vec<MirScalarExpr>>,
314    },
315}
316
317impl PartialEq for MirRelationExpr {
318    fn eq(&self, other: &Self) -> bool {
319        // Capture the result and test it wrt `Ord` implementation in test environments.
320        let result = structured_diff::MreDiff::new(self, other).next().is_none();
321        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
322        result
323    }
324}
325impl Eq for MirRelationExpr {}
326
327impl MirRelationExpr {
328    /// Reports the schema of the relation.
329    ///
330    /// This is the SQL-type parallel of [`Self::typ`]; it is merely
331    /// a wrapper around it, returning a [`SqlRelationType`] instead of
332    /// a [`ReprRelationType`].
333    pub fn sql_typ(&self) -> SqlRelationType {
334        let repr_typ = self.typ();
335        SqlRelationType::from_repr(&repr_typ)
336    }
337
338    /// Reports the repr schema of the relation.
339    ///
340    /// This method determines the type through recursive traversal of the
341    /// relation expression, drawing from the types of base collections.
342    /// As such, this is not an especially cheap method, and should be used
343    /// judiciously.
344    ///
345    /// The relation type is computed incrementally with a recursive post-order
346    /// traversal, that accumulates the input types for the relations yet to be
347    /// visited in `type_stack`.
348    pub fn typ(&self) -> ReprRelationType {
349        let mut type_stack = Vec::new();
350        #[allow(deprecated)]
351        self.visit_pre_post_nolimit(
352            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
353                match &e {
354                    MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
355                    MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
356                    _ => None,
357                }
358            },
359            &mut |e: &MirRelationExpr| {
360                match e {
361                    MirRelationExpr::Let { .. } => {
362                        let body_typ = type_stack.pop().unwrap();
363                        // Insert a dummy relation type for the value, since `typ_with_input_types`
364                        // won't look at it, but expects the relation type of the body to be second.
365                        type_stack.push(ReprRelationType::empty());
366                        type_stack.push(body_typ);
367                    }
368                    MirRelationExpr::LetRec { values, .. } => {
369                        let body_typ = type_stack.pop().unwrap();
370                        type_stack.extend(
371                            std::iter::repeat(ReprRelationType::empty()).take(values.len()),
372                        );
373                        // Insert dummy relation types for the values, since `typ_with_input_types`
374                        // won't look at them, but expects the relation type of the body to be last.
375                        type_stack.push(body_typ);
376                    }
377                    _ => {}
378                }
379                let num_inputs = e.num_inputs();
380                let relation_type =
381                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
382                type_stack.truncate(type_stack.len() - num_inputs);
383                type_stack.push(relation_type);
384            },
385        );
386        assert_eq!(type_stack.len(), 1);
387        type_stack.pop().unwrap()
388    }
389
390    /// Reports the repr schema of the relation given the repr schema of the input relations.
391    pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
392        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
393        let unique_keys = self.keys_with_input_keys(
394            input_types.iter().map(|i| i.arity()),
395            input_types.iter().map(|i| &i.keys),
396        );
397        ReprRelationType::new(column_types).with_keys(unique_keys)
398    }
399
400    /// Reports the column types of the relation given the column types of the
401    /// input relations.
402    ///
403    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
404    /// variant is returned.
405    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
406    where
407        I: Iterator<Item = &'a Vec<ReprColumnType>>,
408    {
409        match self.try_col_with_input_cols(input_types) {
410            Ok(col_types) => col_types,
411            Err(err) => panic!("{err}"),
412        }
413    }
414
415    /// Reports the column types of the relation given the column types of the input relations.
416    ///
417    /// `input_types` is required to contain the column types for the input relations of
418    /// the current relation in the same order as they are visited by `try_visit_children`
419    /// method, even though not all may be used for computing the schema of the
420    /// current relation. For example, `Let` expects two input types, one for the
421    /// value relation and one for the body, in that order, but only the one for the
422    /// body is used to determine the type of the `Let` relation.
423    ///
424    /// It is meant to be used during post-order traversals to compute column types
425    /// incrementally.
426    pub fn try_col_with_input_cols<'a, I>(
427        &self,
428        mut input_types: I,
429    ) -> Result<Vec<ReprColumnType>, String>
430    where
431        I: Iterator<Item = &'a Vec<ReprColumnType>>,
432    {
433        use MirRelationExpr::*;
434
435        let col_types = match self {
436            Constant { rows, typ } => {
437                let mut col_types = typ.column_types.clone();
438                let mut seen_null = vec![false; typ.arity()];
439                if let Ok(rows) = rows {
440                    for (row, _diff) in rows {
441                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
442                            if datum.is_null() {
443                                seen_null[i] = true;
444                            }
445                        }
446                    }
447                }
448                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
449                    if !seen_null {
450                        col_types[i].nullable = false;
451                    } else {
452                        assert!(col_types[i].nullable);
453                    }
454                }
455                col_types
456            }
457            Get { typ, .. } => typ.column_types.clone(),
458            Project { outputs, .. } => {
459                let input = input_types.next().unwrap();
460                outputs.iter().map(|&i| input[i].clone()).collect()
461            }
462            Map { scalars, .. } => {
463                let mut result = input_types.next().unwrap().clone();
464                for scalar in scalars.iter() {
465                    result.push(scalar.typ(&result))
466                }
467                result
468            }
469            FlatMap { func, .. } => {
470                let mut result = input_types.next().unwrap().clone();
471                result.extend(
472                    func.output_sql_type()
473                        .column_types
474                        .iter()
475                        .map(ReprColumnType::from),
476                );
477                result
478            }
479            Filter { predicates, .. } => {
480                let mut result = input_types.next().unwrap().clone();
481
482                // Set as nonnull any columns where null values would cause
483                // any predicate to evaluate to null.
484                for column in non_nullable_columns(predicates) {
485                    result[column].nullable = false;
486                }
487                result
488            }
489            Join { equivalences, .. } => {
490                // Concatenate input column types
491                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
492                // In an equivalence class, if any column is non-null, then make all non-null
493                for equivalence in equivalences {
494                    let col_inds = equivalence
495                        .iter()
496                        .filter_map(|expr| match expr {
497                            MirScalarExpr::Column(col, _name) => Some(*col),
498                            _ => None,
499                        })
500                        .collect_vec();
501                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
502                        for i in col_inds {
503                            types.get_mut(i).unwrap().nullable = false;
504                        }
505                    }
506                }
507                types
508            }
509            Reduce {
510                group_key,
511                aggregates,
512                ..
513            } => {
514                let input = input_types.next().unwrap();
515                group_key
516                    .iter()
517                    .map(|e| e.typ(input))
518                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
519                    .collect()
520            }
521            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
522                input_types.next().unwrap().clone()
523            }
524            Let { .. } => {
525                // skip over the input types for `value`.
526                input_types.nth(1).unwrap().clone()
527            }
528            LetRec { values, .. } => {
529                // skip over the input types for `values`.
530                input_types.nth(values.len()).unwrap().clone()
531            }
532            Union { .. } => {
533                let mut result = input_types.next().unwrap().clone();
534                for input_col_types in input_types {
535                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
536                        *base_col = base_col
537                            .union(col)
538                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
539                    }
540                }
541                result
542            }
543        };
544
545        Ok(col_types)
546    }
547
548    /// Reports the unique keys of the relation given the arities and the unique
549    /// keys of the input relations.
550    ///
551    /// `input_arities` and `input_keys` are required to contain the
552    /// corresponding info for the input relations of
553    /// the current relation in the same order as they are visited by `try_visit_children`
554    /// method, even though not all may be used for computing the schema of the
555    /// current relation. For example, `Let` expects two input types, one for the
556    /// value relation and one for the body, in that order, but only the one for the
557    /// body is used to determine the type of the `Let` relation.
558    ///
559    /// It is meant to be used during post-order traversals to compute unique keys
560    /// incrementally.
561    pub fn keys_with_input_keys<'a, I, J>(
562        &self,
563        mut input_arities: I,
564        mut input_keys: J,
565    ) -> Vec<Vec<usize>>
566    where
567        I: Iterator<Item = usize>,
568        J: Iterator<Item = &'a Vec<Vec<usize>>>,
569    {
570        use MirRelationExpr::*;
571
572        let mut keys = match self {
573            Constant {
574                rows: Ok(rows),
575                typ,
576            } => {
577                let n_cols = typ.arity();
578                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
579                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
580                for (row, diff) in rows {
581                    for (i, datum) in row.iter().enumerate() {
582                        if datum != Datum::Dummy {
583                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
584                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
585                                if is_dupe {
586                                    unique_values_per_col[i] = None;
587                                }
588                            }
589                        }
590                    }
591                }
592                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
593                    vec![vec![]]
594                } else {
595                    // XXX - Multi-column keys are not detected.
596                    typ.keys
597                        .iter()
598                        .cloned()
599                        .chain(
600                            unique_values_per_col
601                                .into_iter()
602                                .enumerate()
603                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
604                                .map(|(idx, _)| vec![idx]),
605                        )
606                        .collect()
607                }
608            }
609            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
610            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
611            Let { .. } => {
612                // skip over the unique keys for value
613                input_keys.nth(1).unwrap().clone()
614            }
615            LetRec { values, .. } => {
616                // skip over the unique keys for value
617                input_keys.nth(values.len()).unwrap().clone()
618            }
619            Project { outputs, .. } => {
620                let input = input_keys.next().unwrap();
621                input
622                    .iter()
623                    .filter_map(|key_set| {
624                        if key_set.iter().all(|k| outputs.contains(k)) {
625                            Some(
626                                key_set
627                                    .iter()
628                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
629                                    .collect(),
630                            )
631                        } else {
632                            None
633                        }
634                    })
635                    .collect()
636            }
637            Map { scalars, .. } => {
638                let mut remappings = Vec::new();
639                let arity = input_arities.next().unwrap();
640                for (column, scalar) in scalars.iter().enumerate() {
641                    // assess whether the scalar preserves uniqueness,
642                    // and could participate in a key!
643
644                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
645                        match expr {
646                            MirScalarExpr::CallUnary { func, expr } => {
647                                if func.preserves_uniqueness() {
648                                    uniqueness(expr)
649                                } else {
650                                    None
651                                }
652                            }
653                            MirScalarExpr::Column(c, _name) => Some(*c),
654                            _ => None,
655                        }
656                    }
657
658                    if let Some(c) = uniqueness(scalar) {
659                        remappings.push((c, column + arity));
660                    }
661                }
662
663                let mut result = input_keys.next().unwrap().clone();
664                let mut new_keys = Vec::new();
665                // Any column in `remappings` could be replaced in a key
666                // by the corresponding c. This could lead to combinatorial
667                // explosion using our current representation, so we wont
668                // do that. Instead, we'll handle the case of one remapping.
669                if remappings.len() == 1 {
670                    let (old, new) = remappings.pop().unwrap();
671                    for key in &result {
672                        if key.contains(&old) {
673                            let mut new_key: Vec<usize> =
674                                key.iter().cloned().filter(|k| k != &old).collect();
675                            new_key.push(new);
676                            new_key.sort_unstable();
677                            new_keys.push(new_key);
678                        }
679                    }
680                    result.append(&mut new_keys);
681                }
682                result
683            }
684            FlatMap { .. } => {
685                // FlatMap can add duplicate rows, so input keys are no longer
686                // valid
687                vec![]
688            }
689            Negate { .. } => {
690                // Although negate may have distinct records for each key,
691                // the multiplicity is -1 rather than 1. This breaks many
692                // of the optimization uses of "keys".
693                vec![]
694            }
695            Filter { predicates, .. } => {
696                // A filter inherits the keys of its input unless the filters
697                // have reduced the input to a single row, in which case the
698                // keys of the input are `()`.
699                let mut input = input_keys.next().unwrap().clone();
700
701                if !input.is_empty() {
702                    // Track columns equated to literals, which we can prune.
703                    let mut cols_equal_to_literal = BTreeSet::new();
704
705                    // Perform union find on `col1 = col2` to establish
706                    // connected components of equated columns. Absent any
707                    // equalities, this will be `0 .. #c` (where #c is the
708                    // greatest column referenced by a predicate), but each
709                    // equality will orient the root of the greater to the root
710                    // of the lesser.
711                    let mut union_find = Vec::new();
712
713                    for expr in predicates.iter() {
714                        if let MirScalarExpr::CallBinary {
715                            func: crate::BinaryFunc::Eq(_),
716                            expr1,
717                            expr2,
718                        } = expr
719                        {
720                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
721                                if expr2.is_literal_ok() {
722                                    cols_equal_to_literal.insert(c);
723                                }
724                            }
725                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
726                                if expr1.is_literal_ok() {
727                                    cols_equal_to_literal.insert(c);
728                                }
729                            }
730                            // Perform union-find to equate columns.
731                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
732                                if c1 != c2 {
733                                    // Ensure union_find has entries up to
734                                    // max(c1, c2) by filling up missing
735                                    // positions with identity mappings.
736                                    while union_find.len() <= std::cmp::max(c1, c2) {
737                                        union_find.push(union_find.len());
738                                    }
739                                    let mut r1 = c1; // Find the representative column of [c1].
740                                    while r1 != union_find[r1] {
741                                        assert!(union_find[r1] < r1);
742                                        r1 = union_find[r1];
743                                    }
744                                    let mut r2 = c2; // Find the representative column of [c2].
745                                    while r2 != union_find[r2] {
746                                        assert!(union_find[r2] < r2);
747                                        r2 = union_find[r2];
748                                    }
749                                    // Union [c1] and [c2] by pointing the
750                                    // larger to the smaller representative (we
751                                    // update the remaining equivalence class
752                                    // members only once after this for-loop).
753                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
754                                }
755                            }
756                        }
757                    }
758
759                    // Complete union-find by pointing each element at its representative column.
760                    for i in 0..union_find.len() {
761                        // Iteration not required, as each prior already references the right column.
762                        union_find[i] = union_find[union_find[i]];
763                    }
764
765                    // Remove columns bound to literals, and remap columns equated to earlier columns.
766                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
767                    for key_set in &mut input {
768                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
769                        for col in key_set.iter_mut() {
770                            if let Some(equiv) = union_find.get(*col) {
771                                *col = *equiv;
772                            }
773                        }
774                        key_set.sort();
775                        key_set.dedup();
776                    }
777                    input.sort();
778                    input.dedup();
779
780                    // Expand out each key to each of its equivalent forms.
781                    // Each instance of `col` can be replaced by any equivalent column.
782                    // This has the potential to result in exponentially sized number of unique keys,
783                    // and in the future we should probably maintain unique keys modulo equivalence.
784
785                    // First, compute an inverse map from each representative
786                    // column `sub` to all other equivalent columns `col`.
787                    let mut subs = Vec::new();
788                    for (col, sub) in union_find.iter().enumerate() {
789                        if *sub != col {
790                            assert!(*sub < col);
791                            while subs.len() <= *sub {
792                                subs.push(Vec::new());
793                            }
794                            subs[*sub].push(col);
795                        }
796                    }
797                    // For each column, substitute for it in each occurrence.
798                    let mut to_add = Vec::new();
799                    for (col, subs) in subs.iter().enumerate() {
800                        if !subs.is_empty() {
801                            for key_set in input.iter() {
802                                if key_set.contains(&col) {
803                                    let mut to_extend = key_set.clone();
804                                    to_extend.retain(|c| c != &col);
805                                    for sub in subs {
806                                        to_extend.push(*sub);
807                                        to_add.push(to_extend.clone());
808                                        to_extend.pop();
809                                    }
810                                }
811                            }
812                        }
813                        // No deduplication, as we cannot introduce duplicates.
814                        input.append(&mut to_add);
815                    }
816                    for key_set in input.iter_mut() {
817                        key_set.sort();
818                        key_set.dedup();
819                    }
820                }
821                input
822            }
823            Join { equivalences, .. } => {
824                // It is important the `new_from_input_arities` constructor is
825                // used. Otherwise, Materialize may potentially end up in an
826                // infinite loop.
827                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
828
829                input_mapper.global_keys(input_keys, equivalences)
830            }
831            Reduce { group_key, .. } => {
832                // The group key should form a key, but we might already have
833                // keys that are subsets of the group key, and should retain
834                // those instead, if so.
835                let mut result = Vec::new();
836                for key_set in input_keys.next().unwrap() {
837                    if key_set
838                        .iter()
839                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
840                    {
841                        result.push(
842                            key_set
843                                .iter()
844                                .map(|i| {
845                                    group_key
846                                        .iter()
847                                        .position(|k| k == &MirScalarExpr::column(*i))
848                                        .unwrap()
849                                })
850                                .collect::<Vec<_>>(),
851                        );
852                    }
853                }
854                if result.is_empty() {
855                    result.push((0..group_key.len()).collect());
856                }
857                result
858            }
859            TopK {
860                group_key, limit, ..
861            } => {
862                // If `limit` is `Some(1)` then the group key will become
863                // a unique key, as there will be only one record with that key.
864                let mut result = input_keys.next().unwrap().clone();
865                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
866                    result.push(group_key.clone())
867                }
868                result
869            }
870            Union { base, inputs } => {
871                // Generally, unions do not have any unique keys, because
872                // each input might duplicate some. However, there is at
873                // least one idiomatic structure that does preserve keys,
874                // which results from SQL aggregations that must populate
875                // absent records with default values. In that pattern,
876                // the union of one GET with its negation, which has first
877                // been subjected to a projection and map, we can remove
878                // their influence on the key structure.
879                //
880                // If there are A, B, each with a unique `key` such that
881                // we are looking at
882                //
883                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
884                //
885                // Then we can report `key` as a unique key.
886                //
887                // TODO: make unique key structure an optimization analysis
888                // rather than part of the type information.
889                // TODO: perhaps ensure that (above) A.proj(key) is a
890                // subset of B, as otherwise there are negative records
891                // and who knows what is true (not expected, but again
892                // who knows what the query plan might look like).
893
894                let arity = input_arities.next().unwrap();
895                let (base_projection, base_with_project_stripped) =
896                    if let MirRelationExpr::Project { input, outputs } = &**base {
897                        (outputs.clone(), &**input)
898                    } else {
899                        // A input without a project is equivalent to an input
900                        // with the project being all columns in the input in order.
901                        ((0..arity).collect::<Vec<_>>(), &**base)
902                    };
903                let mut result = Vec::new();
904                if let MirRelationExpr::Get {
905                    id: first_id,
906                    typ: _,
907                    ..
908                } = base_with_project_stripped
909                {
910                    if inputs.len() == 1 {
911                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
912                            if let MirRelationExpr::Union { base, inputs } = &**input {
913                                if inputs.len() == 1 {
914                                    if let Some((input, outputs)) = base.is_negated_project() {
915                                        if let MirRelationExpr::Get {
916                                            id: second_id,
917                                            typ: _,
918                                            ..
919                                        } = input
920                                        {
921                                            if first_id == second_id {
922                                                result.extend(
923                                                    input_keys
924                                                        .next()
925                                                        .unwrap()
926                                                        .into_iter()
927                                                        .filter(|key| {
928                                                            key.iter().all(|c| {
929                                                                outputs.get(*c) == Some(c)
930                                                                    && base_projection.get(*c)
931                                                                        == Some(c)
932                                                            })
933                                                        })
934                                                        .cloned(),
935                                                );
936                                            }
937                                        }
938                                    }
939                                }
940                            }
941                        }
942                    }
943                }
944                // Important: do not inherit keys of either input, as not unique.
945                result
946            }
947        };
948        keys.sort();
949        keys.dedup();
950        keys
951    }
952
953    /// The number of columns in the relation.
954    ///
955    /// This number is determined from the type, which is determined recursively
956    /// at non-trivial cost.
957    ///
958    /// The arity is computed incrementally with a recursive post-order
959    /// traversal, that accumulates the arities for the relations yet to be
960    /// visited in `arity_stack`.
961    pub fn arity(&self) -> usize {
962        let mut arity_stack = Vec::new();
963        #[allow(deprecated)]
964        self.visit_pre_post_nolimit(
965            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
966                match &e {
967                    MirRelationExpr::Let { body, .. } => {
968                        // Do not traverse the value sub-graph, since it's not relevant for
969                        // determining the arity of Let operators.
970                        Some(vec![&*body])
971                    }
972                    MirRelationExpr::LetRec { body, .. } => {
973                        // Do not traverse the value sub-graph, since it's not relevant for
974                        // determining the arity of Let operators.
975                        Some(vec![&*body])
976                    }
977                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
978                        // No further traversal is required; these operators know their arity.
979                        Some(Vec::new())
980                    }
981                    _ => None,
982                }
983            },
984            &mut |e: &MirRelationExpr| {
985                match &e {
986                    MirRelationExpr::Let { .. } => {
987                        let body_arity = arity_stack.pop().unwrap();
988                        arity_stack.push(0);
989                        arity_stack.push(body_arity);
990                    }
991                    MirRelationExpr::LetRec { values, .. } => {
992                        let body_arity = arity_stack.pop().unwrap();
993                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
994                        arity_stack.push(body_arity);
995                    }
996                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
997                        arity_stack.push(0);
998                    }
999                    _ => {}
1000                }
1001                let num_inputs = e.num_inputs();
1002                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1003                let arity = e.arity_with_input_arities(input_arities);
1004                arity_stack.push(arity);
1005            },
1006        );
1007        assert_eq!(arity_stack.len(), 1);
1008        arity_stack.pop().unwrap()
1009    }
1010
1011    /// Reports the arity of the relation given the schema of the input relations.
1012    ///
1013    /// `input_arities` is required to contain the arities for the input relations of
1014    /// the current relation in the same order as they are visited by `try_visit_children`
1015    /// method, even though not all may be used for computing the schema of the
1016    /// current relation. For example, `Let` expects two input types, one for the
1017    /// value relation and one for the body, in that order, but only the one for the
1018    /// body is used to determine the type of the `Let` relation.
1019    ///
1020    /// It is meant to be used during post-order traversals to compute arities
1021    /// incrementally.
1022    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1023    where
1024        I: Iterator<Item = usize>,
1025    {
1026        use MirRelationExpr::*;
1027
1028        match self {
1029            Constant { rows: _, typ } => typ.arity(),
1030            Get { typ, .. } => typ.arity(),
1031            Let { .. } => {
1032                input_arities.next();
1033                input_arities.next().unwrap()
1034            }
1035            LetRec { values, .. } => {
1036                for _ in 0..values.len() {
1037                    input_arities.next();
1038                }
1039                input_arities.next().unwrap()
1040            }
1041            Project { outputs, .. } => outputs.len(),
1042            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1043            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1044            Join { .. } => input_arities.sum(),
1045            Reduce {
1046                input: _,
1047                group_key,
1048                aggregates,
1049                ..
1050            } => group_key.len() + aggregates.len(),
1051            Filter { .. }
1052            | TopK { .. }
1053            | Negate { .. }
1054            | Threshold { .. }
1055            | Union { .. }
1056            | ArrangeBy { .. } => input_arities.next().unwrap(),
1057        }
1058    }
1059
1060    /// The number of child relations this relation has.
1061    pub fn num_inputs(&self) -> usize {
1062        let mut count = 0;
1063
1064        self.visit_children(|_| count += 1);
1065
1066        count
1067    }
1068
1069    /// Constructs a constant collection from specific rows and schema, where
1070    /// each row will have a multiplicity of one.
1071    pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1072        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1073        MirRelationExpr::constant_diff(rows, typ)
1074    }
1075
1076    /// Constructs a constant collection from specific rows and schema, where
1077    /// each row can have an arbitrary multiplicity.
1078    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1079        for (row, _diff) in &rows {
1080            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1081                assert!(
1082                    datum.is_instance_of(column_typ),
1083                    "Expected datum of type {:?}, got value {:?}",
1084                    column_typ,
1085                    datum
1086                );
1087            }
1088        }
1089        let rows = Ok(rows
1090            .into_iter()
1091            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1092            .collect());
1093        MirRelationExpr::Constant { rows, typ }
1094    }
1095
1096    /// If self is a constant, return the value and the type, otherwise `None`.
1097    /// Looks behind `ArrangeBy`s.
1098    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1099        match self {
1100            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1101            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1102            _ => None,
1103        }
1104    }
1105
1106    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1107    /// Looks behind `ArrangeBy`s.
1108    pub fn as_const_mut(
1109        &mut self,
1110    ) -> Option<(
1111        &mut Result<Vec<(Row, Diff)>, EvalError>,
1112        &mut ReprRelationType,
1113    )> {
1114        match self {
1115            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1116            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1117            _ => None,
1118        }
1119    }
1120
1121    /// If self is a constant error, return the error, otherwise `None`.
1122    /// Looks behind `ArrangeBy`s.
1123    pub fn as_const_err(&self) -> Option<&EvalError> {
1124        match self {
1125            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1126            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1127            _ => None,
1128        }
1129    }
1130
1131    /// Checks if `self` is the single element collection with no columns.
1132    pub fn is_constant_singleton(&self) -> bool {
1133        if let Some((Ok(rows), typ)) = self.as_const() {
1134            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1135        } else {
1136            false
1137        }
1138    }
1139
1140    /// Constructs the expression for getting a local collection.
1141    pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1142        MirRelationExpr::Get {
1143            id: Id::Local(id),
1144            typ,
1145            access_strategy: AccessStrategy::UnknownOrLocal,
1146        }
1147    }
1148
1149    /// Constructs the expression for getting a global collection
1150    pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1151        MirRelationExpr::Get {
1152            id: Id::Global(id),
1153            typ,
1154            access_strategy: AccessStrategy::UnknownOrLocal,
1155        }
1156    }
1157
1158    /// Retains only the columns specified by `output`.
1159    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1160        if let MirRelationExpr::Project {
1161            outputs: columns, ..
1162        } = &mut self
1163        {
1164            // Update `outputs` to reference base columns of `input`.
1165            for column in outputs.iter_mut() {
1166                *column = columns[*column];
1167            }
1168            *columns = outputs;
1169            self
1170        } else {
1171            MirRelationExpr::Project {
1172                input: Box::new(self),
1173                outputs,
1174            }
1175        }
1176    }
1177
1178    /// Append to each row the results of applying elements of `scalar`.
1179    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1180        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1181            s.extend(scalars);
1182            self
1183        } else if !scalars.is_empty() {
1184            MirRelationExpr::Map {
1185                input: Box::new(self),
1186                scalars,
1187            }
1188        } else {
1189            self
1190        }
1191    }
1192
1193    /// Append to each row a single `scalar`.
1194    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1195        self.map(vec![scalar])
1196    }
1197
1198    /// Like `map`, but yields zero-or-more output rows per input row
1199    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1200        MirRelationExpr::FlatMap {
1201            input: Box::new(self),
1202            func,
1203            exprs,
1204        }
1205    }
1206
1207    /// Retain only the rows satisfying each of several predicates.
1208    pub fn filter<I>(mut self, predicates: I) -> Self
1209    where
1210        I: IntoIterator<Item = MirScalarExpr>,
1211    {
1212        // Extract existing predicates
1213        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1214            self = *input;
1215            predicates
1216        } else {
1217            Vec::new()
1218        };
1219        // Normalize collection of predicates.
1220        new_predicates.extend(predicates);
1221        new_predicates.retain(|p| !p.is_literal_true());
1222        new_predicates.sort();
1223        new_predicates.dedup();
1224        // Introduce a `Filter` only if we have predicates.
1225        if !new_predicates.is_empty() {
1226            self = MirRelationExpr::Filter {
1227                input: Box::new(self),
1228                predicates: new_predicates,
1229            };
1230        }
1231
1232        self
1233    }
1234
1235    /// Form the Cartesian outer-product of rows in both inputs.
1236    pub fn product(mut self, right: Self) -> Self {
1237        if right.is_constant_singleton() {
1238            self
1239        } else if self.is_constant_singleton() {
1240            right
1241        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1242            inputs.push(right);
1243            self
1244        } else {
1245            MirRelationExpr::join(vec![self, right], vec![])
1246        }
1247    }
1248
1249    /// Performs a relational equijoin among the input collections.
1250    ///
1251    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1252    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1253    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1254    ///
1255    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1256    /// input collection and for each row examining its `column`th column.
1257    ///
1258    /// # Example
1259    ///
1260    /// ```rust
1261    /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1262    /// use mz_expr::MirRelationExpr;
1263    ///
1264    /// // A common schema for each input.
1265    /// let schema = ReprRelationType::new(vec![
1266    ///     ReprScalarType::Int32.nullable(false),
1267    ///     ReprScalarType::Int32.nullable(false),
1268    /// ]);
1269    ///
1270    /// // the specific data are not important here.
1271    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1272    ///
1273    /// // Three collections that could have been different.
1274    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1276    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1277    ///
1278    /// // Join the three relations looking for triangles, like so.
1279    /// //
1280    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1281    /// let joined = MirRelationExpr::join(
1282    ///     vec![input0, input1, input2],
1283    ///     vec![
1284    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1285    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1286    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1287    ///     ],
1288    /// );
1289    ///
1290    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1291    /// // A projection resolves this and produces the correct output.
1292    /// let result = joined.project(vec![0, 1, 3]);
1293    /// ```
1294    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1295        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1296
1297        let equivalences = variables
1298            .into_iter()
1299            .map(|vs| {
1300                vs.into_iter()
1301                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1302                    .collect::<Vec<_>>()
1303            })
1304            .collect::<Vec<_>>();
1305
1306        Self::join_scalars(inputs, equivalences)
1307    }
1308
1309    /// Constructs a join operator from inputs and required-equal scalar expressions.
1310    pub fn join_scalars(
1311        mut inputs: Vec<MirRelationExpr>,
1312        equivalences: Vec<Vec<MirScalarExpr>>,
1313    ) -> Self {
1314        // Remove all constant inputs that are the identity for join.
1315        // They neither introduce nor modify any column references.
1316        inputs.retain(|i| !i.is_constant_singleton());
1317        MirRelationExpr::Join {
1318            inputs,
1319            equivalences,
1320            implementation: JoinImplementation::Unimplemented,
1321        }
1322    }
1323
1324    /// Perform a key-wise reduction / aggregation.
1325    ///
1326    /// The `group_key` argument indicates columns in the input collection that should
1327    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1328    /// one output column in addition to the keys.
1329    pub fn reduce(
1330        self,
1331        group_key: Vec<usize>,
1332        aggregates: Vec<AggregateExpr>,
1333        expected_group_size: Option<u64>,
1334    ) -> Self {
1335        MirRelationExpr::Reduce {
1336            input: Box::new(self),
1337            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1338            aggregates,
1339            monotonic: false,
1340            expected_group_size,
1341        }
1342    }
1343
1344    /// Perform a key-wise reduction order by and limit.
1345    ///
1346    /// The `group_key` argument indicates columns in the input collection that should
1347    /// be grouped, the `order_key` argument indicates columns that should be further
1348    /// used to order records within groups, and the `limit` argument constrains the
1349    /// total number of records that should be produced in each group.
1350    pub fn top_k(
1351        self,
1352        group_key: Vec<usize>,
1353        order_key: Vec<ColumnOrder>,
1354        limit: Option<MirScalarExpr>,
1355        offset: usize,
1356        expected_group_size: Option<u64>,
1357    ) -> Self {
1358        MirRelationExpr::TopK {
1359            input: Box::new(self),
1360            group_key,
1361            order_key,
1362            limit,
1363            offset,
1364            expected_group_size,
1365            monotonic: false,
1366        }
1367    }
1368
1369    /// Negates the occurrences of each row.
1370    pub fn negate(self) -> Self {
1371        if let MirRelationExpr::Negate { input } = self {
1372            *input
1373        } else {
1374            MirRelationExpr::Negate {
1375                input: Box::new(self),
1376            }
1377        }
1378    }
1379
1380    /// Removes all but the first occurrence of each row.
1381    pub fn distinct(self) -> Self {
1382        let arity = self.arity();
1383        self.distinct_by((0..arity).collect())
1384    }
1385
1386    /// Removes all but the first occurrence of each key. Columns not included
1387    /// in the `group_key` are discarded.
1388    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1389        self.reduce(group_key, vec![], None)
1390    }
1391
1392    /// Discards rows with a negative frequency.
1393    pub fn threshold(self) -> Self {
1394        if let MirRelationExpr::Threshold { .. } = &self {
1395            self
1396        } else {
1397            MirRelationExpr::Threshold {
1398                input: Box::new(self),
1399            }
1400        }
1401    }
1402
1403    /// Unions together any number inputs.
1404    ///
1405    /// If `inputs` is empty, then an empty relation of type `typ` is
1406    /// constructed.
1407    pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1408        // Deconstruct `inputs` as `Union`s and reconstitute.
1409        let mut flat_inputs = Vec::with_capacity(inputs.len());
1410        for input in inputs {
1411            if let MirRelationExpr::Union { base, inputs } = input {
1412                flat_inputs.push(*base);
1413                flat_inputs.extend(inputs);
1414            } else {
1415                flat_inputs.push(input);
1416            }
1417        }
1418        inputs = flat_inputs;
1419        if inputs.len() == 0 {
1420            MirRelationExpr::Constant {
1421                rows: Ok(vec![]),
1422                typ,
1423            }
1424        } else if inputs.len() == 1 {
1425            inputs.into_element()
1426        } else {
1427            MirRelationExpr::Union {
1428                base: Box::new(inputs.remove(0)),
1429                inputs,
1430            }
1431        }
1432    }
1433
1434    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1435    pub fn union(self, other: Self) -> Self {
1436        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1437        let mut flat_inputs = Vec::with_capacity(2);
1438        if let MirRelationExpr::Union { base, inputs } = self {
1439            flat_inputs.push(*base);
1440            flat_inputs.extend(inputs);
1441        } else {
1442            flat_inputs.push(self);
1443        }
1444        if let MirRelationExpr::Union { base, inputs } = other {
1445            flat_inputs.push(*base);
1446            flat_inputs.extend(inputs);
1447        } else {
1448            flat_inputs.push(other);
1449        }
1450
1451        MirRelationExpr::Union {
1452            base: Box::new(flat_inputs.remove(0)),
1453            inputs: flat_inputs,
1454        }
1455    }
1456
1457    /// Arranges the collection by the specified columns
1458    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1459        MirRelationExpr::ArrangeBy {
1460            input: Box::new(self),
1461            keys: keys.to_owned(),
1462        }
1463    }
1464
1465    /// Indicates if this is a constant empty collection.
1466    ///
1467    /// A false value does not mean the collection is known to be non-empty,
1468    /// only that we cannot currently determine that it is statically empty.
1469    pub fn is_empty(&self) -> bool {
1470        if let Some((Ok(rows), ..)) = self.as_const() {
1471            rows.is_empty()
1472        } else {
1473            false
1474        }
1475    }
1476
1477    /// If the expression is a negated project, return the input and the projection.
1478    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1479        if let MirRelationExpr::Negate { input } = self {
1480            if let MirRelationExpr::Project { input, outputs } = &**input {
1481                return Some((&**input, outputs));
1482            }
1483        }
1484        if let MirRelationExpr::Project { input, outputs } = self {
1485            if let MirRelationExpr::Negate { input } = &**input {
1486                return Some((&**input, outputs));
1487            }
1488        }
1489        None
1490    }
1491
1492    /// Pretty-print this [MirRelationExpr] to a string.
1493    pub fn pretty(&self) -> String {
1494        let config = ExplainConfig::default();
1495        self.debug_explain(&config, None)
1496    }
1497
1498    /// Pretty-print this [MirRelationExpr] to a string using a custom
1499    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1500    /// This is intended for debugging and tests, not users.
1501    pub fn debug_explain(
1502        &self,
1503        config: &ExplainConfig,
1504        humanizer: Option<&dyn ExprHumanizer>,
1505    ) -> String {
1506        text_string_at(self, || PlanRenderingContext {
1507            indent: Indent::default(),
1508            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1509            annotations: BTreeMap::default(),
1510            config,
1511            ambiguous_ids: BTreeSet::default(),
1512        })
1513    }
1514
1515    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1516    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1517    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1518    /// the best possible key and nullability, since we are making an empty collection.
1519    ///
1520    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1521    /// the correct type.
1522    pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1523        if let Some(typ) = &typ {
1524            let self_typ = self.typ();
1525            soft_assert_no_log!(
1526                self_typ
1527                    .column_types
1528                    .iter()
1529                    .zip_eq(typ.column_types.iter())
1530                    .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1531            );
1532        }
1533        let mut typ = typ.unwrap_or_else(|| self.typ());
1534        typ.keys = vec![vec![]];
1535        for ct in typ.column_types.iter_mut() {
1536            ct.nullable = false;
1537        }
1538        std::mem::replace(
1539            self,
1540            MirRelationExpr::Constant {
1541                rows: Ok(vec![]),
1542                typ,
1543            },
1544        )
1545    }
1546
1547    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1548    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1549    /// possible nullability, since we are making an empty collection.
1550    pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1551        self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1552    }
1553
1554    /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1555    ///
1556    /// This is the preferred entry point for optimizer transforms, where repr
1557    /// types are the native currency. Internally converts to [`SqlColumnType`]
1558    /// and delegates to [`Self::take_safely_with_col_types`].
1559    pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1560        self.take_safely(Some(ReprRelationType::new(typ)))
1561    }
1562
1563    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1564    ///
1565    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1566    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1567        let empty = MirRelationExpr::Constant {
1568            rows: Ok(vec![]),
1569            typ: ReprRelationType::new(Vec::new()),
1570        };
1571        std::mem::replace(self, empty)
1572    }
1573
1574    /// Replaces `self` with some logic applied to `self`.
1575    pub fn replace_using<F>(&mut self, logic: F)
1576    where
1577        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1578    {
1579        let empty = MirRelationExpr::Constant {
1580            rows: Ok(vec![]),
1581            typ: ReprRelationType::new(Vec::new()),
1582        };
1583        let expr = std::mem::replace(self, empty);
1584        *self = logic(expr);
1585    }
1586
1587    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1588    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1589    where
1590        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1591    {
1592        if let MirRelationExpr::Get { .. } = self {
1593            // already done
1594            body(id_gen, self)
1595        } else {
1596            let id = LocalId::new(id_gen.allocate_id());
1597            let get = MirRelationExpr::Get {
1598                id: Id::Local(id),
1599                typ: self.typ(),
1600                access_strategy: AccessStrategy::UnknownOrLocal,
1601            };
1602            let body = (body)(id_gen, get)?;
1603            Ok(MirRelationExpr::Let {
1604                id,
1605                value: Box::new(self),
1606                body: Box::new(body),
1607            })
1608        }
1609    }
1610
1611    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1612    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1613    pub fn anti_lookup<E>(
1614        self,
1615        id_gen: &mut IdGen,
1616        keys_and_values: MirRelationExpr,
1617        default: Vec<(Datum, ReprScalarType)>,
1618    ) -> Result<MirRelationExpr, E> {
1619        let (data, column_types): (Vec<_>, Vec<_>) = default
1620            .into_iter()
1621            .map(|(datum, scalar_type)| {
1622                (
1623                    datum,
1624                    ReprColumnType {
1625                        scalar_type,
1626                        nullable: datum.is_null(),
1627                    },
1628                )
1629            })
1630            .unzip();
1631        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1632        self.let_in(id_gen, |_id_gen, get_keys| {
1633            let get_keys_arity = get_keys.arity();
1634            Ok(MirRelationExpr::join(
1635                vec![
1636                    // all the missing keys (with count 1)
1637                    keys_and_values
1638                        .distinct_by((0..get_keys_arity).collect())
1639                        .negate()
1640                        .union(get_keys.clone().distinct()),
1641                    // join with keys to get the correct counts
1642                    get_keys.clone(),
1643                ],
1644                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1645            )
1646            // get rid of the extra copies of columns from keys
1647            .project((0..get_keys_arity).collect())
1648            // This join is logically equivalent to
1649            // `.map(<default_expr>)`, but using a join allows for
1650            // potential predicate pushdown and elision in the
1651            // optimizer.
1652            .product(MirRelationExpr::constant(
1653                vec![data],
1654                ReprRelationType::new(column_types),
1655            )))
1656        })
1657    }
1658
1659    /// Return:
1660    /// * every row in keys_and_values
1661    /// * every row in `self` that does not have a matching row in the first columns of
1662    ///   `keys_and_values`, using `default` to fill in the remaining columns
1663    /// (This is LEFT OUTER JOIN if:
1664    /// 1) `default` is a row of null
1665    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1666    pub fn lookup<E>(
1667        self,
1668        id_gen: &mut IdGen,
1669        keys_and_values: MirRelationExpr,
1670        default: Vec<(Datum<'static>, ReprScalarType)>,
1671    ) -> Result<MirRelationExpr, E> {
1672        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1673            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1674                id_gen,
1675                get_keys_and_values,
1676                default,
1677            )?))
1678        })
1679    }
1680
1681    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1682    pub fn contains_temporal(&self) -> bool {
1683        let mut contains = false;
1684        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1685        contains
1686    }
1687
1688    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1689    ///
1690    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1691    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1692    where
1693        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1694    {
1695        use MirRelationExpr::*;
1696        match self {
1697            Map { scalars, .. } => {
1698                for s in scalars {
1699                    f(s)?;
1700                }
1701            }
1702            Filter { predicates, .. } => {
1703                for p in predicates {
1704                    f(p)?;
1705                }
1706            }
1707            FlatMap { exprs, .. } => {
1708                for expr in exprs {
1709                    f(expr)?;
1710                }
1711            }
1712            Join {
1713                inputs: _,
1714                equivalences,
1715                implementation,
1716            } => {
1717                for equivalence in equivalences {
1718                    for expr in equivalence {
1719                        f(expr)?;
1720                    }
1721                }
1722                match implementation {
1723                    JoinImplementation::Differential((_, start_key, _), order) => {
1724                        if let Some(start_key) = start_key {
1725                            for k in start_key {
1726                                f(k)?;
1727                            }
1728                        }
1729                        for (_, lookup_key, _) in order {
1730                            for k in lookup_key {
1731                                f(k)?;
1732                            }
1733                        }
1734                    }
1735                    JoinImplementation::DeltaQuery(paths) => {
1736                        for path in paths {
1737                            for (_, lookup_key, _) in path {
1738                                for k in lookup_key {
1739                                    f(k)?;
1740                                }
1741                            }
1742                        }
1743                    }
1744                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1745                        for k in index_key {
1746                            f(k)?;
1747                        }
1748                    }
1749                    JoinImplementation::Unimplemented => {} // No scalar exprs
1750                }
1751            }
1752            ArrangeBy { keys, .. } => {
1753                for key in keys {
1754                    for s in key {
1755                        f(s)?;
1756                    }
1757                }
1758            }
1759            Reduce {
1760                group_key,
1761                aggregates,
1762                ..
1763            } => {
1764                for s in group_key {
1765                    f(s)?;
1766                }
1767                for agg in aggregates {
1768                    f(&mut agg.expr)?;
1769                }
1770            }
1771            TopK { limit, .. } => {
1772                if let Some(s) = limit {
1773                    f(s)?;
1774                }
1775            }
1776            Constant { .. }
1777            | Get { .. }
1778            | Let { .. }
1779            | LetRec { .. }
1780            | Project { .. }
1781            | Negate { .. }
1782            | Threshold { .. }
1783            | Union { .. } => (),
1784        }
1785        Ok(())
1786    }
1787
1788    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1789    /// rooted at `self`.
1790    ///
1791    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1792    /// nodes.
1793    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1794    where
1795        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1796        E: From<RecursionLimitError>,
1797    {
1798        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1799    }
1800
1801    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1802    /// rooted at `self`.
1803    ///
1804    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1805    /// nodes.
1806    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1807    where
1808        F: FnMut(&mut MirScalarExpr),
1809    {
1810        self.try_visit_scalars_mut(&mut |s| {
1811            f(s);
1812            Ok::<_, RecursionLimitError>(())
1813        })
1814        .expect("Unexpected error in `visit_scalars_mut` call");
1815    }
1816
1817    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1818    ///
1819    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1820    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1821    where
1822        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1823    {
1824        use MirRelationExpr::*;
1825        match self {
1826            Map { scalars, .. } => {
1827                for s in scalars {
1828                    f(s)?;
1829                }
1830            }
1831            Filter { predicates, .. } => {
1832                for p in predicates {
1833                    f(p)?;
1834                }
1835            }
1836            FlatMap { exprs, .. } => {
1837                for expr in exprs {
1838                    f(expr)?;
1839                }
1840            }
1841            Join {
1842                inputs: _,
1843                equivalences,
1844                implementation,
1845            } => {
1846                for equivalence in equivalences {
1847                    for expr in equivalence {
1848                        f(expr)?;
1849                    }
1850                }
1851                match implementation {
1852                    JoinImplementation::Differential((_, start_key, _), order) => {
1853                        if let Some(start_key) = start_key {
1854                            for k in start_key {
1855                                f(k)?;
1856                            }
1857                        }
1858                        for (_, lookup_key, _) in order {
1859                            for k in lookup_key {
1860                                f(k)?;
1861                            }
1862                        }
1863                    }
1864                    JoinImplementation::DeltaQuery(paths) => {
1865                        for path in paths {
1866                            for (_, lookup_key, _) in path {
1867                                for k in lookup_key {
1868                                    f(k)?;
1869                                }
1870                            }
1871                        }
1872                    }
1873                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1874                        for k in index_key {
1875                            f(k)?;
1876                        }
1877                    }
1878                    JoinImplementation::Unimplemented => {} // No scalar exprs
1879                }
1880            }
1881            ArrangeBy { keys, .. } => {
1882                for key in keys {
1883                    for s in key {
1884                        f(s)?;
1885                    }
1886                }
1887            }
1888            Reduce {
1889                group_key,
1890                aggregates,
1891                ..
1892            } => {
1893                for s in group_key {
1894                    f(s)?;
1895                }
1896                for agg in aggregates {
1897                    f(&agg.expr)?;
1898                }
1899            }
1900            TopK { limit, .. } => {
1901                if let Some(s) = limit {
1902                    f(s)?;
1903                }
1904            }
1905            Constant { .. }
1906            | Get { .. }
1907            | Let { .. }
1908            | LetRec { .. }
1909            | Project { .. }
1910            | Negate { .. }
1911            | Threshold { .. }
1912            | Union { .. } => (),
1913        }
1914        Ok(())
1915    }
1916
1917    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1918    /// rooted at `self`.
1919    ///
1920    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1921    /// nodes.
1922    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1923    where
1924        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1925        E: From<RecursionLimitError>,
1926    {
1927        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1928    }
1929
1930    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1931    /// rooted at `self`.
1932    ///
1933    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1934    /// nodes.
1935    pub fn visit_scalars<F>(&self, f: &mut F)
1936    where
1937        F: FnMut(&MirScalarExpr),
1938    {
1939        self.try_visit_scalars(&mut |s| {
1940            f(s);
1941            Ok::<_, RecursionLimitError>(())
1942        })
1943        .expect("Unexpected error in `visit_scalars` call");
1944    }
1945
1946    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1947    /// stack overflow in `drop_in_place`.
1948    ///
1949    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1950    /// dropped or otherwise overwritten.
1951    pub fn destroy_carefully(&mut self) {
1952        let mut todo = vec![self.take_dangerous()];
1953        while let Some(mut expr) = todo.pop() {
1954            for child in expr.children_mut() {
1955                todo.push(child.take_dangerous());
1956            }
1957        }
1958    }
1959
1960    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1961    /// debug printing purposes.
1962    pub fn debug_size_and_depth(&self) -> (usize, usize) {
1963        let mut size = 0;
1964        let mut max_depth = 0;
1965        let mut todo = vec![(self, 1)];
1966        while let Some((expr, depth)) = todo.pop() {
1967            size += 1;
1968            max_depth = max(max_depth, depth);
1969            todo.extend(expr.children().map(|c| (c, depth + 1)));
1970        }
1971        (size, max_depth)
1972    }
1973
1974    /// The MirRelationExpr is considered potentially expensive if and only if
1975    /// at least one of the following conditions is true:
1976    ///
1977    ///  - It contains at least one MirScalarExpr with a function call.
1978    ///  - It contains at least one FlatMap or a Reduce operator.
1979    ///  - We run into a RecursionLimitError while analyzing the expression.
1980    ///
1981    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1982    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1983    pub fn could_run_expensive_function(&self) -> bool {
1984        let mut result = false;
1985        use MirRelationExpr::*;
1986        use MirScalarExpr::*;
1987        if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1988            result |= match scalar {
1989                Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1990                // Function calls are considered expensive
1991                CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1992            };
1993            Ok(())
1994        }) {
1995            // Conservatively set `true` if on RecursionLimitError.
1996            result = true;
1997        }
1998        self.visit_pre(|e: &MirRelationExpr| {
1999            // FlatMap has a table function; Reduce has an aggregate function.
2000            // Other constructs use MirScalarExpr to run a function
2001            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2002        });
2003        result
2004    }
2005
2006    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2007    /// than what `Hashable::hashed` would give us.)
2008    pub fn hash_to_u64(&self) -> u64 {
2009        let mut h = DefaultHasher::new();
2010        self.hash(&mut h);
2011        h.finish()
2012    }
2013}
2014
2015// `LetRec` helpers
2016impl MirRelationExpr {
2017    /// True when `expr` contains a `LetRec` AST node.
2018    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2019        let mut worklist = vec![self];
2020        while let Some(expr) = worklist.pop() {
2021            if let MirRelationExpr::LetRec { .. } = expr {
2022                return true;
2023            }
2024            worklist.extend(expr.children());
2025        }
2026        false
2027    }
2028
2029    /// Return the number of sub-expressions in the tree (including self).
2030    pub fn size(&self) -> usize {
2031        let mut size = 0;
2032        self.visit_pre(|_| size += 1);
2033        size
2034    }
2035
2036    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2037    /// iterations. These are those ids that have a reference before they are defined, when reading
2038    /// all the bindings in order.
2039    ///
2040    /// For example:
2041    /// ```SQL
2042    /// WITH MUTUALLY RECURSIVE
2043    ///     x(...) AS f(z),
2044    ///     y(...) AS g(x),
2045    ///     z(...) AS h(y)
2046    /// ...;
2047    /// ```
2048    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2049    /// iteration.
2050    ///
2051    /// Note that if a binding references itself, that is also returned.
2052    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2053        let mut used_across_iterations = BTreeSet::new();
2054        let mut defined = BTreeSet::new();
2055        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2056            value.visit_pre(|expr| {
2057                if let MirRelationExpr::Get {
2058                    id: Local(get_id), ..
2059                } = expr
2060                {
2061                    // If we haven't seen a definition for it yet, then this will refer
2062                    // to the previous iteration.
2063                    // The `ids.contains` part of the condition is needed to exclude
2064                    // those ids that are not really in this LetRec, but either an inner
2065                    // or outer one.
2066                    if !defined.contains(get_id) && ids.contains(get_id) {
2067                        used_across_iterations.insert(*get_id);
2068                    }
2069                }
2070            });
2071            defined.insert(*binding_id);
2072        }
2073        used_across_iterations
2074    }
2075
2076    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2077    ///
2078    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2079    /// identifiers are rewritten to be the constant collection.
2080    /// This makes the computation perform exactly "one" iteration.
2081    ///
2082    /// This was used only temporarily while developing `LetRec`.
2083    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2084        let mut deadlist = BTreeSet::new();
2085        let mut worklist = vec![self];
2086        while let Some(expr) = worklist.pop() {
2087            if let MirRelationExpr::LetRec {
2088                ids,
2089                values,
2090                limits: _,
2091                body,
2092            } = expr
2093            {
2094                let ids_values = values
2095                    .drain(..)
2096                    .zip_eq(ids)
2097                    .map(|(value, id)| (*id, value))
2098                    .collect::<Vec<_>>();
2099                *expr = body.take_dangerous();
2100                for (id, mut value) in ids_values.into_iter().rev() {
2101                    // Remove references to potentially recursive identifiers.
2102                    deadlist.insert(id);
2103                    value.visit_pre_mut(|e| {
2104                        if let MirRelationExpr::Get {
2105                            id: crate::Id::Local(id),
2106                            typ,
2107                            ..
2108                        } = e
2109                        {
2110                            let typ = typ.clone();
2111                            if deadlist.contains(id) {
2112                                e.take_safely(Some(typ));
2113                            }
2114                        }
2115                    });
2116                    *expr = MirRelationExpr::Let {
2117                        id,
2118                        value: Box::new(value),
2119                        body: Box::new(expr.take_dangerous()),
2120                    };
2121                }
2122                worklist.push(expr);
2123            } else {
2124                worklist.extend(expr.children_mut().rev());
2125            }
2126        }
2127    }
2128
2129    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2130    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2131    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2132    /// redefinition.
2133    ///
2134    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2135    pub fn collect_expirations(
2136        id: LocalId,
2137        expr: &MirRelationExpr,
2138        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2139    ) {
2140        expr.visit_pre(|e| {
2141            if let MirRelationExpr::Get {
2142                id: Id::Local(referenced_id),
2143                ..
2144            } = e
2145            {
2146                // The following check needs `renumber_bindings` to have run recently
2147                if referenced_id >= &id {
2148                    expire_whens
2149                        .entry(*referenced_id)
2150                        .or_insert_with(Vec::new)
2151                        .push(id);
2152                }
2153            }
2154        });
2155    }
2156
2157    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2158    /// about such Ids whose information depended on the earlier definition of `id`, according to
2159    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2160    pub fn do_expirations<I>(
2161        redefined_id: LocalId,
2162        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2163        id_infos: &mut BTreeMap<LocalId, I>,
2164    ) -> Vec<(LocalId, I)> {
2165        let mut expired_infos = Vec::new();
2166        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2167            for expired_id in expirations.into_iter() {
2168                if let Some(offer) = id_infos.remove(&expired_id) {
2169                    expired_infos.push((expired_id, offer));
2170                }
2171            }
2172        }
2173        expired_infos
2174    }
2175}
2176/// Augment non-nullability of columns, by observing either
2177/// 1. Predicates that explicitly test for null values, and
2178/// 2. Columns that if null would make a predicate be null.
2179pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2180    let mut nonnull_required_columns = BTreeSet::new();
2181    for predicate in predicates {
2182        // Add any columns that being null would force the predicate to be null.
2183        // Should that happen, the row would be discarded.
2184        predicate.non_null_requirements(&mut nonnull_required_columns);
2185
2186        /*
2187        Test for explicit checks that a column is non-null.
2188
2189        This analysis is ad hoc, and will miss things:
2190
2191        materialize=> create table a(x int, y int);
2192        CREATE TABLE
2193        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2194        Optimized Plan
2195        --------------------------------------------------------------------------------------------------------
2196        Explained Query:                                                                                      +
2197        Project (#0) // { types: "(integer?)" }                                                             +
2198        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2199        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2200                                                                                  +
2201        Source materialize.public.a                                                                           +
2202        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2203
2204        (1 row)
2205        */
2206
2207        if let MirScalarExpr::CallUnary {
2208            func: UnaryFunc::Not(scalar_func::Not),
2209            expr,
2210        } = predicate
2211        {
2212            if let MirScalarExpr::CallUnary {
2213                func: UnaryFunc::IsNull(scalar_func::IsNull),
2214                expr,
2215            } = &**expr
2216            {
2217                if let MirScalarExpr::Column(c, _name) = &**expr {
2218                    nonnull_required_columns.insert(*c);
2219                }
2220            }
2221        }
2222    }
2223
2224    nonnull_required_columns
2225}
2226
2227impl CollectionPlan for MirRelationExpr {
2228    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2229    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2230    ///
2231    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2232    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2233    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2234        if let MirRelationExpr::Get {
2235            id: Id::Global(id), ..
2236        } = self
2237        {
2238            out.insert(*id);
2239        }
2240        self.visit_children(|expr| expr.depends_on_into(out))
2241    }
2242}
2243
2244impl MirRelationExpr {
2245    /// Iterates through references to child expressions.
2246    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2247        let mut first = None;
2248        let mut second = None;
2249        let mut rest = None;
2250        let mut last = None;
2251
2252        use MirRelationExpr::*;
2253        match self {
2254            Constant { .. } | Get { .. } => (),
2255            Let { value, body, .. } => {
2256                first = Some(&**value);
2257                second = Some(&**body);
2258            }
2259            LetRec { values, body, .. } => {
2260                rest = Some(values);
2261                last = Some(&**body);
2262            }
2263            Project { input, .. }
2264            | Map { input, .. }
2265            | FlatMap { input, .. }
2266            | Filter { input, .. }
2267            | Reduce { input, .. }
2268            | TopK { input, .. }
2269            | Negate { input }
2270            | Threshold { input }
2271            | ArrangeBy { input, .. } => {
2272                first = Some(&**input);
2273            }
2274            Join { inputs, .. } => {
2275                rest = Some(inputs);
2276            }
2277            Union { base, inputs } => {
2278                first = Some(&**base);
2279                rest = Some(inputs);
2280            }
2281        }
2282
2283        first
2284            .into_iter()
2285            .chain(second)
2286            .chain(rest.into_iter().flatten())
2287            .chain(last)
2288    }
2289
2290    /// Iterates through mutable references to child expressions.
2291    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2292        let mut first = None;
2293        let mut second = None;
2294        let mut rest = None;
2295        let mut last = None;
2296
2297        use MirRelationExpr::*;
2298        match self {
2299            Constant { .. } | Get { .. } => (),
2300            Let { value, body, .. } => {
2301                first = Some(&mut **value);
2302                second = Some(&mut **body);
2303            }
2304            LetRec { values, body, .. } => {
2305                rest = Some(values);
2306                last = Some(&mut **body);
2307            }
2308            Project { input, .. }
2309            | Map { input, .. }
2310            | FlatMap { input, .. }
2311            | Filter { input, .. }
2312            | Reduce { input, .. }
2313            | TopK { input, .. }
2314            | Negate { input }
2315            | Threshold { input }
2316            | ArrangeBy { input, .. } => {
2317                first = Some(&mut **input);
2318            }
2319            Join { inputs, .. } => {
2320                rest = Some(inputs);
2321            }
2322            Union { base, inputs } => {
2323                first = Some(&mut **base);
2324                rest = Some(inputs);
2325            }
2326        }
2327
2328        first
2329            .into_iter()
2330            .chain(second)
2331            .chain(rest.into_iter().flatten())
2332            .chain(last)
2333    }
2334
2335    /// Iterative pre-order visitor.
2336    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2337        let mut worklist = vec![self];
2338        while let Some(expr) = worklist.pop() {
2339            f(expr);
2340            worklist.extend(expr.children().rev());
2341        }
2342    }
2343
2344    /// Iterative pre-order visitor.
2345    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2346        let mut worklist = vec![self];
2347        while let Some(expr) = worklist.pop() {
2348            f(expr);
2349            worklist.extend(expr.children_mut().rev());
2350        }
2351    }
2352
2353    /// Return a vector of references to the subtrees of this expression
2354    /// in post-visit order (the last element is `&self`).
2355    pub fn post_order_vec(&self) -> Vec<&Self> {
2356        let mut stack = vec![self];
2357        let mut result = vec![];
2358        while let Some(expr) = stack.pop() {
2359            result.push(expr);
2360            stack.extend(expr.children());
2361        }
2362        result.reverse();
2363        result
2364    }
2365}
2366
2367impl VisitChildren<Self> for MirRelationExpr {
2368    fn visit_children<F>(&self, mut f: F)
2369    where
2370        F: FnMut(&Self),
2371    {
2372        for child in self.children() {
2373            f(child)
2374        }
2375    }
2376
2377    fn visit_mut_children<F>(&mut self, mut f: F)
2378    where
2379        F: FnMut(&mut Self),
2380    {
2381        for child in self.children_mut() {
2382            f(child)
2383        }
2384    }
2385
2386    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2387    where
2388        F: FnMut(&Self) -> Result<(), E>,
2389        E: From<RecursionLimitError>,
2390    {
2391        for child in self.children() {
2392            f(child)?
2393        }
2394        Ok(())
2395    }
2396
2397    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2398    where
2399        F: FnMut(&mut Self) -> Result<(), E>,
2400        E: From<RecursionLimitError>,
2401    {
2402        for child in self.children_mut() {
2403            f(child)?
2404        }
2405        Ok(())
2406    }
2407}
2408
2409/// Specification for an ordering by a column.
2410#[derive(
2411    Debug,
2412    Clone,
2413    Copy,
2414    Eq,
2415    PartialEq,
2416    Ord,
2417    PartialOrd,
2418    Serialize,
2419    Deserialize,
2420    Hash,
2421    MzReflect
2422)]
2423pub struct ColumnOrder {
2424    /// The column index.
2425    pub column: usize,
2426    /// Whether to sort in descending order.
2427    #[serde(default)]
2428    pub desc: bool,
2429    /// Whether to sort nulls last.
2430    #[serde(default)]
2431    pub nulls_last: bool,
2432}
2433
2434impl Columnation for ColumnOrder {
2435    type InnerRegion = CopyRegion<Self>;
2436}
2437
2438impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2439where
2440    M: HumanizerMode,
2441{
2442    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2443        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2444        write!(
2445            f,
2446            "{} {} {}",
2447            self.child(&self.expr.column),
2448            if self.expr.desc { "desc" } else { "asc" },
2449            if self.expr.nulls_last {
2450                "nulls_last"
2451            } else {
2452                "nulls_first"
2453            },
2454        )
2455    }
2456}
2457
2458/// Describes an aggregation expression.
2459#[derive(
2460    Clone,
2461    Debug,
2462    Eq,
2463    PartialEq,
2464    Ord,
2465    PartialOrd,
2466    Serialize,
2467    Deserialize,
2468    Hash,
2469    MzReflect
2470)]
2471pub struct AggregateExpr {
2472    /// Names the aggregation function.
2473    pub func: AggregateFunc,
2474    /// An expression which extracts from each row the input to `func`.
2475    pub expr: MirScalarExpr,
2476    /// Should the aggregation be applied only to distinct results in each group.
2477    #[serde(default)]
2478    pub distinct: bool,
2479}
2480
2481impl AggregateExpr {
2482    /// Computes the type of this `AggregateExpr`.
2483    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2484        self.func.output_sql_type(self.expr.sql_typ(column_types))
2485    }
2486
2487    /// Computes the type of this `AggregateExpr`.
2488    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2489        self.func.output_type(self.expr.typ(column_types))
2490    }
2491
2492    /// Returns whether the expression has a constant result.
2493    pub fn is_constant(&self) -> bool {
2494        match self.func {
2495            AggregateFunc::MaxNumeric
2496            | AggregateFunc::MaxInt16
2497            | AggregateFunc::MaxInt32
2498            | AggregateFunc::MaxInt64
2499            | AggregateFunc::MaxUInt16
2500            | AggregateFunc::MaxUInt32
2501            | AggregateFunc::MaxUInt64
2502            | AggregateFunc::MaxMzTimestamp
2503            | AggregateFunc::MaxFloat32
2504            | AggregateFunc::MaxFloat64
2505            | AggregateFunc::MaxBool
2506            | AggregateFunc::MaxString
2507            | AggregateFunc::MaxDate
2508            | AggregateFunc::MaxTimestamp
2509            | AggregateFunc::MaxTimestampTz
2510            | AggregateFunc::MaxInterval
2511            | AggregateFunc::MaxTime
2512            | AggregateFunc::MinNumeric
2513            | AggregateFunc::MinInt16
2514            | AggregateFunc::MinInt32
2515            | AggregateFunc::MinInt64
2516            | AggregateFunc::MinUInt16
2517            | AggregateFunc::MinUInt32
2518            | AggregateFunc::MinUInt64
2519            | AggregateFunc::MinMzTimestamp
2520            | AggregateFunc::MinFloat32
2521            | AggregateFunc::MinFloat64
2522            | AggregateFunc::MinBool
2523            | AggregateFunc::MinString
2524            | AggregateFunc::MinDate
2525            | AggregateFunc::MinTimestamp
2526            | AggregateFunc::MinTimestampTz
2527            | AggregateFunc::MinInterval
2528            | AggregateFunc::MinTime
2529            | AggregateFunc::Any
2530            | AggregateFunc::All
2531            | AggregateFunc::Dummy => self.expr.is_literal(),
2532            AggregateFunc::Count => self.expr.is_literal_null(),
2533            AggregateFunc::SumInt16
2534            | AggregateFunc::SumInt32
2535            | AggregateFunc::SumInt64
2536            | AggregateFunc::SumUInt16
2537            | AggregateFunc::SumUInt32
2538            | AggregateFunc::SumUInt64
2539            | AggregateFunc::SumFloat32
2540            | AggregateFunc::SumFloat64
2541            | AggregateFunc::SumNumeric
2542            | AggregateFunc::JsonbAgg { .. }
2543            | AggregateFunc::JsonbObjectAgg { .. }
2544            | AggregateFunc::MapAgg { .. }
2545            | AggregateFunc::ArrayConcat { .. }
2546            | AggregateFunc::ListConcat { .. }
2547            | AggregateFunc::StringAgg { .. }
2548            | AggregateFunc::RowNumber { .. }
2549            | AggregateFunc::Rank { .. }
2550            | AggregateFunc::DenseRank { .. }
2551            | AggregateFunc::LagLead { .. }
2552            | AggregateFunc::FirstValue { .. }
2553            | AggregateFunc::LastValue { .. }
2554            | AggregateFunc::FusedValueWindowFunc { .. }
2555            | AggregateFunc::WindowAggregate { .. }
2556            | AggregateFunc::FusedWindowAggregate { .. } => self.expr.is_literal_err(),
2557        }
2558    }
2559
2560    /// Returns an expression that computes `self` on a group that has exactly one row.
2561    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2562    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2563    pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2564        match &self.func {
2565            // Count is one if non-null, and zero if null.
2566            AggregateFunc::Count => self
2567                .expr
2568                .clone()
2569                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2570                .if_then_else(
2571                    MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2572                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2573                ),
2574
2575            // SumInt16 takes Int16s as input, but outputs Int64s.
2576            AggregateFunc::SumInt16 => self
2577                .expr
2578                .clone()
2579                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2580
2581            // SumInt32 takes Int32s as input, but outputs Int64s.
2582            AggregateFunc::SumInt32 => self
2583                .expr
2584                .clone()
2585                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2586
2587            // SumInt64 takes Int64s as input, but outputs numerics.
2588            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2589                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2590            )),
2591
2592            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2593            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2594                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2595            ),
2596
2597            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2598            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2599                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2600            ),
2601
2602            // SumUInt64 takes UInt64s as input, but outputs numerics.
2603            AggregateFunc::SumUInt64 => {
2604                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2605                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2606                ))
2607            }
2608
2609            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2610            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2611                JsonbBuildArray,
2612                vec![
2613                    self.expr
2614                        .clone()
2615                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2616                ],
2617            ),
2618
2619            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2620            AggregateFunc::JsonbObjectAgg { .. } => {
2621                let record = self
2622                    .expr
2623                    .clone()
2624                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2625                MirScalarExpr::call_variadic(
2626                    JsonbBuildObject,
2627                    (0..2)
2628                        .map(|i| {
2629                            record
2630                                .clone()
2631                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2632                        })
2633                        .collect(),
2634                )
2635            }
2636
2637            AggregateFunc::MapAgg { value_type, .. } => {
2638                let record = self
2639                    .expr
2640                    .clone()
2641                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2642                MirScalarExpr::call_variadic(
2643                    MapBuild {
2644                        value_type: value_type.clone(),
2645                    },
2646                    (0..2)
2647                        .map(|i| {
2648                            record
2649                                .clone()
2650                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2651                        })
2652                        .collect(),
2653                )
2654            }
2655
2656            // StringAgg takes nested records of strings and outputs a string
2657            AggregateFunc::StringAgg { .. } => self
2658                .expr
2659                .clone()
2660                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2661                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2662
2663            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2664            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2665                .expr
2666                .clone()
2667                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2668
2669            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2670            AggregateFunc::RowNumber { .. } => {
2671                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2672            }
2673            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2674            AggregateFunc::DenseRank { .. } => {
2675                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2676            }
2677
2678            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2679            AggregateFunc::LagLead { lag_lead, .. } => {
2680                let tuple = self
2681                    .expr
2682                    .clone()
2683                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2684
2685                // Get the overall return type
2686                let return_type_with_orig_row = self
2687                    .typ(input_type)
2688                    .scalar_type
2689                    .unwrap_list_element_type()
2690                    .clone();
2691                let lag_lead_return_type =
2692                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2693
2694                // Extract the original row
2695                let original_row = tuple
2696                    .clone()
2697                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2698
2699                // Extract the encoded args
2700                let encoded_args =
2701                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2702
2703                let (result_expr, column_name) =
2704                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2705
2706                MirScalarExpr::call_variadic(
2707                    ListCreate {
2708                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2709                    },
2710                    vec![MirScalarExpr::call_variadic(
2711                        RecordCreate {
2712                            field_names: vec![column_name, ColumnName::from("?record?")],
2713                        },
2714                        vec![result_expr, original_row],
2715                    )],
2716                )
2717            }
2718
2719            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2720            AggregateFunc::FirstValue { window_frame, .. } => {
2721                let tuple = self
2722                    .expr
2723                    .clone()
2724                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2725
2726                // Get the overall return type
2727                let return_type_with_orig_row = self
2728                    .typ(input_type)
2729                    .scalar_type
2730                    .unwrap_list_element_type()
2731                    .clone();
2732                let first_value_return_type =
2733                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2734
2735                // Extract the original row
2736                let original_row = tuple
2737                    .clone()
2738                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2739
2740                // Extract the input value
2741                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2742
2743                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2744                    window_frame,
2745                    arg,
2746                    first_value_return_type,
2747                );
2748
2749                MirScalarExpr::call_variadic(
2750                    ListCreate {
2751                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2752                    },
2753                    vec![MirScalarExpr::call_variadic(
2754                        RecordCreate {
2755                            field_names: vec![column_name, ColumnName::from("?record?")],
2756                        },
2757                        vec![result_expr, original_row],
2758                    )],
2759                )
2760            }
2761
2762            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2763            AggregateFunc::LastValue { window_frame, .. } => {
2764                let tuple = self
2765                    .expr
2766                    .clone()
2767                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2768
2769                // Get the overall return type
2770                let return_type_with_orig_row = self
2771                    .typ(input_type)
2772                    .scalar_type
2773                    .unwrap_list_element_type()
2774                    .clone();
2775                let last_value_return_type =
2776                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2777
2778                // Extract the original row
2779                let original_row = tuple
2780                    .clone()
2781                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2782
2783                // Extract the input value
2784                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2785
2786                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2787                    window_frame,
2788                    arg,
2789                    last_value_return_type,
2790                );
2791
2792                MirScalarExpr::call_variadic(
2793                    ListCreate {
2794                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2795                    },
2796                    vec![MirScalarExpr::call_variadic(
2797                        RecordCreate {
2798                            field_names: vec![column_name, ColumnName::from("?record?")],
2799                        },
2800                        vec![result_expr, original_row],
2801                    )],
2802                )
2803            }
2804
2805            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2806            // See an example MIR in `window_func_applied_to`.
2807            AggregateFunc::WindowAggregate {
2808                wrapped_aggregate,
2809                window_frame,
2810                order_by: _,
2811            } => {
2812                // TODO: deduplicate code between the various window function cases.
2813
2814                let tuple = self
2815                    .expr
2816                    .clone()
2817                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2818
2819                // Get the overall return type
2820                let return_type = self
2821                    .typ(input_type)
2822                    .scalar_type
2823                    .unwrap_list_element_type()
2824                    .clone();
2825                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2826
2827                // Extract the original row
2828                let original_row = tuple
2829                    .clone()
2830                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2831
2832                // Extract the input value
2833                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2834
2835                let (result, column_name) = Self::on_unique_window_agg(
2836                    window_frame,
2837                    arg_expr,
2838                    input_type,
2839                    window_agg_return_type,
2840                    wrapped_aggregate,
2841                );
2842
2843                MirScalarExpr::call_variadic(
2844                    ListCreate {
2845                        elem_type: SqlScalarType::from_repr(&return_type),
2846                    },
2847                    vec![MirScalarExpr::call_variadic(
2848                        RecordCreate {
2849                            field_names: vec![column_name, ColumnName::from("?record?")],
2850                        },
2851                        vec![result, original_row],
2852                    )],
2853                )
2854            }
2855
2856            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2857            AggregateFunc::FusedWindowAggregate {
2858                wrapped_aggregates,
2859                order_by: _,
2860                window_frame,
2861            } => {
2862                // Throw away OrderByExprs
2863                let tuple = self
2864                    .expr
2865                    .clone()
2866                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2867
2868                // Extract the original row
2869                let original_row = tuple
2870                    .clone()
2871                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2872
2873                // Extract the args of the fused call
2874                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2875
2876                let return_type_with_orig_row = self
2877                    .typ(input_type)
2878                    .scalar_type
2879                    .unwrap_list_element_type()
2880                    .clone();
2881
2882                let all_func_return_types =
2883                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2884                let mut func_result_exprs = Vec::new();
2885                let mut col_names = Vec::new();
2886                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2887                    let arg = all_args
2888                        .clone()
2889                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2890                    let return_type =
2891                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2892                    let (result, column_name) = Self::on_unique_window_agg(
2893                        window_frame,
2894                        arg,
2895                        input_type,
2896                        return_type,
2897                        wrapped_aggr,
2898                    );
2899                    func_result_exprs.push(result);
2900                    col_names.push(column_name);
2901                }
2902
2903                MirScalarExpr::call_variadic(
2904                    ListCreate {
2905                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2906                    },
2907                    vec![MirScalarExpr::call_variadic(
2908                        RecordCreate {
2909                            field_names: vec![
2910                                ColumnName::from("?fused_window_aggr?"),
2911                                ColumnName::from("?record?"),
2912                            ],
2913                        },
2914                        vec![
2915                            MirScalarExpr::call_variadic(
2916                                RecordCreate {
2917                                    field_names: col_names,
2918                                },
2919                                func_result_exprs,
2920                            ),
2921                            original_row,
2922                        ],
2923                    )],
2924                )
2925            }
2926
2927            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2928            AggregateFunc::FusedValueWindowFunc {
2929                funcs,
2930                order_by: outer_order_by,
2931            } => {
2932                // Throw away OrderByExprs
2933                let tuple = self
2934                    .expr
2935                    .clone()
2936                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2937
2938                // Extract the original row
2939                let original_row = tuple
2940                    .clone()
2941                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2942
2943                // Extract the encoded args of the fused call
2944                let all_encoded_args =
2945                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2946
2947                let return_type_with_orig_row = self
2948                    .typ(input_type)
2949                    .scalar_type
2950                    .unwrap_list_element_type()
2951                    .clone();
2952
2953                let all_func_return_types =
2954                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2955                let mut func_result_exprs = Vec::new();
2956                let mut col_names = Vec::new();
2957                for (idx, func) in funcs.iter().enumerate() {
2958                    let args_for_func = all_encoded_args
2959                        .clone()
2960                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2961                    let return_type_for_func =
2962                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2963                    let (result, column_name) = match func {
2964                        AggregateFunc::LagLead {
2965                            lag_lead,
2966                            order_by,
2967                            ignore_nulls: _,
2968                        } => {
2969                            assert_eq!(order_by, outer_order_by);
2970                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2971                        }
2972                        AggregateFunc::FirstValue {
2973                            window_frame,
2974                            order_by,
2975                        } => {
2976                            assert_eq!(order_by, outer_order_by);
2977                            Self::on_unique_first_value_last_value(
2978                                window_frame,
2979                                args_for_func,
2980                                return_type_for_func,
2981                            )
2982                        }
2983                        AggregateFunc::LastValue {
2984                            window_frame,
2985                            order_by,
2986                        } => {
2987                            assert_eq!(order_by, outer_order_by);
2988                            Self::on_unique_first_value_last_value(
2989                                window_frame,
2990                                args_for_func,
2991                                return_type_for_func,
2992                            )
2993                        }
2994                        _ => panic!("unknown function in FusedValueWindowFunc"),
2995                    };
2996                    func_result_exprs.push(result);
2997                    col_names.push(column_name);
2998                }
2999
3000                MirScalarExpr::call_variadic(
3001                    ListCreate {
3002                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
3003                    },
3004                    vec![MirScalarExpr::call_variadic(
3005                        RecordCreate {
3006                            field_names: vec![
3007                                ColumnName::from("?fused_value_window_func?"),
3008                                ColumnName::from("?record?"),
3009                            ],
3010                        },
3011                        vec![
3012                            MirScalarExpr::call_variadic(
3013                                RecordCreate {
3014                                    field_names: col_names,
3015                                },
3016                                func_result_exprs,
3017                            ),
3018                            original_row,
3019                        ],
3020                    )],
3021                )
3022            }
3023
3024            // All other variants should return the argument to the aggregation.
3025            AggregateFunc::MaxNumeric
3026            | AggregateFunc::MaxInt16
3027            | AggregateFunc::MaxInt32
3028            | AggregateFunc::MaxInt64
3029            | AggregateFunc::MaxUInt16
3030            | AggregateFunc::MaxUInt32
3031            | AggregateFunc::MaxUInt64
3032            | AggregateFunc::MaxMzTimestamp
3033            | AggregateFunc::MaxFloat32
3034            | AggregateFunc::MaxFloat64
3035            | AggregateFunc::MaxBool
3036            | AggregateFunc::MaxString
3037            | AggregateFunc::MaxDate
3038            | AggregateFunc::MaxTimestamp
3039            | AggregateFunc::MaxTimestampTz
3040            | AggregateFunc::MaxInterval
3041            | AggregateFunc::MaxTime
3042            | AggregateFunc::MinNumeric
3043            | AggregateFunc::MinInt16
3044            | AggregateFunc::MinInt32
3045            | AggregateFunc::MinInt64
3046            | AggregateFunc::MinUInt16
3047            | AggregateFunc::MinUInt32
3048            | AggregateFunc::MinUInt64
3049            | AggregateFunc::MinMzTimestamp
3050            | AggregateFunc::MinFloat32
3051            | AggregateFunc::MinFloat64
3052            | AggregateFunc::MinBool
3053            | AggregateFunc::MinString
3054            | AggregateFunc::MinDate
3055            | AggregateFunc::MinTimestamp
3056            | AggregateFunc::MinTimestampTz
3057            | AggregateFunc::MinInterval
3058            | AggregateFunc::MinTime
3059            | AggregateFunc::SumFloat32
3060            | AggregateFunc::SumFloat64
3061            | AggregateFunc::SumNumeric
3062            | AggregateFunc::Any
3063            | AggregateFunc::All
3064            | AggregateFunc::Dummy => self.expr.clone(),
3065        }
3066    }
3067
3068    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3069    fn on_unique_ranking_window_funcs(
3070        &self,
3071        input_type: &[ReprColumnType],
3072        col_name: &str,
3073    ) -> MirScalarExpr {
3074        let sql_input_type: Vec<SqlColumnType> =
3075            input_type.iter().map(SqlColumnType::from_repr).collect();
3076        let list = self
3077            .expr
3078            .clone()
3079            // extract the list within the record
3080            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3081
3082        // extract the expression within the list
3083        let record = MirScalarExpr::call_variadic(
3084            ListIndex,
3085            vec![
3086                list,
3087                MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3088            ],
3089        );
3090
3091        MirScalarExpr::call_variadic(
3092            ListCreate {
3093                elem_type: self
3094                    .sql_typ(&sql_input_type)
3095                    .scalar_type
3096                    .unwrap_list_element_type()
3097                    .clone(),
3098            },
3099            vec![MirScalarExpr::call_variadic(
3100                RecordCreate {
3101                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3102                },
3103                vec![
3104                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3105                    record,
3106                ],
3107            )],
3108        )
3109    }
3110
3111    /// `on_unique` for `lag` and `lead`
3112    fn on_unique_lag_lead(
3113        lag_lead: &LagLeadType,
3114        encoded_args: MirScalarExpr,
3115        return_type: ReprScalarType,
3116    ) -> (MirScalarExpr, ColumnName) {
3117        let expr = encoded_args
3118            .clone()
3119            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3120        let offset = encoded_args
3121            .clone()
3122            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3123        let default_value =
3124            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3125
3126        // In this case, the window always has only one element, so if the offset is not null and
3127        // not zero, the default value should be returned instead.
3128        let value = offset
3129            .clone()
3130            .call_binary(
3131                MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3132                crate::func::Eq,
3133            )
3134            .if_then_else(expr, default_value);
3135        let result_expr = offset
3136            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3137            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3138
3139        let column_name = ColumnName::from(match lag_lead {
3140            LagLeadType::Lag => "?lag?",
3141            LagLeadType::Lead => "?lead?",
3142        });
3143
3144        (result_expr, column_name)
3145    }
3146
3147    /// `on_unique` for `first_value` and `last_value`
3148    fn on_unique_first_value_last_value(
3149        window_frame: &WindowFrame,
3150        arg: MirScalarExpr,
3151        return_type: ReprScalarType,
3152    ) -> (MirScalarExpr, ColumnName) {
3153        // If the window frame includes the current (single) row, return its value, null otherwise
3154        let result_expr = if window_frame.includes_current_row() {
3155            arg
3156        } else {
3157            MirScalarExpr::literal_null(return_type)
3158        };
3159        (result_expr, ColumnName::from("?first_value?"))
3160    }
3161
3162    /// `on_unique` for window aggregations
3163    fn on_unique_window_agg(
3164        window_frame: &WindowFrame,
3165        arg_expr: MirScalarExpr,
3166        input_type: &[ReprColumnType],
3167        return_type: ReprScalarType,
3168        wrapped_aggr: &AggregateFunc,
3169    ) -> (MirScalarExpr, ColumnName) {
3170        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3171        // that row. Otherwise, return the default value for the aggregate.
3172        let result_expr = if window_frame.includes_current_row() {
3173            AggregateExpr {
3174                func: wrapped_aggr.clone(),
3175                expr: arg_expr,
3176                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3177            }
3178            .on_unique(input_type)
3179        } else {
3180            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3181        };
3182        (result_expr, ColumnName::from("?window_agg?"))
3183    }
3184
3185    /// Returns whether the expression is COUNT(*) or not.  Note that
3186    /// when we define the count builtin in sql::func, we convert
3187    /// COUNT(*) to COUNT(true), making it indistinguishable from
3188    /// literal COUNT(true), but we prefer to consider this as the
3189    /// former.
3190    ///
3191    /// (HIR has the same `is_count_asterisk`.)
3192    pub fn is_count_asterisk(&self) -> bool {
3193        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3194    }
3195}
3196
3197/// Describe a join implementation in dataflow.
3198#[derive(
3199    Clone,
3200    Debug,
3201    Eq,
3202    PartialEq,
3203    Ord,
3204    PartialOrd,
3205    Serialize,
3206    Deserialize,
3207    Hash,
3208    MzReflect
3209)]
3210pub enum JoinImplementation {
3211    /// Perform a sequence of binary differential dataflow joins.
3212    ///
3213    /// The first argument indicates
3214    /// 1) the index of the starting collection,
3215    /// 2) if it should be arranged, the keys to arrange it by, and
3216    /// 3) the characteristics of the starting collection (for EXPLAINing).
3217    /// The sequence that follows lists other relation indexes, and the key for
3218    /// the arrangement we should use when joining it in.
3219    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3220    /// were used for join ordering.
3221    ///
3222    /// Each collection index should occur exactly once, either as the starting collection
3223    /// or somewhere in the list.
3224    Differential(
3225        (
3226            usize,
3227            Option<Vec<MirScalarExpr>>,
3228            Option<JoinInputCharacteristics>,
3229        ),
3230        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3231    ),
3232    /// Perform independent delta query dataflows for each input.
3233    ///
3234    /// The argument is a sequence of plans, for the input collections in order.
3235    /// Each plan starts from the corresponding index, and then in sequence joins
3236    /// against collections identified by index and with the specified arrangement key.
3237    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3238    /// used for join ordering.
3239    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3240    /// Join a user-created index with a constant collection to speed up the evaluation of a
3241    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3242    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3243    /// to represent it in MIR, because the fast path detection wants to match on this.
3244    ///
3245    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3246    IndexedFilter(
3247        GlobalId,
3248        GlobalId,
3249        Vec<MirScalarExpr>,
3250        #[mzreflect(ignore)] Vec<Row>,
3251    ),
3252    /// No implementation yet selected.
3253    Unimplemented,
3254}
3255
3256impl Default for JoinImplementation {
3257    fn default() -> Self {
3258        JoinImplementation::Unimplemented
3259    }
3260}
3261
3262impl JoinImplementation {
3263    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3264    pub fn is_implemented(&self) -> bool {
3265        match self {
3266            Self::Unimplemented => false,
3267            _ => true,
3268        }
3269    }
3270
3271    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3272    pub fn name(&self) -> Option<&'static str> {
3273        match self {
3274            Self::Differential(..) => Some("differential"),
3275            Self::DeltaQuery(..) => Some("delta"),
3276            Self::IndexedFilter(..) => Some("indexed_filter"),
3277            Self::Unimplemented => None,
3278        }
3279    }
3280}
3281
3282/// Characteristics of a join order candidate collection.
3283///
3284/// A candidate is described by a collection and a key, and may have various liabilities.
3285/// Primarily, the candidate may risk substantial inflation of records, which is something
3286/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3287/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3288/// collections in the interest of consistent tie-breaking. For more characteristics, see
3289/// comments on individual fields.
3290///
3291/// This has more than one version. `new` instantiates the appropriate version based on a
3292/// feature flag.
3293#[derive(
3294    Eq,
3295    PartialEq,
3296    Ord,
3297    PartialOrd,
3298    Debug,
3299    Clone,
3300    Serialize,
3301    Deserialize,
3302    Hash,
3303    MzReflect
3304)]
3305pub enum JoinInputCharacteristics {
3306    /// Old version, with `enable_join_prioritize_arranged` turned off.
3307    V1(JoinInputCharacteristicsV1),
3308    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3309    V2(JoinInputCharacteristicsV2),
3310}
3311
3312impl JoinInputCharacteristics {
3313    /// Creates a new instance with the given characteristics.
3314    pub fn new(
3315        unique_key: bool,
3316        key_length: usize,
3317        arranged: bool,
3318        cardinality: Option<usize>,
3319        filters: FilterCharacteristics,
3320        input: usize,
3321        enable_join_prioritize_arranged: bool,
3322    ) -> Self {
3323        if enable_join_prioritize_arranged {
3324            Self::V2(JoinInputCharacteristicsV2::new(
3325                unique_key,
3326                key_length,
3327                arranged,
3328                cardinality,
3329                filters,
3330                input,
3331            ))
3332        } else {
3333            Self::V1(JoinInputCharacteristicsV1::new(
3334                unique_key,
3335                key_length,
3336                arranged,
3337                cardinality,
3338                filters,
3339                input,
3340            ))
3341        }
3342    }
3343
3344    /// Turns the instance into a String to be printed in EXPLAIN.
3345    pub fn explain(&self) -> String {
3346        match self {
3347            Self::V1(jic) => jic.explain(),
3348            Self::V2(jic) => jic.explain(),
3349        }
3350    }
3351
3352    /// Whether the join input described by `self` is arranged.
3353    pub fn arranged(&self) -> bool {
3354        match self {
3355            Self::V1(jic) => jic.arranged,
3356            Self::V2(jic) => jic.arranged,
3357        }
3358    }
3359
3360    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3361    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3362        match self {
3363            Self::V1(jic) => &mut jic.filters,
3364            Self::V2(jic) => &mut jic.filters,
3365        }
3366    }
3367}
3368
3369/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3370#[derive(
3371    Eq,
3372    PartialEq,
3373    Ord,
3374    PartialOrd,
3375    Debug,
3376    Clone,
3377    Serialize,
3378    Deserialize,
3379    Hash,
3380    MzReflect
3381)]
3382pub struct JoinInputCharacteristicsV2 {
3383    /// An excellent indication that record count will not increase.
3384    pub unique_key: bool,
3385    /// Cross joins are bad.
3386    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3387    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3388    /// but otherwise `key_length` is less important than `arranged`.)
3389    pub not_cross: bool,
3390    /// Indicates that there will be no additional in-memory footprint.
3391    pub arranged: bool,
3392    /// A weaker signal that record count will not increase.
3393    pub key_length: usize,
3394    /// Estimated cardinality (lower is better)
3395    pub cardinality: Option<std::cmp::Reverse<usize>>,
3396    /// Characteristics of the filter that is applied at this input.
3397    pub filters: FilterCharacteristics,
3398    /// We want to prefer input earlier in the input list, for stability of ordering.
3399    pub input: std::cmp::Reverse<usize>,
3400}
3401
3402impl JoinInputCharacteristicsV2 {
3403    /// Creates a new instance with the given characteristics.
3404    pub fn new(
3405        unique_key: bool,
3406        key_length: usize,
3407        arranged: bool,
3408        cardinality: Option<usize>,
3409        filters: FilterCharacteristics,
3410        input: usize,
3411    ) -> Self {
3412        Self {
3413            unique_key,
3414            not_cross: key_length > 0,
3415            arranged,
3416            key_length,
3417            cardinality: cardinality.map(std::cmp::Reverse),
3418            filters,
3419            input: std::cmp::Reverse(input),
3420        }
3421    }
3422
3423    /// Turns the instance into a String to be printed in EXPLAIN.
3424    pub fn explain(&self) -> String {
3425        let mut e = "".to_owned();
3426        if self.unique_key {
3427            e.push_str("U");
3428        }
3429        // Don't need to print `not_cross`, because that is visible in the printed key.
3430        // if !self.not_cross {
3431        //     e.push_str("C");
3432        // }
3433        for _ in 0..self.key_length {
3434            e.push_str("K");
3435        }
3436        if self.arranged {
3437            e.push_str("A");
3438        }
3439        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3440            e.push_str(&format!("|{cardinality}|"));
3441        }
3442        e.push_str(&self.filters.explain());
3443        e
3444    }
3445}
3446
3447/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3448#[derive(
3449    Eq,
3450    PartialEq,
3451    Ord,
3452    PartialOrd,
3453    Debug,
3454    Clone,
3455    Serialize,
3456    Deserialize,
3457    Hash,
3458    MzReflect
3459)]
3460pub struct JoinInputCharacteristicsV1 {
3461    /// An excellent indication that record count will not increase.
3462    pub unique_key: bool,
3463    /// A weaker signal that record count will not increase.
3464    pub key_length: usize,
3465    /// Indicates that there will be no additional in-memory footprint.
3466    pub arranged: bool,
3467    /// Estimated cardinality (lower is better)
3468    pub cardinality: Option<std::cmp::Reverse<usize>>,
3469    /// Characteristics of the filter that is applied at this input.
3470    pub filters: FilterCharacteristics,
3471    /// We want to prefer input earlier in the input list, for stability of ordering.
3472    pub input: std::cmp::Reverse<usize>,
3473}
3474
3475impl JoinInputCharacteristicsV1 {
3476    /// Creates a new instance with the given characteristics.
3477    pub fn new(
3478        unique_key: bool,
3479        key_length: usize,
3480        arranged: bool,
3481        cardinality: Option<usize>,
3482        filters: FilterCharacteristics,
3483        input: usize,
3484    ) -> Self {
3485        Self {
3486            unique_key,
3487            key_length,
3488            arranged,
3489            cardinality: cardinality.map(std::cmp::Reverse),
3490            filters,
3491            input: std::cmp::Reverse(input),
3492        }
3493    }
3494
3495    /// Turns the instance into a String to be printed in EXPLAIN.
3496    pub fn explain(&self) -> String {
3497        let mut e = "".to_owned();
3498        if self.unique_key {
3499            e.push_str("U");
3500        }
3501        for _ in 0..self.key_length {
3502            e.push_str("K");
3503        }
3504        if self.arranged {
3505            e.push_str("A");
3506        }
3507        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3508            e.push_str(&format!("|{cardinality}|"));
3509        }
3510        e.push_str(&self.filters.explain());
3511        e
3512    }
3513}
3514
3515/// Instructions for finishing the result of a query.
3516///
3517/// The primary reason for the existence of this structure and attendant code
3518/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3519/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3520/// multisets. But as it turns out, the same idea can be used to optimize
3521/// trivial peeks.
3522///
3523/// The generic parameters are for accommodating prepared statement parameters in
3524/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3525/// `bind_parameters` on them.
3526#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3527pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3528    /// Order rows by the given columns.
3529    pub order_by: Vec<ColumnOrder>,
3530    /// Include only as many rows (after offset).
3531    pub limit: Option<L>,
3532    /// Omit as many rows.
3533    pub offset: O,
3534    /// Include only given columns.
3535    pub project: Vec<usize>,
3536}
3537
3538impl<L> RowSetFinishing<L> {
3539    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3540    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3541        RowSetFinishing {
3542            order_by: Vec::new(),
3543            limit: None,
3544            offset: 0,
3545            project: (0..arity).collect(),
3546        }
3547    }
3548    /// True if the finishing does nothing to any result set.
3549    pub fn is_trivial(&self, arity: usize) -> bool {
3550        self.limit.is_none()
3551            && self.order_by.is_empty()
3552            && self.offset == 0
3553            && self.project.iter().copied().eq(0..arity)
3554    }
3555    /// True if the finishing does not require an ORDER BY.
3556    ///
3557    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3558    /// explicit ordering we will skip an arbitrary bag of elements and return
3559    /// the first arbitrary elements in the remaining bag. The result semantics
3560    /// are still correct but maybe surprising for some users.
3561    pub fn is_streamable(&self, arity: usize) -> bool {
3562        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3563    }
3564}
3565
3566impl RowSetFinishing<NonNeg<i64>, usize> {
3567    /// The number of rows needed from before the finishing to evaluate the finishing:
3568    /// offset + limit.
3569    ///
3570    /// If it returns None, then we need all the rows.
3571    pub fn num_rows_needed(&self) -> Option<usize> {
3572        self.limit
3573            .as_ref()
3574            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3575    }
3576}
3577
3578impl RowSetFinishing {
3579    /// Applies finishing actions to a [`RowCollection`], and reports the total
3580    /// time it took to run.
3581    ///
3582    /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3583    /// well as the size of the response in bytes.
3584    pub fn finish(
3585        &self,
3586        rows: RowCollection,
3587        max_result_size: u64,
3588        max_returned_query_size: Option<u64>,
3589        duration_histogram: &Histogram,
3590    ) -> Result<(RowCollectionIter, usize), String> {
3591        let now = Instant::now();
3592        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3593        let duration = now.elapsed();
3594        duration_histogram.observe(duration.as_secs_f64());
3595
3596        result
3597    }
3598
3599    /// Implementation for [`RowSetFinishing::finish`].
3600    fn finish_inner(
3601        &self,
3602        rows: RowCollection,
3603        max_result_size: u64,
3604        max_returned_query_size: Option<u64>,
3605    ) -> Result<(RowCollectionIter, usize), String> {
3606        // How much additional memory is required to make a sorted view.
3607        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3608        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3609
3610        // Bail if creating the sorted view would require us to use too much memory.
3611        if required_memory > usize::cast_from(max_result_size) {
3612            let max_bytes = ByteSize::b(max_result_size);
3613            return Err(format!("result exceeds max size of {max_bytes}",));
3614        }
3615
3616        let sorted_view = rows;
3617        let mut iter = sorted_view
3618            .into_row_iter()
3619            .apply_offset(self.offset)
3620            .with_projection(self.project.clone());
3621
3622        if let Some(limit) = self.limit {
3623            let limit = u64::from(limit);
3624            let limit = usize::cast_from(limit);
3625            iter = iter.with_limit(limit);
3626        };
3627
3628        // TODO(parkmycar): Re-think how we can calculate the total response size without
3629        // having to iterate through the entire collection of Rows, while still
3630        // respecting the LIMIT, OFFSET, and projections.
3631        //
3632        // Note: It feels a bit bad always calculating the response size, but we almost
3633        // always need it to either check the `max_returned_query_size`, or for reporting
3634        // in the query history.
3635        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3636
3637        // Bail if we would end up returning more data to the client than they can support.
3638        if let Some(max) = max_returned_query_size {
3639            if response_size > usize::cast_from(max) {
3640                let max_bytes = ByteSize::b(max);
3641                return Err(format!("result exceeds max size of {max_bytes}"));
3642            }
3643        }
3644
3645        Ok((iter, response_size))
3646    }
3647}
3648
3649/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3650/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3651/// on query result size.
3652#[derive(Debug)]
3653pub struct RowSetFinishingIncremental {
3654    /// Include only as many rows (after offset).
3655    pub remaining_limit: Option<usize>,
3656    /// Omit as many rows.
3657    pub remaining_offset: usize,
3658    /// The maximum allowed result size, as requested by the client.
3659    pub max_returned_query_size: Option<u64>,
3660    /// Tracks our remaining allowed budget for result size.
3661    pub remaining_max_returned_query_size: Option<u64>,
3662    /// Include only given columns.
3663    pub project: Vec<usize>,
3664}
3665
3666impl RowSetFinishingIncremental {
3667    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3668    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3669    /// `true`.
3670    ///
3671    /// # Panics
3672    ///
3673    /// Panics if the result is not streamable, that is it has an ORDER BY.
3674    pub fn new(
3675        offset: usize,
3676        limit: Option<NonNeg<i64>>,
3677        project: Vec<usize>,
3678        max_returned_query_size: Option<u64>,
3679    ) -> Self {
3680        let limit = limit.map(|l| {
3681            let l = u64::from(l);
3682            let l = usize::cast_from(l);
3683            l
3684        });
3685
3686        RowSetFinishingIncremental {
3687            remaining_limit: limit,
3688            remaining_offset: offset,
3689            max_returned_query_size,
3690            remaining_max_returned_query_size: max_returned_query_size,
3691            project,
3692        }
3693    }
3694
3695    /// Applies finishing actions to the given [`RowCollection`], and reports
3696    /// the total time it took to run.
3697    ///
3698    /// Returns a [`RowCollectionIter`] that contains all of the response
3699    /// data.
3700    pub fn finish_incremental(
3701        &mut self,
3702        rows: RowCollection,
3703        max_result_size: u64,
3704        duration_histogram: &Histogram,
3705    ) -> Result<RowCollectionIter, String> {
3706        let now = Instant::now();
3707        let result = self.finish_incremental_inner(rows, max_result_size);
3708        let duration = now.elapsed();
3709        duration_histogram.observe(duration.as_secs_f64());
3710
3711        result
3712    }
3713
3714    fn finish_incremental_inner(
3715        &mut self,
3716        rows: RowCollection,
3717        max_result_size: u64,
3718    ) -> Result<RowCollectionIter, String> {
3719        // How much additional memory is required to make a sorted view.
3720        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3721        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3722
3723        // Bail if creating the sorted view would require us to use too much memory.
3724        if required_memory > usize::cast_from(max_result_size) {
3725            let max_bytes = ByteSize::b(max_result_size);
3726            return Err(format!("total result exceeds max size of {max_bytes}",));
3727        }
3728
3729        let batch_num_rows = rows.count();
3730
3731        let sorted_view = rows;
3732        let mut iter = sorted_view
3733            .into_row_iter()
3734            .apply_offset(self.remaining_offset)
3735            .with_projection(self.project.clone());
3736
3737        if let Some(limit) = self.remaining_limit {
3738            iter = iter.with_limit(limit);
3739        };
3740
3741        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3742        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3743            *remaining_limit -= iter.count();
3744        }
3745
3746        // TODO(parkmycar): Re-think how we can calculate the total response size without
3747        // having to iterate through the entire collection of Rows, while still
3748        // respecting the LIMIT, OFFSET, and projections.
3749        //
3750        // Note: It feels a bit bad always calculating the response size, but we almost
3751        // always need it to either check the `max_returned_query_size`, or for reporting
3752        // in the query history.
3753        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3754
3755        // Bail if we would end up returning more data to the client than they can support.
3756        if let Some(max) = &mut self.remaining_max_returned_query_size {
3757            if let Some(remaining) = max.checked_sub(response_size.cast_into()) {
3758                *max = remaining;
3759            } else {
3760                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3761                return Err(format!("total result exceeds max size of {max_bytes}"));
3762            }
3763        }
3764
3765        Ok(iter)
3766    }
3767}
3768
3769/// Compares two rows columnwise, using [compare_columns].
3770///
3771/// Compared to the naive implementation, this allows sharing some memory and implements some
3772/// optimizations that avoid unnecessary row unpacking.
3773#[derive(Debug, Clone)]
3774pub struct RowComparator<O: AsRef<[ColumnOrder]> = Vec<ColumnOrder>> {
3775    order: O,
3776    /// Invariant: all column references in the order are less than this limit.
3777    /// This allows for partial unpacking of rows.
3778    limit: usize,
3779    left_vec: RefCell<DatumVec>,
3780    right_vec: RefCell<DatumVec>,
3781}
3782
3783impl<O: AsRef<[ColumnOrder]>> RowComparator<O> {
3784    /// Create a new row comparator from the given column ordering.
3785    pub fn new(order: O) -> Self {
3786        let limit = order
3787            .as_ref()
3788            .iter()
3789            .map(|o| o.column + 1)
3790            .max()
3791            .unwrap_or(0);
3792        Self {
3793            order,
3794            limit,
3795            left_vec: Default::default(),
3796            right_vec: Default::default(),
3797        }
3798    }
3799
3800    /// Compare two (references to) rows.
3801    pub fn compare_rows(
3802        &self,
3803        left_row: &RowRef,
3804        right_row: &RowRef,
3805        tiebreaker: impl Fn() -> Ordering,
3806    ) -> Ordering {
3807        let order = if self.limit == 0 {
3808            Ordering::Equal
3809        } else {
3810            // These borrows should never fail, since this struct is non-sync and this function
3811            // is non-recursive.
3812            let mut left_ref = self.left_vec.borrow_mut();
3813            let mut right_ref = self.right_vec.borrow_mut();
3814            let left_cols = left_ref.borrow_with_limit(left_row, self.limit);
3815            let right_cols = right_ref.borrow_with_limit(right_row, self.limit);
3816            compare_columns(self.order.as_ref(), &left_cols, &right_cols, || {
3817                Ordering::Equal
3818            })
3819        };
3820        // Tiebreak without the vecs borrowed, in case that recursively invokes this function.
3821        order.then_with(tiebreaker)
3822    }
3823}
3824
3825/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3826/// ordering, call `tiebreaker`.
3827pub fn compare_columns<F>(
3828    order: &[ColumnOrder],
3829    left: &[Datum],
3830    right: &[Datum],
3831    tiebreaker: F,
3832) -> Ordering
3833where
3834    F: Fn() -> Ordering,
3835{
3836    for order in order {
3837        let cmp = match (&left[order.column], &right[order.column]) {
3838            (Datum::Null, Datum::Null) => Ordering::Equal,
3839            (Datum::Null, _) => {
3840                if order.nulls_last {
3841                    Ordering::Greater
3842                } else {
3843                    Ordering::Less
3844                }
3845            }
3846            (_, Datum::Null) => {
3847                if order.nulls_last {
3848                    Ordering::Less
3849                } else {
3850                    Ordering::Greater
3851                }
3852            }
3853            (lval, rval) => {
3854                if order.desc {
3855                    rval.cmp(lval)
3856                } else {
3857                    lval.cmp(rval)
3858                }
3859            }
3860        };
3861        if cmp != Ordering::Equal {
3862            return cmp;
3863        }
3864    }
3865    tiebreaker()
3866}
3867
3868/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3869/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3870///
3871/// Window frames define a subset of the partition , and only a subset of
3872/// window functions make use of the window frame.
3873#[derive(
3874    Debug,
3875    Clone,
3876    Eq,
3877    PartialEq,
3878    Ord,
3879    PartialOrd,
3880    Serialize,
3881    Deserialize,
3882    Hash,
3883    MzReflect
3884)]
3885pub struct WindowFrame {
3886    /// ROWS, RANGE or GROUPS
3887    pub units: WindowFrameUnits,
3888    /// Where the frame starts
3889    pub start_bound: WindowFrameBound,
3890    /// Where the frame ends
3891    pub end_bound: WindowFrameBound,
3892}
3893
3894impl Display for WindowFrame {
3895    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3896        write!(
3897            f,
3898            "{} between {} and {}",
3899            self.units, self.start_bound, self.end_bound
3900        )
3901    }
3902}
3903
3904impl WindowFrame {
3905    /// Return the default window frame used when one is not explicitly defined
3906    pub fn default() -> Self {
3907        WindowFrame {
3908            units: WindowFrameUnits::Range,
3909            start_bound: WindowFrameBound::UnboundedPreceding,
3910            end_bound: WindowFrameBound::CurrentRow,
3911        }
3912    }
3913
3914    fn includes_current_row(&self) -> bool {
3915        use WindowFrameBound::*;
3916        match self.start_bound {
3917            UnboundedPreceding => match self.end_bound {
3918                UnboundedPreceding => false,
3919                OffsetPreceding(0) => true,
3920                OffsetPreceding(_) => false,
3921                CurrentRow => true,
3922                OffsetFollowing(_) => true,
3923                UnboundedFollowing => true,
3924            },
3925            OffsetPreceding(0) => match self.end_bound {
3926                UnboundedPreceding => unreachable!(),
3927                OffsetPreceding(0) => true,
3928                // Any nonzero offsets here will create an empty window
3929                OffsetPreceding(_) => false,
3930                CurrentRow => true,
3931                OffsetFollowing(_) => true,
3932                UnboundedFollowing => true,
3933            },
3934            OffsetPreceding(_) => match self.end_bound {
3935                UnboundedPreceding => unreachable!(),
3936                // Window ends at the current row
3937                OffsetPreceding(0) => true,
3938                OffsetPreceding(_) => false,
3939                CurrentRow => true,
3940                OffsetFollowing(_) => true,
3941                UnboundedFollowing => true,
3942            },
3943            CurrentRow => true,
3944            OffsetFollowing(0) => match self.end_bound {
3945                UnboundedPreceding => unreachable!(),
3946                OffsetPreceding(_) => unreachable!(),
3947                CurrentRow => unreachable!(),
3948                OffsetFollowing(_) => true,
3949                UnboundedFollowing => true,
3950            },
3951            OffsetFollowing(_) => match self.end_bound {
3952                UnboundedPreceding => unreachable!(),
3953                OffsetPreceding(_) => unreachable!(),
3954                CurrentRow => unreachable!(),
3955                OffsetFollowing(_) => false,
3956                UnboundedFollowing => false,
3957            },
3958            UnboundedFollowing => false,
3959        }
3960    }
3961}
3962
3963/// Describe how frame bounds are interpreted
3964#[derive(
3965    Debug,
3966    Clone,
3967    Eq,
3968    PartialEq,
3969    Ord,
3970    PartialOrd,
3971    Serialize,
3972    Deserialize,
3973    Hash,
3974    MzReflect
3975)]
3976pub enum WindowFrameUnits {
3977    /// Each row is treated as the unit of work for bounds
3978    Rows,
3979    /// Each peer group is treated as the unit of work for bounds,
3980    /// and offset-based bounds use the value of the ORDER BY expression
3981    Range,
3982    /// Each peer group is treated as the unit of work for bounds.
3983    /// Groups is currently not supported, and it is rejected during planning.
3984    Groups,
3985}
3986
3987impl Display for WindowFrameUnits {
3988    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3989        match self {
3990            WindowFrameUnits::Rows => write!(f, "rows"),
3991            WindowFrameUnits::Range => write!(f, "range"),
3992            WindowFrameUnits::Groups => write!(f, "groups"),
3993        }
3994    }
3995}
3996
3997/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
3998///
3999/// The order between frame bounds is significant, as Postgres enforces
4000/// some restrictions there.
4001#[derive(
4002    Debug,
4003    Clone,
4004    Serialize,
4005    Deserialize,
4006    PartialEq,
4007    Eq,
4008    Hash,
4009    MzReflect,
4010    PartialOrd,
4011    Ord
4012)]
4013pub enum WindowFrameBound {
4014    /// `UNBOUNDED PRECEDING`
4015    UnboundedPreceding,
4016    /// `<N> PRECEDING`
4017    OffsetPreceding(u64),
4018    /// `CURRENT ROW`
4019    CurrentRow,
4020    /// `<N> FOLLOWING`
4021    OffsetFollowing(u64),
4022    /// `UNBOUNDED FOLLOWING`.
4023    UnboundedFollowing,
4024}
4025
4026impl Display for WindowFrameBound {
4027    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4028        match self {
4029            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4030            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4031            WindowFrameBound::CurrentRow => write!(f, "current row"),
4032            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4033            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4034        }
4035    }
4036}
4037
4038/// Maximum iterations for a LetRec.
4039#[derive(
4040    Debug,
4041    Clone,
4042    Copy,
4043    PartialEq,
4044    Eq,
4045    PartialOrd,
4046    Ord,
4047    Hash,
4048    Serialize,
4049    Deserialize
4050)]
4051pub struct LetRecLimit {
4052    /// Maximum number of iterations to evaluate.
4053    pub max_iters: NonZeroU64,
4054    /// Whether to throw an error when reaching the above limit.
4055    /// If true, we simply use the current contents of each Id as the final result.
4056    pub return_at_limit: bool,
4057}
4058
4059impl LetRecLimit {
4060    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4061    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4062        limits
4063            .iter()
4064            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4065            .min()
4066    }
4067
4068    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4069    /// WMR without ERROR AT or RETURN AT.
4070    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4071}
4072
4073impl Display for LetRecLimit {
4074    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4075        write!(f, "[recursion_limit={}", self.max_iters)?;
4076        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4077            write!(f, ", return_at_limit")?;
4078        }
4079        write!(f, "]")
4080    }
4081}
4082
4083/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4084/// (See comment in MirRelationExpr::Get.)
4085#[derive(
4086    Clone,
4087    Debug,
4088    Eq,
4089    PartialEq,
4090    Ord,
4091    PartialOrd,
4092    Serialize,
4093    Deserialize,
4094    Hash
4095)]
4096pub enum AccessStrategy {
4097    /// It's either a local Get (a CTE), or unknown at the time.
4098    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4099    /// one of the other variants.
4100    UnknownOrLocal,
4101    /// The Get will read from Persist.
4102    Persist,
4103    /// The Get will read from an index or indexes: (index id, how the index will be used).
4104    Index(Vec<(GlobalId, IndexUsageType)>),
4105    /// The Get will read a collection that is computed by the same dataflow, but in a different
4106    /// `BuildDesc` in `objects_to_build`.
4107    SameDataflow,
4108}
4109
4110#[cfg(test)]
4111mod tests {
4112    use std::num::NonZeroUsize;
4113
4114    use mz_repr::explain::text::text_string_at;
4115
4116    use crate::explain::HumanizedExplain;
4117
4118    use super::*;
4119
4120    #[mz_ore::test]
4121    fn test_row_set_finishing_as_text() {
4122        let finishing = RowSetFinishing {
4123            order_by: vec![ColumnOrder {
4124                column: 4,
4125                desc: true,
4126                nulls_last: true,
4127            }],
4128            limit: Some(NonNeg::try_from(7).unwrap()),
4129            offset: Default::default(),
4130            project: vec![1, 3, 4, 5],
4131        };
4132
4133        let mode = HumanizedExplain::new(false);
4134        let expr = mode.expr(&finishing, None);
4135
4136        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4137
4138        let exp = {
4139            use mz_ore::fmt::FormatBuffer;
4140            let mut s = String::new();
4141            write!(&mut s, "Finish");
4142            write!(&mut s, " order_by=[#4 desc nulls_last]");
4143            write!(&mut s, " limit=7");
4144            write!(&mut s, " output=[#1, #3..=#5]");
4145            writeln!(&mut s, "");
4146            s
4147        };
4148
4149        assert_eq!(act, exp);
4150    }
4151
4152    #[mz_ore::test]
4153    fn test_row_set_finishing_incremental_max_returned_query_size() {
4154        let row = Row::pack_slice(&[Datum::String("hello")]);
4155        let row_size = u64::cast_from(row.data().len());
4156        let diff = NonZeroUsize::new(1).unwrap();
4157        let batch = RowCollection::new(vec![(row, diff)], &[]);
4158
4159        // Set max_returned_query_size to hold exactly 2 batches worth of rows.
4160        let mut finishing = RowSetFinishingIncremental::new(0, None, vec![0], Some(row_size * 2));
4161
4162        let max_result_size = u64::MAX;
4163
4164        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4165        assert!(r.is_ok());
4166        assert_eq!(finishing.remaining_max_returned_query_size, Some(row_size));
4167
4168        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4169        assert!(r.is_ok());
4170        assert_eq!(finishing.remaining_max_returned_query_size, Some(0));
4171
4172        let r = finishing.finish_incremental_inner(batch, max_result_size);
4173        assert!(r.unwrap_err().contains("total result exceeds max size"));
4174    }
4175}
4176
4177/// An iterator over AST structures, which calls out nodes in difference.
4178///
4179/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4180/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4181/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4182/// to compare two ASTs.
4183mod structured_diff {
4184
4185    use super::MirRelationExpr;
4186    use itertools::Itertools;
4187
4188    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4189    pub struct MreDiff<'a> {
4190        /// Pairs of expressions that must still be compared.
4191        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4192    }
4193
4194    impl<'a> MreDiff<'a> {
4195        /// Create a new `MirRelationExpr` structured difference.
4196        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4197            MreDiff {
4198                todo: vec![(expr1, expr2)],
4199            }
4200        }
4201    }
4202
4203    impl<'a> Iterator for MreDiff<'a> {
4204        // Pairs of expressions that do not match.
4205        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4206
4207        fn next(&mut self) -> Option<Self::Item> {
4208            while let Some((expr1, expr2)) = self.todo.pop() {
4209                match (expr1, expr2) {
4210                    (
4211                        MirRelationExpr::Constant {
4212                            rows: rows1,
4213                            typ: typ1,
4214                        },
4215                        MirRelationExpr::Constant {
4216                            rows: rows2,
4217                            typ: typ2,
4218                        },
4219                    ) => {
4220                        if rows1 != rows2 || typ1 != typ2 {
4221                            return Some((expr1, expr2));
4222                        }
4223                    }
4224                    (
4225                        MirRelationExpr::Get {
4226                            id: id1,
4227                            typ: typ1,
4228                            access_strategy: as1,
4229                        },
4230                        MirRelationExpr::Get {
4231                            id: id2,
4232                            typ: typ2,
4233                            access_strategy: as2,
4234                        },
4235                    ) => {
4236                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4237                            return Some((expr1, expr2));
4238                        }
4239                    }
4240                    (
4241                        MirRelationExpr::Let {
4242                            id: id1,
4243                            body: body1,
4244                            value: value1,
4245                        },
4246                        MirRelationExpr::Let {
4247                            id: id2,
4248                            body: body2,
4249                            value: value2,
4250                        },
4251                    ) => {
4252                        if id1 != id2 {
4253                            return Some((expr1, expr2));
4254                        } else {
4255                            self.todo.push((body1, body2));
4256                            self.todo.push((value1, value2));
4257                        }
4258                    }
4259                    (
4260                        MirRelationExpr::LetRec {
4261                            ids: ids1,
4262                            body: body1,
4263                            values: values1,
4264                            limits: limits1,
4265                        },
4266                        MirRelationExpr::LetRec {
4267                            ids: ids2,
4268                            body: body2,
4269                            values: values2,
4270                            limits: limits2,
4271                        },
4272                    ) => {
4273                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4274                            return Some((expr1, expr2));
4275                        } else {
4276                            self.todo.push((body1, body2));
4277                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4278                        }
4279                    }
4280                    (
4281                        MirRelationExpr::Project {
4282                            outputs: outputs1,
4283                            input: input1,
4284                        },
4285                        MirRelationExpr::Project {
4286                            outputs: outputs2,
4287                            input: input2,
4288                        },
4289                    ) => {
4290                        if outputs1 != outputs2 {
4291                            return Some((expr1, expr2));
4292                        } else {
4293                            self.todo.push((input1, input2));
4294                        }
4295                    }
4296                    (
4297                        MirRelationExpr::Map {
4298                            scalars: scalars1,
4299                            input: input1,
4300                        },
4301                        MirRelationExpr::Map {
4302                            scalars: scalars2,
4303                            input: input2,
4304                        },
4305                    ) => {
4306                        if scalars1 != scalars2 {
4307                            return Some((expr1, expr2));
4308                        } else {
4309                            self.todo.push((input1, input2));
4310                        }
4311                    }
4312                    (
4313                        MirRelationExpr::Filter {
4314                            predicates: predicates1,
4315                            input: input1,
4316                        },
4317                        MirRelationExpr::Filter {
4318                            predicates: predicates2,
4319                            input: input2,
4320                        },
4321                    ) => {
4322                        if predicates1 != predicates2 {
4323                            return Some((expr1, expr2));
4324                        } else {
4325                            self.todo.push((input1, input2));
4326                        }
4327                    }
4328                    (
4329                        MirRelationExpr::FlatMap {
4330                            input: input1,
4331                            func: func1,
4332                            exprs: exprs1,
4333                        },
4334                        MirRelationExpr::FlatMap {
4335                            input: input2,
4336                            func: func2,
4337                            exprs: exprs2,
4338                        },
4339                    ) => {
4340                        if func1 != func2 || exprs1 != exprs2 {
4341                            return Some((expr1, expr2));
4342                        } else {
4343                            self.todo.push((input1, input2));
4344                        }
4345                    }
4346                    (
4347                        MirRelationExpr::Join {
4348                            inputs: inputs1,
4349                            equivalences: eq1,
4350                            implementation: impl1,
4351                        },
4352                        MirRelationExpr::Join {
4353                            inputs: inputs2,
4354                            equivalences: eq2,
4355                            implementation: impl2,
4356                        },
4357                    ) => {
4358                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4359                            return Some((expr1, expr2));
4360                        } else {
4361                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4362                        }
4363                    }
4364                    (
4365                        MirRelationExpr::Reduce {
4366                            aggregates: aggregates1,
4367                            input: inputs1,
4368                            group_key: gk1,
4369                            monotonic: m1,
4370                            expected_group_size: egs1,
4371                        },
4372                        MirRelationExpr::Reduce {
4373                            aggregates: aggregates2,
4374                            input: inputs2,
4375                            group_key: gk2,
4376                            monotonic: m2,
4377                            expected_group_size: egs2,
4378                        },
4379                    ) => {
4380                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4381                            return Some((expr1, expr2));
4382                        } else {
4383                            self.todo.push((inputs1, inputs2));
4384                        }
4385                    }
4386                    (
4387                        MirRelationExpr::TopK {
4388                            group_key: gk1,
4389                            order_key: order1,
4390                            input: input1,
4391                            limit: l1,
4392                            offset: o1,
4393                            monotonic: m1,
4394                            expected_group_size: egs1,
4395                        },
4396                        MirRelationExpr::TopK {
4397                            group_key: gk2,
4398                            order_key: order2,
4399                            input: input2,
4400                            limit: l2,
4401                            offset: o2,
4402                            monotonic: m2,
4403                            expected_group_size: egs2,
4404                        },
4405                    ) => {
4406                        if order1 != order2
4407                            || gk1 != gk2
4408                            || l1 != l2
4409                            || o1 != o2
4410                            || m1 != m2
4411                            || egs1 != egs2
4412                        {
4413                            return Some((expr1, expr2));
4414                        } else {
4415                            self.todo.push((input1, input2));
4416                        }
4417                    }
4418                    (
4419                        MirRelationExpr::Negate { input: input1 },
4420                        MirRelationExpr::Negate { input: input2 },
4421                    ) => {
4422                        self.todo.push((input1, input2));
4423                    }
4424                    (
4425                        MirRelationExpr::Threshold { input: input1 },
4426                        MirRelationExpr::Threshold { input: input2 },
4427                    ) => {
4428                        self.todo.push((input1, input2));
4429                    }
4430                    (
4431                        MirRelationExpr::Union {
4432                            base: base1,
4433                            inputs: inputs1,
4434                        },
4435                        MirRelationExpr::Union {
4436                            base: base2,
4437                            inputs: inputs2,
4438                        },
4439                    ) => {
4440                        if inputs1.len() != inputs2.len() {
4441                            return Some((expr1, expr2));
4442                        } else {
4443                            self.todo.push((base1, base2));
4444                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4445                        }
4446                    }
4447                    (
4448                        MirRelationExpr::ArrangeBy {
4449                            keys: keys1,
4450                            input: input1,
4451                        },
4452                        MirRelationExpr::ArrangeBy {
4453                            keys: keys2,
4454                            input: input2,
4455                        },
4456                    ) => {
4457                        if keys1 != keys2 {
4458                            return Some((expr1, expr2));
4459                        } else {
4460                            self.todo.push((input1, input2));
4461                        }
4462                    }
4463                    _ => {
4464                        return Some((expr1, expr2));
4465                    }
4466                }
4467            }
4468            None
4469        }
4470    }
4471}