Skip to main content

mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cell::RefCell;
13use std::cmp::{Ordering, max};
14use std::collections::{BTreeMap, BTreeSet};
15use std::fmt;
16use std::fmt::{Display, Formatter};
17use std::hash::{DefaultHasher, Hash, Hasher};
18use std::num::NonZeroU64;
19use std::time::Instant;
20
21use bytesize::ByteSize;
22use columnation::{Columnation, CopyRegion};
23use itertools::Itertools;
24use mz_lowertest::MzReflect;
25use mz_ore::cast::{CastFrom, CastInto};
26use mz_ore::collections::CollectionExt;
27use mz_ore::id_gen::IdGen;
28use mz_ore::metrics::Histogram;
29use mz_ore::num::NonNeg;
30use mz_ore::soft_assert_no_log;
31use mz_ore::stack::RecursionLimitError;
32use mz_ore::str::Indent;
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39    ColumnName, Datum, DatumVec, Diff, GlobalId, IntoRowIterator, ReprColumnType, ReprRelationType,
40    ReprScalarType, Row, RowIterator, RowRef, SqlColumnType, SqlRelationType, SqlScalarType,
41};
42use serde::{Deserialize, Serialize};
43
44use crate::Id::Local;
45use crate::explain::{HumanizedExpr, HumanizerMode};
46use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
47use crate::row::{RowCollection, RowCollectionIter};
48use crate::scalar::columns::Columns;
49use crate::scalar::func::variadic::{
50    JsonbBuildArray, JsonbBuildObject, ListCreate, ListIndex, MapBuild, RecordCreate,
51};
52use crate::visit::{Visit, VisitChildren};
53use crate::{
54    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
62///
63/// The recursion limit must be large enough to accommodate for the linear representation
64/// of some pathological but frequently occurring query fragments.
65///
66/// For example, in MIR we could have long chains of
67/// - (1) `Let` bindings,
68/// - (2) `CallBinary` calls with associative functions such as `+`
69///
70/// Until we fix those, we need to stick with the larger recursion limit.
71pub const RECURSION_LIMIT: usize = 2048;
72
73/// A trait for types that describe how to build a collection.
74pub trait CollectionPlan {
75    /// Collects the set of global identifiers from dataflows referenced in Get.
76    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
77
78    /// Returns the set of global identifiers from dataflows referenced in Get.
79    ///
80    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
81    fn depends_on(&self) -> BTreeSet<GlobalId> {
82        let mut out = BTreeSet::new();
83        self.depends_on_into(&mut out);
84        out
85    }
86}
87
88/// An abstract syntax tree which defines a collection.
89///
90/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
91/// written generically enough to avoid run-time compilation work.
92///
93/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
94/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
95/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
96/// one. This is because the reason for the manual implementation is not to change the semantics
97/// from the derived one, but to avoid stack overflows.
98#[allow(clippy::derived_hash_with_manual_eq)]
99#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
100pub enum MirRelationExpr {
101    /// A constant relation containing specified rows.
102    ///
103    /// The runtime memory footprint of this operator is zero.
104    ///
105    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
106    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
107    /// constant folding doesn't remove `ArrangeBy`s.
108    Constant {
109        /// Rows of the constant collection and their multiplicities.
110        rows: Result<Vec<(Row, Diff)>, EvalError>,
111        /// Schema of the collection.
112        typ: ReprRelationType,
113    },
114    /// Get an existing dataflow.
115    ///
116    /// The runtime memory footprint of this operator is zero.
117    Get {
118        /// The identifier for the collection to load.
119        #[mzreflect(ignore)]
120        id: Id,
121        /// Schema of the collection.
122        typ: ReprRelationType,
123        /// If this is a global Get, this will indicate whether we are going to read from Persist or
124        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
125        /// how downstream dataflow operations will use this index is also recorded. This is filled
126        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
127        /// lowering to LIR, but is used only by EXPLAIN.
128        #[mzreflect(ignore)]
129        access_strategy: AccessStrategy,
130    },
131    /// Introduce a temporary dataflow.
132    ///
133    /// The runtime memory footprint of this operator is zero.
134    Let {
135        /// The identifier to be used in `Get` variants to retrieve `value`.
136        #[mzreflect(ignore)]
137        id: LocalId,
138        /// The collection to be bound to `id`.
139        value: Box<MirRelationExpr>,
140        /// The result of the `Let`, evaluated with `id` bound to `value`.
141        body: Box<MirRelationExpr>,
142    },
143    /// Introduce mutually recursive bindings.
144    ///
145    /// Each `LocalId` is immediately bound to an initially empty  collection
146    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
147    /// binding is evaluated using the current contents of each other binding,
148    /// and is refreshed to contain the new evaluation. This process continues
149    /// through all bindings, and repeats as long as changes continue to occur.
150    ///
151    /// The resulting value of the expression is `body` evaluated once in the
152    /// context of the final iterates.
153    ///
154    /// A zero-binding instance can be replaced by `body`.
155    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
156    ///
157    /// The runtime memory footprint of this operator is zero.
158    LetRec {
159        /// The identifiers to be used in `Get` variants to retrieve each `value`.
160        #[mzreflect(ignore)]
161        ids: Vec<LocalId>,
162        /// The collections to be bound to each `id`.
163        values: Vec<MirRelationExpr>,
164        /// Maximum number of iterations, after which we should artificially force a fixpoint.
165        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
166        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
167        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
168        #[mzreflect(ignore)]
169        limits: Vec<Option<LetRecLimit>>,
170        /// The result of the `Let`, evaluated with `id` bound to `value`.
171        body: Box<MirRelationExpr>,
172    },
173    /// Project out some columns from a dataflow
174    ///
175    /// The runtime memory footprint of this operator is zero.
176    Project {
177        /// The source collection.
178        input: Box<MirRelationExpr>,
179        /// Indices of columns to retain.
180        outputs: Vec<usize>,
181    },
182    /// Append new columns to a dataflow
183    ///
184    /// The runtime memory footprint of this operator is zero.
185    Map {
186        /// The source collection.
187        input: Box<MirRelationExpr>,
188        /// Expressions which determine values to append to each row.
189        /// An expression may refer to columns in `input` or
190        /// expressions defined earlier in the vector
191        scalars: Vec<MirScalarExpr>,
192    },
193    /// Like Map, but yields zero-or-more output rows per input row
194    ///
195    /// The runtime memory footprint of this operator is zero.
196    FlatMap {
197        /// The source collection
198        input: Box<MirRelationExpr>,
199        /// The table func to apply
200        func: TableFunc,
201        /// The argument to the table func
202        exprs: Vec<MirScalarExpr>,
203    },
204    /// Keep rows from a dataflow where all the predicates are true
205    ///
206    /// The runtime memory footprint of this operator is zero.
207    Filter {
208        /// The source collection.
209        input: Box<MirRelationExpr>,
210        /// Predicates, each of which must be true.
211        predicates: Vec<MirScalarExpr>,
212    },
213    /// Join several collections, where some columns must be equal.
214    ///
215    /// For further details consult the documentation for [`MirRelationExpr::join`].
216    ///
217    /// The runtime memory footprint of this operator can be proportional to
218    /// the sizes of all inputs and the size of all joins of prefixes.
219    /// This may be reduced due to arrangements available at rendering time.
220    Join {
221        /// A sequence of input relations.
222        inputs: Vec<MirRelationExpr>,
223        /// A sequence of equivalence classes of expressions on the cross product of inputs.
224        ///
225        /// Each equivalence class is a list of scalar expressions, where for each class the
226        /// intended interpretation is that all evaluated expressions should be equal.
227        ///
228        /// Each scalar expression is to be evaluated over the cross-product of all records
229        /// from all inputs. In many cases this may just be column selection from specific
230        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
231        /// from multiple inputs, or just constant literals).
232        equivalences: Vec<Vec<MirScalarExpr>>,
233        /// Join implementation information.
234        #[serde(default)]
235        implementation: JoinImplementation,
236    },
237    /// Group a dataflow by some columns and aggregate over each group
238    ///
239    /// The runtime memory footprint of this operator is at most proportional to the
240    /// number of distinct records in the input and output. The actual requirements
241    /// can be less: the number of distinct inputs to each aggregate, summed across
242    /// each aggregate, plus the output size. For more details consult the code that
243    /// builds the associated dataflow.
244    Reduce {
245        /// The source collection.
246        input: Box<MirRelationExpr>,
247        /// Column indices used to form groups.
248        group_key: Vec<MirScalarExpr>,
249        /// Expressions which determine values to append to each row, after the group keys.
250        aggregates: Vec<AggregateExpr>,
251        /// True iff the input is known to monotonically increase (only addition of records).
252        #[serde(default)]
253        monotonic: bool,
254        /// User hint: expected number of values per group key. Used to optimize physical rendering.
255        #[serde(default)]
256        expected_group_size: Option<u64>,
257    },
258    /// Groups and orders within each group, limiting output.
259    ///
260    /// The runtime memory footprint of this operator is proportional to its input and output.
261    TopK {
262        /// The source collection.
263        input: Box<MirRelationExpr>,
264        /// Column indices used to form groups.
265        group_key: Vec<usize>,
266        /// Column indices used to order rows within groups.
267        order_key: Vec<ColumnOrder>,
268        /// Number of records to retain
269        #[serde(default)]
270        limit: Option<MirScalarExpr>,
271        /// Number of records to skip
272        #[serde(default)]
273        offset: usize,
274        /// True iff the input is known to monotonically increase (only addition of records).
275        #[serde(default)]
276        monotonic: bool,
277        /// User-supplied hint: how many rows will have the same group key.
278        #[serde(default)]
279        expected_group_size: Option<u64>,
280    },
281    /// Return a dataflow where the row counts are negated
282    ///
283    /// The runtime memory footprint of this operator is zero.
284    Negate {
285        /// The source collection.
286        input: Box<MirRelationExpr>,
287    },
288    /// Keep rows from a dataflow where the row counts are positive
289    ///
290    /// The runtime memory footprint of this operator is proportional to its input and output.
291    Threshold {
292        /// The source collection.
293        input: Box<MirRelationExpr>,
294    },
295    /// Adds the frequencies of elements in contained sets.
296    ///
297    /// The runtime memory footprint of this operator is zero.
298    Union {
299        /// A source collection.
300        base: Box<MirRelationExpr>,
301        /// Source collections to union.
302        inputs: Vec<MirRelationExpr>,
303    },
304    /// Technically a no-op. Used to render an index. Will be used to optimize queries
305    /// on finer grain. Each `keys` item represents a different index that should be
306    /// produced from the `keys`.
307    ///
308    /// The runtime memory footprint of this operator is proportional to its input.
309    ArrangeBy {
310        /// The source collection
311        input: Box<MirRelationExpr>,
312        /// Columns to arrange `input` by, in order of decreasing primacy
313        keys: Vec<Vec<MirScalarExpr>>,
314    },
315}
316
317impl PartialEq for MirRelationExpr {
318    fn eq(&self, other: &Self) -> bool {
319        // Capture the result and test it wrt `Ord` implementation in test environments.
320        let result = structured_diff::MreDiff::new(self, other).next().is_none();
321        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
322        result
323    }
324}
325impl Eq for MirRelationExpr {}
326
327impl MirRelationExpr {
328    /// Reports the schema of the relation.
329    ///
330    /// This is the SQL-type parallel of [`Self::typ`]; it is merely
331    /// a wrapper around it, returning a [`SqlRelationType`] instead of
332    /// a [`ReprRelationType`].
333    pub fn sql_typ(&self) -> SqlRelationType {
334        let repr_typ = self.typ();
335        SqlRelationType::from_repr(&repr_typ)
336    }
337
338    /// Reports the repr schema of the relation.
339    ///
340    /// This method determines the type through recursive traversal of the
341    /// relation expression, drawing from the types of base collections.
342    /// As such, this is not an especially cheap method, and should be used
343    /// judiciously.
344    ///
345    /// The relation type is computed incrementally with a recursive post-order
346    /// traversal, that accumulates the input types for the relations yet to be
347    /// visited in `type_stack`.
348    pub fn typ(&self) -> ReprRelationType {
349        let mut type_stack = Vec::new();
350        self.visit_pre_post(
351            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
352                match &e {
353                    MirRelationExpr::Let { body, .. } => Some(vec![&*body]),
354                    MirRelationExpr::LetRec { body, .. } => Some(vec![&*body]),
355                    _ => None,
356                }
357            },
358            &mut |e: &MirRelationExpr| {
359                match e {
360                    MirRelationExpr::Let { .. } => {
361                        let body_typ = type_stack.pop().unwrap();
362                        // Insert a dummy relation type for the value, since `typ_with_input_types`
363                        // won't look at it, but expects the relation type of the body to be second.
364                        type_stack.push(ReprRelationType::empty());
365                        type_stack.push(body_typ);
366                    }
367                    MirRelationExpr::LetRec { values, .. } => {
368                        let body_typ = type_stack.pop().unwrap();
369                        type_stack.extend(
370                            std::iter::repeat(ReprRelationType::empty()).take(values.len()),
371                        );
372                        // Insert dummy relation types for the values, since `typ_with_input_types`
373                        // won't look at them, but expects the relation type of the body to be last.
374                        type_stack.push(body_typ);
375                    }
376                    _ => {}
377                }
378                let num_inputs = e.num_inputs();
379                let relation_type =
380                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
381                type_stack.truncate(type_stack.len() - num_inputs);
382                type_stack.push(relation_type);
383            },
384        );
385        assert_eq!(type_stack.len(), 1);
386        type_stack.pop().unwrap()
387    }
388
389    /// Reports the repr schema of the relation given the repr schema of the input relations.
390    pub fn typ_with_input_types(&self, input_types: &[ReprRelationType]) -> ReprRelationType {
391        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
392        let unique_keys = self.keys_with_input_keys(
393            input_types.iter().map(|i| i.arity()),
394            input_types.iter().map(|i| &i.keys),
395        );
396        ReprRelationType::new(column_types).with_keys(unique_keys)
397    }
398
399    /// Reports the column types of the relation given the column types of the
400    /// input relations.
401    ///
402    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
403    /// variant is returned.
404    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ReprColumnType>
405    where
406        I: Iterator<Item = &'a Vec<ReprColumnType>>,
407    {
408        match self.try_col_with_input_cols(input_types) {
409            Ok(col_types) => col_types,
410            Err(err) => panic!("{err}"),
411        }
412    }
413
414    /// Reports the column types of the relation given the column types of the input relations.
415    ///
416    /// `input_types` is required to contain the column types for the input relations of
417    /// the current relation in the same order as they are visited by `try_visit_children`
418    /// method, even though not all may be used for computing the schema of the
419    /// current relation. For example, `Let` expects two input types, one for the
420    /// value relation and one for the body, in that order, but only the one for the
421    /// body is used to determine the type of the `Let` relation.
422    ///
423    /// It is meant to be used during post-order traversals to compute column types
424    /// incrementally.
425    pub fn try_col_with_input_cols<'a, I>(
426        &self,
427        mut input_types: I,
428    ) -> Result<Vec<ReprColumnType>, String>
429    where
430        I: Iterator<Item = &'a Vec<ReprColumnType>>,
431    {
432        use MirRelationExpr::*;
433
434        let col_types = match self {
435            Constant { rows, typ } => {
436                let mut col_types = typ.column_types.clone();
437                let mut seen_null = vec![false; typ.arity()];
438                if let Ok(rows) = rows {
439                    for (row, _diff) in rows {
440                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
441                            if datum.is_null() {
442                                seen_null[i] = true;
443                            }
444                        }
445                    }
446                }
447                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
448                    if !seen_null {
449                        col_types[i].nullable = false;
450                    } else {
451                        assert!(col_types[i].nullable);
452                    }
453                }
454                col_types
455            }
456            Get { typ, .. } => typ.column_types.clone(),
457            Project { outputs, .. } => {
458                let input = input_types.next().unwrap();
459                outputs.iter().map(|&i| input[i].clone()).collect()
460            }
461            Map { scalars, .. } => {
462                let mut result = input_types.next().unwrap().clone();
463                for scalar in scalars.iter() {
464                    result.push(scalar.typ(&result))
465                }
466                result
467            }
468            FlatMap { func, .. } => {
469                let mut result = input_types.next().unwrap().clone();
470                result.extend(
471                    func.output_sql_type()
472                        .column_types
473                        .iter()
474                        .map(ReprColumnType::from),
475                );
476                result
477            }
478            Filter { predicates, .. } => {
479                let mut result = input_types.next().unwrap().clone();
480
481                // Set as nonnull any columns where null values would cause
482                // any predicate to evaluate to null.
483                for column in non_nullable_columns(predicates) {
484                    result[column].nullable = false;
485                }
486                result
487            }
488            Join { equivalences, .. } => {
489                // Concatenate input column types
490                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
491                // In an equivalence class, if any column is non-null, then make all non-null
492                for equivalence in equivalences {
493                    let col_inds = equivalence
494                        .iter()
495                        .filter_map(|expr| match expr {
496                            MirScalarExpr::Column(col, _name) => Some(*col),
497                            _ => None,
498                        })
499                        .collect_vec();
500                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
501                        for i in col_inds {
502                            types.get_mut(i).unwrap().nullable = false;
503                        }
504                    }
505                }
506                types
507            }
508            Reduce {
509                group_key,
510                aggregates,
511                ..
512            } => {
513                let input = input_types.next().unwrap();
514                group_key
515                    .iter()
516                    .map(|e| e.typ(input))
517                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
518                    .collect()
519            }
520            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
521                input_types.next().unwrap().clone()
522            }
523            Let { .. } => {
524                // skip over the input types for `value`.
525                input_types.nth(1).unwrap().clone()
526            }
527            LetRec { values, .. } => {
528                // skip over the input types for `values`.
529                input_types.nth(values.len()).unwrap().clone()
530            }
531            Union { .. } => {
532                let mut result = input_types.next().unwrap().clone();
533                for input_col_types in input_types {
534                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
535                        *base_col = base_col
536                            .union(col)
537                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
538                    }
539                }
540                result
541            }
542        };
543
544        Ok(col_types)
545    }
546
547    /// Reports the unique keys of the relation given the arities and the unique
548    /// keys of the input relations.
549    ///
550    /// `input_arities` and `input_keys` are required to contain the
551    /// corresponding info for the input relations of
552    /// the current relation in the same order as they are visited by `try_visit_children`
553    /// method, even though not all may be used for computing the schema of the
554    /// current relation. For example, `Let` expects two input types, one for the
555    /// value relation and one for the body, in that order, but only the one for the
556    /// body is used to determine the type of the `Let` relation.
557    ///
558    /// It is meant to be used during post-order traversals to compute unique keys
559    /// incrementally.
560    pub fn keys_with_input_keys<'a, I, J>(
561        &self,
562        mut input_arities: I,
563        mut input_keys: J,
564    ) -> Vec<Vec<usize>>
565    where
566        I: Iterator<Item = usize>,
567        J: Iterator<Item = &'a Vec<Vec<usize>>>,
568    {
569        use MirRelationExpr::*;
570
571        let mut keys = match self {
572            Constant {
573                rows: Ok(rows),
574                typ,
575            } => {
576                let n_cols = typ.arity();
577                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
578                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
579                for (row, diff) in rows {
580                    for (i, datum) in row.iter().enumerate() {
581                        if datum != Datum::Dummy {
582                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
583                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
584                                if is_dupe {
585                                    unique_values_per_col[i] = None;
586                                }
587                            }
588                        }
589                    }
590                }
591                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
592                    vec![vec![]]
593                } else {
594                    // XXX - Multi-column keys are not detected.
595                    typ.keys
596                        .iter()
597                        .cloned()
598                        .chain(
599                            unique_values_per_col
600                                .into_iter()
601                                .enumerate()
602                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
603                                .map(|(idx, _)| vec![idx]),
604                        )
605                        .collect()
606                }
607            }
608            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
609            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
610            Let { .. } => {
611                // skip over the unique keys for value
612                input_keys.nth(1).unwrap().clone()
613            }
614            LetRec { values, .. } => {
615                // skip over the unique keys for value
616                input_keys.nth(values.len()).unwrap().clone()
617            }
618            Project { outputs, .. } => {
619                let input = input_keys.next().unwrap();
620                input
621                    .iter()
622                    .filter_map(|key_set| {
623                        if key_set.iter().all(|k| outputs.contains(k)) {
624                            Some(
625                                key_set
626                                    .iter()
627                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
628                                    .collect(),
629                            )
630                        } else {
631                            None
632                        }
633                    })
634                    .collect()
635            }
636            Map { scalars, .. } => {
637                let mut remappings = Vec::new();
638                let arity = input_arities.next().unwrap();
639                for (column, scalar) in scalars.iter().enumerate() {
640                    // assess whether the scalar preserves uniqueness,
641                    // and could participate in a key!
642
643                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
644                        match expr {
645                            MirScalarExpr::CallUnary { func, expr } => {
646                                if func.preserves_uniqueness() {
647                                    uniqueness(expr)
648                                } else {
649                                    None
650                                }
651                            }
652                            MirScalarExpr::Column(c, _name) => Some(*c),
653                            _ => None,
654                        }
655                    }
656
657                    if let Some(c) = uniqueness(scalar) {
658                        remappings.push((c, column + arity));
659                    }
660                }
661
662                let mut result = input_keys.next().unwrap().clone();
663                let mut new_keys = Vec::new();
664                // Any column in `remappings` could be replaced in a key
665                // by the corresponding c. This could lead to combinatorial
666                // explosion using our current representation, so we wont
667                // do that. Instead, we'll handle the case of one remapping.
668                if remappings.len() == 1 {
669                    let (old, new) = remappings.pop().unwrap();
670                    for key in &result {
671                        if key.contains(&old) {
672                            let mut new_key: Vec<usize> =
673                                key.iter().cloned().filter(|k| k != &old).collect();
674                            new_key.push(new);
675                            new_key.sort_unstable();
676                            new_keys.push(new_key);
677                        }
678                    }
679                    result.append(&mut new_keys);
680                }
681                result
682            }
683            FlatMap { .. } => {
684                // FlatMap can add duplicate rows, so input keys are no longer
685                // valid
686                vec![]
687            }
688            Negate { .. } => {
689                // Although negate may have distinct records for each key,
690                // the multiplicity is -1 rather than 1. This breaks many
691                // of the optimization uses of "keys".
692                vec![]
693            }
694            Filter { predicates, .. } => {
695                // A filter inherits the keys of its input unless the filters
696                // have reduced the input to a single row, in which case the
697                // keys of the input are `()`.
698                let mut input = input_keys.next().unwrap().clone();
699
700                if !input.is_empty() {
701                    // Track columns equated to literals, which we can prune.
702                    let mut cols_equal_to_literal = BTreeSet::new();
703
704                    // Perform union find on `col1 = col2` to establish
705                    // connected components of equated columns. Absent any
706                    // equalities, this will be `0 .. #c` (where #c is the
707                    // greatest column referenced by a predicate), but each
708                    // equality will orient the root of the greater to the root
709                    // of the lesser.
710                    let mut union_find = Vec::new();
711
712                    for expr in predicates.iter() {
713                        if let MirScalarExpr::CallBinary {
714                            func: crate::BinaryFunc::Eq(_),
715                            expr1,
716                            expr2,
717                        } = expr
718                        {
719                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
720                                if expr2.is_literal_ok() {
721                                    cols_equal_to_literal.insert(c);
722                                }
723                            }
724                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
725                                if expr1.is_literal_ok() {
726                                    cols_equal_to_literal.insert(c);
727                                }
728                            }
729                            // Perform union-find to equate columns.
730                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
731                                if c1 != c2 {
732                                    // Ensure union_find has entries up to
733                                    // max(c1, c2) by filling up missing
734                                    // positions with identity mappings.
735                                    while union_find.len() <= std::cmp::max(c1, c2) {
736                                        union_find.push(union_find.len());
737                                    }
738                                    let mut r1 = c1; // Find the representative column of [c1].
739                                    while r1 != union_find[r1] {
740                                        assert!(union_find[r1] < r1);
741                                        r1 = union_find[r1];
742                                    }
743                                    let mut r2 = c2; // Find the representative column of [c2].
744                                    while r2 != union_find[r2] {
745                                        assert!(union_find[r2] < r2);
746                                        r2 = union_find[r2];
747                                    }
748                                    // Union [c1] and [c2] by pointing the
749                                    // larger to the smaller representative (we
750                                    // update the remaining equivalence class
751                                    // members only once after this for-loop).
752                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
753                                }
754                            }
755                        }
756                    }
757
758                    // Complete union-find by pointing each element at its representative column.
759                    for i in 0..union_find.len() {
760                        // Iteration not required, as each prior already references the right column.
761                        union_find[i] = union_find[union_find[i]];
762                    }
763
764                    // Remove columns bound to literals, and remap columns equated to earlier columns.
765                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
766                    for key_set in &mut input {
767                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
768                        for col in key_set.iter_mut() {
769                            if let Some(equiv) = union_find.get(*col) {
770                                *col = *equiv;
771                            }
772                        }
773                        key_set.sort();
774                        key_set.dedup();
775                    }
776                    input.sort();
777                    input.dedup();
778
779                    // Expand out each key to each of its equivalent forms.
780                    // Each instance of `col` can be replaced by any equivalent column.
781                    // This has the potential to result in exponentially sized number of unique keys,
782                    // and in the future we should probably maintain unique keys modulo equivalence.
783
784                    // First, compute an inverse map from each representative
785                    // column `sub` to all other equivalent columns `col`.
786                    let mut subs = Vec::new();
787                    for (col, sub) in union_find.iter().enumerate() {
788                        if *sub != col {
789                            assert!(*sub < col);
790                            while subs.len() <= *sub {
791                                subs.push(Vec::new());
792                            }
793                            subs[*sub].push(col);
794                        }
795                    }
796                    // For each column, substitute for it in each occurrence.
797                    let mut to_add = Vec::new();
798                    for (col, subs) in subs.iter().enumerate() {
799                        if !subs.is_empty() {
800                            for key_set in input.iter() {
801                                if key_set.contains(&col) {
802                                    let mut to_extend = key_set.clone();
803                                    to_extend.retain(|c| c != &col);
804                                    for sub in subs {
805                                        to_extend.push(*sub);
806                                        to_add.push(to_extend.clone());
807                                        to_extend.pop();
808                                    }
809                                }
810                            }
811                        }
812                        // No deduplication, as we cannot introduce duplicates.
813                        input.append(&mut to_add);
814                    }
815                    for key_set in input.iter_mut() {
816                        key_set.sort();
817                        key_set.dedup();
818                    }
819                }
820                input
821            }
822            Join { equivalences, .. } => {
823                // It is important the `new_from_input_arities` constructor is
824                // used. Otherwise, Materialize may potentially end up in an
825                // infinite loop.
826                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
827
828                input_mapper.global_keys(input_keys, equivalences)
829            }
830            Reduce { group_key, .. } => {
831                // The group key should form a key, but we might already have
832                // keys that are subsets of the group key, and should retain
833                // those instead, if so.
834                let mut result = Vec::new();
835                for key_set in input_keys.next().unwrap() {
836                    if key_set
837                        .iter()
838                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
839                    {
840                        result.push(
841                            key_set
842                                .iter()
843                                .map(|i| {
844                                    group_key
845                                        .iter()
846                                        .position(|k| k == &MirScalarExpr::column(*i))
847                                        .unwrap()
848                                })
849                                .collect::<Vec<_>>(),
850                        );
851                    }
852                }
853                if result.is_empty() {
854                    result.push((0..group_key.len()).collect());
855                }
856                result
857            }
858            TopK {
859                group_key, limit, ..
860            } => {
861                // If `limit` is `Some(1)` then the group key will become
862                // a unique key, as there will be only one record with that key.
863                let mut result = input_keys.next().unwrap().clone();
864                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
865                    result.push(group_key.clone())
866                }
867                result
868            }
869            Union { base, inputs } => {
870                // Generally, unions do not have any unique keys, because
871                // each input might duplicate some. However, there is at
872                // least one idiomatic structure that does preserve keys,
873                // which results from SQL aggregations that must populate
874                // absent records with default values. In that pattern,
875                // the union of one GET with its negation, which has first
876                // been subjected to a projection and map, we can remove
877                // their influence on the key structure.
878                //
879                // If there are A, B, each with a unique `key` such that
880                // we are looking at
881                //
882                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
883                //
884                // Then we can report `key` as a unique key.
885                //
886                // TODO: make unique key structure an optimization analysis
887                // rather than part of the type information.
888                // TODO: perhaps ensure that (above) A.proj(key) is a
889                // subset of B, as otherwise there are negative records
890                // and who knows what is true (not expected, but again
891                // who knows what the query plan might look like).
892
893                let arity = input_arities.next().unwrap();
894                let (base_projection, base_with_project_stripped) =
895                    if let MirRelationExpr::Project { input, outputs } = &**base {
896                        (outputs.clone(), &**input)
897                    } else {
898                        // A input without a project is equivalent to an input
899                        // with the project being all columns in the input in order.
900                        ((0..arity).collect::<Vec<_>>(), &**base)
901                    };
902                let mut result = Vec::new();
903                if let MirRelationExpr::Get {
904                    id: first_id,
905                    typ: _,
906                    ..
907                } = base_with_project_stripped
908                {
909                    if inputs.len() == 1 {
910                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
911                            if let MirRelationExpr::Union { base, inputs } = &**input {
912                                if inputs.len() == 1 {
913                                    if let Some((input, outputs)) = base.is_negated_project() {
914                                        if let MirRelationExpr::Get {
915                                            id: second_id,
916                                            typ: _,
917                                            ..
918                                        } = input
919                                        {
920                                            if first_id == second_id {
921                                                result.extend(
922                                                    input_keys
923                                                        .next()
924                                                        .unwrap()
925                                                        .into_iter()
926                                                        .filter(|key| {
927                                                            key.iter().all(|c| {
928                                                                outputs.get(*c) == Some(c)
929                                                                    && base_projection.get(*c)
930                                                                        == Some(c)
931                                                            })
932                                                        })
933                                                        .cloned(),
934                                                );
935                                            }
936                                        }
937                                    }
938                                }
939                            }
940                        }
941                    }
942                }
943                // Important: do not inherit keys of either input, as not unique.
944                result
945            }
946        };
947        keys.sort();
948        keys.dedup();
949        keys
950    }
951
952    /// The number of columns in the relation.
953    ///
954    /// This number is determined from the type, which is determined recursively
955    /// at non-trivial cost.
956    ///
957    /// The arity is computed incrementally with a recursive post-order
958    /// traversal, that accumulates the arities for the relations yet to be
959    /// visited in `arity_stack`.
960    pub fn arity(&self) -> usize {
961        let mut arity_stack = Vec::new();
962        self.visit_pre_post(
963            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
964                match &e {
965                    MirRelationExpr::Let { body, .. } => {
966                        // Do not traverse the value sub-graph, since it's not relevant for
967                        // determining the arity of Let operators.
968                        Some(vec![&*body])
969                    }
970                    MirRelationExpr::LetRec { body, .. } => {
971                        // Do not traverse the value sub-graph, since it's not relevant for
972                        // determining the arity of Let operators.
973                        Some(vec![&*body])
974                    }
975                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
976                        // No further traversal is required; these operators know their arity.
977                        Some(Vec::new())
978                    }
979                    _ => None,
980                }
981            },
982            &mut |e: &MirRelationExpr| {
983                match &e {
984                    MirRelationExpr::Let { .. } => {
985                        let body_arity = arity_stack.pop().unwrap();
986                        arity_stack.push(0);
987                        arity_stack.push(body_arity);
988                    }
989                    MirRelationExpr::LetRec { values, .. } => {
990                        let body_arity = arity_stack.pop().unwrap();
991                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
992                        arity_stack.push(body_arity);
993                    }
994                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
995                        arity_stack.push(0);
996                    }
997                    _ => {}
998                }
999                let num_inputs = e.num_inputs();
1000                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1001                let arity = e.arity_with_input_arities(input_arities);
1002                arity_stack.push(arity);
1003            },
1004        );
1005        assert_eq!(arity_stack.len(), 1);
1006        arity_stack.pop().unwrap()
1007    }
1008
1009    /// Reports the arity of the relation given the schema of the input relations.
1010    ///
1011    /// `input_arities` is required to contain the arities for the input relations of
1012    /// the current relation in the same order as they are visited by `try_visit_children`
1013    /// method, even though not all may be used for computing the schema of the
1014    /// current relation. For example, `Let` expects two input types, one for the
1015    /// value relation and one for the body, in that order, but only the one for the
1016    /// body is used to determine the type of the `Let` relation.
1017    ///
1018    /// It is meant to be used during post-order traversals to compute arities
1019    /// incrementally.
1020    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1021    where
1022        I: Iterator<Item = usize>,
1023    {
1024        use MirRelationExpr::*;
1025
1026        match self {
1027            Constant { rows: _, typ } => typ.arity(),
1028            Get { typ, .. } => typ.arity(),
1029            Let { .. } => {
1030                input_arities.next();
1031                input_arities.next().unwrap()
1032            }
1033            LetRec { values, .. } => {
1034                for _ in 0..values.len() {
1035                    input_arities.next();
1036                }
1037                input_arities.next().unwrap()
1038            }
1039            Project { outputs, .. } => outputs.len(),
1040            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1041            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1042            Join { .. } => input_arities.sum(),
1043            Reduce {
1044                input: _,
1045                group_key,
1046                aggregates,
1047                ..
1048            } => group_key.len() + aggregates.len(),
1049            Filter { .. }
1050            | TopK { .. }
1051            | Negate { .. }
1052            | Threshold { .. }
1053            | Union { .. }
1054            | ArrangeBy { .. } => input_arities.next().unwrap(),
1055        }
1056    }
1057
1058    /// The number of child relations this relation has.
1059    pub fn num_inputs(&self) -> usize {
1060        let mut count = 0;
1061
1062        self.visit_children(|_| count += 1);
1063
1064        count
1065    }
1066
1067    /// Constructs a constant collection from specific rows and schema, where
1068    /// each row will have a multiplicity of one.
1069    pub fn constant(rows: Vec<Vec<Datum>>, typ: ReprRelationType) -> Self {
1070        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1071        MirRelationExpr::constant_diff(rows, typ)
1072    }
1073
1074    /// Constructs a constant collection from specific rows and schema, where
1075    /// each row can have an arbitrary multiplicity.
1076    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: ReprRelationType) -> Self {
1077        for (row, _diff) in &rows {
1078            for (datum, column_typ) in row.iter().zip_eq(typ.column_types.iter()) {
1079                assert!(
1080                    datum.is_instance_of(column_typ),
1081                    "Expected datum of type {:?}, got value {:?}",
1082                    column_typ,
1083                    datum
1084                );
1085            }
1086        }
1087        let rows = Ok(rows
1088            .into_iter()
1089            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1090            .collect());
1091        MirRelationExpr::Constant { rows, typ }
1092    }
1093
1094    /// If self is a constant, return the value and the type, otherwise `None`.
1095    /// Looks behind `ArrangeBy`s.
1096    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &ReprRelationType)> {
1097        match self {
1098            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1099            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1100            _ => None,
1101        }
1102    }
1103
1104    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1105    /// Looks behind `ArrangeBy`s.
1106    pub fn as_const_mut(
1107        &mut self,
1108    ) -> Option<(
1109        &mut Result<Vec<(Row, Diff)>, EvalError>,
1110        &mut ReprRelationType,
1111    )> {
1112        match self {
1113            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1114            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1115            _ => None,
1116        }
1117    }
1118
1119    /// If self is a constant error, return the error, otherwise `None`.
1120    /// Looks behind `ArrangeBy`s.
1121    pub fn as_const_err(&self) -> Option<&EvalError> {
1122        match self {
1123            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1124            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1125            _ => None,
1126        }
1127    }
1128
1129    /// Checks if `self` is the single element collection with no columns.
1130    pub fn is_constant_singleton(&self) -> bool {
1131        if let Some((Ok(rows), typ)) = self.as_const() {
1132            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1133        } else {
1134            false
1135        }
1136    }
1137
1138    /// Constructs the expression for getting a local collection.
1139    pub fn local_get(id: LocalId, typ: ReprRelationType) -> Self {
1140        MirRelationExpr::Get {
1141            id: Id::Local(id),
1142            typ,
1143            access_strategy: AccessStrategy::UnknownOrLocal,
1144        }
1145    }
1146
1147    /// Constructs the expression for getting a global collection
1148    pub fn global_get(id: GlobalId, typ: ReprRelationType) -> Self {
1149        MirRelationExpr::Get {
1150            id: Id::Global(id),
1151            typ,
1152            access_strategy: AccessStrategy::UnknownOrLocal,
1153        }
1154    }
1155
1156    /// Retains only the columns specified by `output`.
1157    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1158        if let MirRelationExpr::Project {
1159            outputs: columns, ..
1160        } = &mut self
1161        {
1162            // Update `outputs` to reference base columns of `input`.
1163            for column in outputs.iter_mut() {
1164                *column = columns[*column];
1165            }
1166            *columns = outputs;
1167            self
1168        } else {
1169            MirRelationExpr::Project {
1170                input: Box::new(self),
1171                outputs,
1172            }
1173        }
1174    }
1175
1176    /// Append to each row the results of applying elements of `scalar`.
1177    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1178        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1179            s.extend(scalars);
1180            self
1181        } else if !scalars.is_empty() {
1182            MirRelationExpr::Map {
1183                input: Box::new(self),
1184                scalars,
1185            }
1186        } else {
1187            self
1188        }
1189    }
1190
1191    /// Append to each row a single `scalar`.
1192    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1193        self.map(vec![scalar])
1194    }
1195
1196    /// Like `map`, but yields zero-or-more output rows per input row
1197    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1198        MirRelationExpr::FlatMap {
1199            input: Box::new(self),
1200            func,
1201            exprs,
1202        }
1203    }
1204
1205    /// Retain only the rows satisfying each of several predicates.
1206    pub fn filter<I>(mut self, predicates: I) -> Self
1207    where
1208        I: IntoIterator<Item = MirScalarExpr>,
1209    {
1210        // Extract existing predicates
1211        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1212            self = *input;
1213            predicates
1214        } else {
1215            Vec::new()
1216        };
1217        // Normalize collection of predicates.
1218        new_predicates.extend(predicates);
1219        new_predicates.retain(|p| !p.is_literal_true());
1220        new_predicates.sort();
1221        new_predicates.dedup();
1222        // Introduce a `Filter` only if we have predicates.
1223        if !new_predicates.is_empty() {
1224            self = MirRelationExpr::Filter {
1225                input: Box::new(self),
1226                predicates: new_predicates,
1227            };
1228        }
1229
1230        self
1231    }
1232
1233    /// Form the Cartesian outer-product of rows in both inputs.
1234    pub fn product(mut self, right: Self) -> Self {
1235        if right.is_constant_singleton() {
1236            self
1237        } else if self.is_constant_singleton() {
1238            right
1239        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1240            inputs.push(right);
1241            self
1242        } else {
1243            MirRelationExpr::join(vec![self, right], vec![])
1244        }
1245    }
1246
1247    /// Performs a relational equijoin among the input collections.
1248    ///
1249    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1250    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1251    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1252    ///
1253    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1254    /// input collection and for each row examining its `column`th column.
1255    ///
1256    /// # Example
1257    ///
1258    /// ```rust
1259    /// use mz_repr::{Datum, SqlColumnType, ReprRelationType, ReprScalarType};
1260    /// use mz_expr::MirRelationExpr;
1261    ///
1262    /// // A common schema for each input.
1263    /// let schema = ReprRelationType::new(vec![
1264    ///     ReprScalarType::Int32.nullable(false),
1265    ///     ReprScalarType::Int32.nullable(false),
1266    /// ]);
1267    ///
1268    /// // the specific data are not important here.
1269    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1270    ///
1271    /// // Three collections that could have been different.
1272    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1273    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1274    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1275    ///
1276    /// // Join the three relations looking for triangles, like so.
1277    /// //
1278    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1279    /// let joined = MirRelationExpr::join(
1280    ///     vec![input0, input1, input2],
1281    ///     vec![
1282    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1283    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1284    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1285    ///     ],
1286    /// );
1287    ///
1288    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1289    /// // A projection resolves this and produces the correct output.
1290    /// let result = joined.project(vec![0, 1, 3]);
1291    /// ```
1292    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1293        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1294
1295        let equivalences = variables
1296            .into_iter()
1297            .map(|vs| {
1298                vs.into_iter()
1299                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1300                    .collect::<Vec<_>>()
1301            })
1302            .collect::<Vec<_>>();
1303
1304        Self::join_scalars(inputs, equivalences)
1305    }
1306
1307    /// Constructs a join operator from inputs and required-equal scalar expressions.
1308    pub fn join_scalars(
1309        mut inputs: Vec<MirRelationExpr>,
1310        equivalences: Vec<Vec<MirScalarExpr>>,
1311    ) -> Self {
1312        // Remove all constant inputs that are the identity for join.
1313        // They neither introduce nor modify any column references.
1314        inputs.retain(|i| !i.is_constant_singleton());
1315        MirRelationExpr::Join {
1316            inputs,
1317            equivalences,
1318            implementation: JoinImplementation::Unimplemented,
1319        }
1320    }
1321
1322    /// Perform a key-wise reduction / aggregation.
1323    ///
1324    /// The `group_key` argument indicates columns in the input collection that should
1325    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1326    /// one output column in addition to the keys.
1327    pub fn reduce(
1328        self,
1329        group_key: Vec<usize>,
1330        aggregates: Vec<AggregateExpr>,
1331        expected_group_size: Option<u64>,
1332    ) -> Self {
1333        MirRelationExpr::Reduce {
1334            input: Box::new(self),
1335            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1336            aggregates,
1337            monotonic: false,
1338            expected_group_size,
1339        }
1340    }
1341
1342    /// Perform a key-wise reduction order by and limit.
1343    ///
1344    /// The `group_key` argument indicates columns in the input collection that should
1345    /// be grouped, the `order_key` argument indicates columns that should be further
1346    /// used to order records within groups, and the `limit` argument constrains the
1347    /// total number of records that should be produced in each group.
1348    pub fn top_k(
1349        self,
1350        group_key: Vec<usize>,
1351        order_key: Vec<ColumnOrder>,
1352        limit: Option<MirScalarExpr>,
1353        offset: usize,
1354        expected_group_size: Option<u64>,
1355    ) -> Self {
1356        MirRelationExpr::TopK {
1357            input: Box::new(self),
1358            group_key,
1359            order_key,
1360            limit,
1361            offset,
1362            expected_group_size,
1363            monotonic: false,
1364        }
1365    }
1366
1367    /// Negates the occurrences of each row.
1368    pub fn negate(self) -> Self {
1369        if let MirRelationExpr::Negate { input } = self {
1370            *input
1371        } else {
1372            MirRelationExpr::Negate {
1373                input: Box::new(self),
1374            }
1375        }
1376    }
1377
1378    /// Removes all but the first occurrence of each row.
1379    pub fn distinct(self) -> Self {
1380        let arity = self.arity();
1381        self.distinct_by((0..arity).collect())
1382    }
1383
1384    /// Removes all but the first occurrence of each key. Columns not included
1385    /// in the `group_key` are discarded.
1386    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1387        self.reduce(group_key, vec![], None)
1388    }
1389
1390    /// Discards rows with a negative frequency.
1391    pub fn threshold(self) -> Self {
1392        if let MirRelationExpr::Threshold { .. } = &self {
1393            self
1394        } else {
1395            MirRelationExpr::Threshold {
1396                input: Box::new(self),
1397            }
1398        }
1399    }
1400
1401    /// Unions together any number inputs.
1402    ///
1403    /// If `inputs` is empty, then an empty relation of type `typ` is
1404    /// constructed.
1405    pub fn union_many(mut inputs: Vec<Self>, typ: ReprRelationType) -> Self {
1406        // Deconstruct `inputs` as `Union`s and reconstitute.
1407        let mut flat_inputs = Vec::with_capacity(inputs.len());
1408        for input in inputs {
1409            if let MirRelationExpr::Union { base, inputs } = input {
1410                flat_inputs.push(*base);
1411                flat_inputs.extend(inputs);
1412            } else {
1413                flat_inputs.push(input);
1414            }
1415        }
1416        inputs = flat_inputs;
1417        if inputs.len() == 0 {
1418            MirRelationExpr::Constant {
1419                rows: Ok(vec![]),
1420                typ,
1421            }
1422        } else if inputs.len() == 1 {
1423            inputs.into_element()
1424        } else {
1425            MirRelationExpr::Union {
1426                base: Box::new(inputs.remove(0)),
1427                inputs,
1428            }
1429        }
1430    }
1431
1432    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1433    pub fn union(self, other: Self) -> Self {
1434        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1435        let mut flat_inputs = Vec::with_capacity(2);
1436        if let MirRelationExpr::Union { base, inputs } = self {
1437            flat_inputs.push(*base);
1438            flat_inputs.extend(inputs);
1439        } else {
1440            flat_inputs.push(self);
1441        }
1442        if let MirRelationExpr::Union { base, inputs } = other {
1443            flat_inputs.push(*base);
1444            flat_inputs.extend(inputs);
1445        } else {
1446            flat_inputs.push(other);
1447        }
1448
1449        MirRelationExpr::Union {
1450            base: Box::new(flat_inputs.remove(0)),
1451            inputs: flat_inputs,
1452        }
1453    }
1454
1455    /// Arranges the collection by the specified columns
1456    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1457        MirRelationExpr::ArrangeBy {
1458            input: Box::new(self),
1459            keys: keys.to_owned(),
1460        }
1461    }
1462
1463    /// Indicates if this is a constant empty collection.
1464    ///
1465    /// A false value does not mean the collection is known to be non-empty,
1466    /// only that we cannot currently determine that it is statically empty.
1467    pub fn is_empty(&self) -> bool {
1468        if let Some((Ok(rows), ..)) = self.as_const() {
1469            rows.is_empty()
1470        } else {
1471            false
1472        }
1473    }
1474
1475    /// If the expression is a negated project, return the input and the projection.
1476    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1477        if let MirRelationExpr::Negate { input } = self {
1478            if let MirRelationExpr::Project { input, outputs } = &**input {
1479                return Some((&**input, outputs));
1480            }
1481        }
1482        if let MirRelationExpr::Project { input, outputs } = self {
1483            if let MirRelationExpr::Negate { input } = &**input {
1484                return Some((&**input, outputs));
1485            }
1486        }
1487        None
1488    }
1489
1490    /// Pretty-print this [MirRelationExpr] to a string.
1491    pub fn pretty(&self) -> String {
1492        let config = ExplainConfig::default();
1493        self.debug_explain(&config, None)
1494    }
1495
1496    /// Pretty-print this [MirRelationExpr] to a string using a custom
1497    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1498    /// This is intended for debugging and tests, not users.
1499    pub fn debug_explain(
1500        &self,
1501        config: &ExplainConfig,
1502        humanizer: Option<&dyn ExprHumanizer>,
1503    ) -> String {
1504        text_string_at(self, || PlanRenderingContext {
1505            indent: Indent::default(),
1506            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1507            annotations: BTreeMap::default(),
1508            config,
1509            ambiguous_ids: BTreeSet::default(),
1510        })
1511    }
1512
1513    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1514    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1515    /// would find. Keys and nullability are ignored in the given `SqlRelationType`, and instead we set
1516    /// the best possible key and nullability, since we are making an empty collection.
1517    ///
1518    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1519    /// the correct type.
1520    pub fn take_safely(&mut self, typ: Option<ReprRelationType>) -> MirRelationExpr {
1521        if let Some(typ) = &typ {
1522            let self_typ = self.typ();
1523            soft_assert_no_log!(
1524                self_typ
1525                    .column_types
1526                    .iter()
1527                    .zip_eq(typ.column_types.iter())
1528                    .all(|(t1, t2)| t1.scalar_type == t2.scalar_type)
1529            );
1530        }
1531        let mut typ = typ.unwrap_or_else(|| self.typ());
1532        typ.keys = vec![vec![]];
1533        for ct in typ.column_types.iter_mut() {
1534            ct.nullable = false;
1535        }
1536        std::mem::replace(
1537            self,
1538            MirRelationExpr::Constant {
1539                rows: Ok(vec![]),
1540                typ,
1541            },
1542        )
1543    }
1544
1545    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1546    /// types. Nullability is ignored in the given `SqlColumnType`s, and instead we set the best
1547    /// possible nullability, since we are making an empty collection.
1548    pub fn take_safely_with_sql_col_types(&mut self, typ: Vec<SqlColumnType>) -> MirRelationExpr {
1549        self.take_safely(Some(ReprRelationType::from(&SqlRelationType::new(typ))))
1550    }
1551
1552    /// Like [`Self::take_safely_with_col_types`], but accepts `Vec<ReprColumnType>`.
1553    ///
1554    /// This is the preferred entry point for optimizer transforms, where repr
1555    /// types are the native currency. Internally converts to [`SqlColumnType`]
1556    /// and delegates to [`Self::take_safely_with_col_types`].
1557    pub fn take_safely_with_col_types(&mut self, typ: Vec<ReprColumnType>) -> MirRelationExpr {
1558        self.take_safely(Some(ReprRelationType::new(typ)))
1559    }
1560
1561    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1562    ///
1563    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1564    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1565        let empty = MirRelationExpr::Constant {
1566            rows: Ok(vec![]),
1567            typ: ReprRelationType::new(Vec::new()),
1568        };
1569        std::mem::replace(self, empty)
1570    }
1571
1572    /// Replaces `self` with some logic applied to `self`.
1573    pub fn replace_using<F>(&mut self, logic: F)
1574    where
1575        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1576    {
1577        let empty = MirRelationExpr::Constant {
1578            rows: Ok(vec![]),
1579            typ: ReprRelationType::new(Vec::new()),
1580        };
1581        let expr = std::mem::replace(self, empty);
1582        *self = logic(expr);
1583    }
1584
1585    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1586    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1587    where
1588        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1589    {
1590        if let MirRelationExpr::Get { .. } = self {
1591            // already done
1592            body(id_gen, self)
1593        } else {
1594            let id = LocalId::new(id_gen.allocate_id());
1595            let get = MirRelationExpr::Get {
1596                id: Id::Local(id),
1597                typ: self.typ(),
1598                access_strategy: AccessStrategy::UnknownOrLocal,
1599            };
1600            let body = (body)(id_gen, get)?;
1601            Ok(MirRelationExpr::Let {
1602                id,
1603                value: Box::new(self),
1604                body: Box::new(body),
1605            })
1606        }
1607    }
1608
1609    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1610    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1611    pub fn anti_lookup<E>(
1612        self,
1613        id_gen: &mut IdGen,
1614        keys_and_values: MirRelationExpr,
1615        default: Vec<(Datum, ReprScalarType)>,
1616    ) -> Result<MirRelationExpr, E> {
1617        let (data, column_types): (Vec<_>, Vec<_>) = default
1618            .into_iter()
1619            .map(|(datum, scalar_type)| {
1620                (
1621                    datum,
1622                    ReprColumnType {
1623                        scalar_type,
1624                        nullable: datum.is_null(),
1625                    },
1626                )
1627            })
1628            .unzip();
1629        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1630        self.let_in(id_gen, |_id_gen, get_keys| {
1631            let get_keys_arity = get_keys.arity();
1632            Ok(MirRelationExpr::join(
1633                vec![
1634                    // all the missing keys (with count 1)
1635                    keys_and_values
1636                        .distinct_by((0..get_keys_arity).collect())
1637                        .negate()
1638                        .union(get_keys.clone().distinct()),
1639                    // join with keys to get the correct counts
1640                    get_keys.clone(),
1641                ],
1642                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1643            )
1644            // get rid of the extra copies of columns from keys
1645            .project((0..get_keys_arity).collect())
1646            // This join is logically equivalent to
1647            // `.map(<default_expr>)`, but using a join allows for
1648            // potential predicate pushdown and elision in the
1649            // optimizer.
1650            .product(MirRelationExpr::constant(
1651                vec![data],
1652                ReprRelationType::new(column_types),
1653            )))
1654        })
1655    }
1656
1657    /// Return:
1658    /// * every row in keys_and_values
1659    /// * every row in `self` that does not have a matching row in the first columns of
1660    ///   `keys_and_values`, using `default` to fill in the remaining columns
1661    /// (This is LEFT OUTER JOIN if:
1662    /// 1) `default` is a row of null
1663    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1664    pub fn lookup<E>(
1665        self,
1666        id_gen: &mut IdGen,
1667        keys_and_values: MirRelationExpr,
1668        default: Vec<(Datum<'static>, ReprScalarType)>,
1669    ) -> Result<MirRelationExpr, E> {
1670        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1671            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1672                id_gen,
1673                get_keys_and_values,
1674                default,
1675            )?))
1676        })
1677    }
1678
1679    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1680    pub fn contains_temporal(&self) -> bool {
1681        let mut contains = false;
1682        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1683        contains
1684    }
1685
1686    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1687    ///
1688    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1689    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1690    where
1691        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1692    {
1693        use MirRelationExpr::*;
1694        match self {
1695            Map { scalars, .. } => {
1696                for s in scalars {
1697                    f(s)?;
1698                }
1699            }
1700            Filter { predicates, .. } => {
1701                for p in predicates {
1702                    f(p)?;
1703                }
1704            }
1705            FlatMap { exprs, .. } => {
1706                for expr in exprs {
1707                    f(expr)?;
1708                }
1709            }
1710            Join {
1711                inputs: _,
1712                equivalences,
1713                implementation,
1714            } => {
1715                for equivalence in equivalences {
1716                    for expr in equivalence {
1717                        f(expr)?;
1718                    }
1719                }
1720                match implementation {
1721                    JoinImplementation::Differential((_, start_key, _), order) => {
1722                        if let Some(start_key) = start_key {
1723                            for k in start_key {
1724                                f(k)?;
1725                            }
1726                        }
1727                        for (_, lookup_key, _) in order {
1728                            for k in lookup_key {
1729                                f(k)?;
1730                            }
1731                        }
1732                    }
1733                    JoinImplementation::DeltaQuery(paths) => {
1734                        for path in paths {
1735                            for (_, lookup_key, _) in path {
1736                                for k in lookup_key {
1737                                    f(k)?;
1738                                }
1739                            }
1740                        }
1741                    }
1742                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1743                        for k in index_key {
1744                            f(k)?;
1745                        }
1746                    }
1747                    JoinImplementation::Unimplemented => {} // No scalar exprs
1748                }
1749            }
1750            ArrangeBy { keys, .. } => {
1751                for key in keys {
1752                    for s in key {
1753                        f(s)?;
1754                    }
1755                }
1756            }
1757            Reduce {
1758                group_key,
1759                aggregates,
1760                ..
1761            } => {
1762                for s in group_key {
1763                    f(s)?;
1764                }
1765                for agg in aggregates {
1766                    f(&mut agg.expr)?;
1767                }
1768            }
1769            TopK { limit, .. } => {
1770                if let Some(s) = limit {
1771                    f(s)?;
1772                }
1773            }
1774            Constant { .. }
1775            | Get { .. }
1776            | Let { .. }
1777            | LetRec { .. }
1778            | Project { .. }
1779            | Negate { .. }
1780            | Threshold { .. }
1781            | Union { .. } => (),
1782        }
1783        Ok(())
1784    }
1785
1786    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1787    /// rooted at `self`.
1788    ///
1789    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1790    /// nodes.
1791    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1792    where
1793        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1794    {
1795        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1796    }
1797
1798    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1799    /// rooted at `self`.
1800    ///
1801    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1802    /// nodes.
1803    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1804    where
1805        F: FnMut(&mut MirScalarExpr),
1806    {
1807        self.try_visit_scalars_mut(&mut |s| {
1808            f(s);
1809            Ok::<_, RecursionLimitError>(())
1810        })
1811        .expect("Unexpected error in `visit_scalars_mut` call");
1812    }
1813
1814    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1815    ///
1816    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1817    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1818    where
1819        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1820    {
1821        use MirRelationExpr::*;
1822        match self {
1823            Map { scalars, .. } => {
1824                for s in scalars {
1825                    f(s)?;
1826                }
1827            }
1828            Filter { predicates, .. } => {
1829                for p in predicates {
1830                    f(p)?;
1831                }
1832            }
1833            FlatMap { exprs, .. } => {
1834                for expr in exprs {
1835                    f(expr)?;
1836                }
1837            }
1838            Join {
1839                inputs: _,
1840                equivalences,
1841                implementation,
1842            } => {
1843                for equivalence in equivalences {
1844                    for expr in equivalence {
1845                        f(expr)?;
1846                    }
1847                }
1848                match implementation {
1849                    JoinImplementation::Differential((_, start_key, _), order) => {
1850                        if let Some(start_key) = start_key {
1851                            for k in start_key {
1852                                f(k)?;
1853                            }
1854                        }
1855                        for (_, lookup_key, _) in order {
1856                            for k in lookup_key {
1857                                f(k)?;
1858                            }
1859                        }
1860                    }
1861                    JoinImplementation::DeltaQuery(paths) => {
1862                        for path in paths {
1863                            for (_, lookup_key, _) in path {
1864                                for k in lookup_key {
1865                                    f(k)?;
1866                                }
1867                            }
1868                        }
1869                    }
1870                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1871                        for k in index_key {
1872                            f(k)?;
1873                        }
1874                    }
1875                    JoinImplementation::Unimplemented => {} // No scalar exprs
1876                }
1877            }
1878            ArrangeBy { keys, .. } => {
1879                for key in keys {
1880                    for s in key {
1881                        f(s)?;
1882                    }
1883                }
1884            }
1885            Reduce {
1886                group_key,
1887                aggregates,
1888                ..
1889            } => {
1890                for s in group_key {
1891                    f(s)?;
1892                }
1893                for agg in aggregates {
1894                    f(&agg.expr)?;
1895                }
1896            }
1897            TopK { limit, .. } => {
1898                if let Some(s) = limit {
1899                    f(s)?;
1900                }
1901            }
1902            Constant { .. }
1903            | Get { .. }
1904            | Let { .. }
1905            | LetRec { .. }
1906            | Project { .. }
1907            | Negate { .. }
1908            | Threshold { .. }
1909            | Union { .. } => (),
1910        }
1911        Ok(())
1912    }
1913
1914    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1915    /// rooted at `self`.
1916    ///
1917    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1918    /// nodes.
1919    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
1920    where
1921        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1922    {
1923        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
1924    }
1925
1926    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1927    /// rooted at `self`.
1928    ///
1929    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1930    /// nodes.
1931    pub fn visit_scalars<F>(&self, f: &mut F)
1932    where
1933        F: FnMut(&MirScalarExpr),
1934    {
1935        self.try_visit_scalars(&mut |s| {
1936            f(s);
1937            Ok::<_, RecursionLimitError>(())
1938        })
1939        .expect("Unexpected error in `visit_scalars` call");
1940    }
1941
1942    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
1943    /// stack overflow in `drop_in_place`.
1944    ///
1945    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
1946    /// dropped or otherwise overwritten.
1947    pub fn destroy_carefully(&mut self) {
1948        let mut todo = vec![self.take_dangerous()];
1949        while let Some(mut expr) = todo.pop() {
1950            for child in expr.children_mut() {
1951                todo.push(child.take_dangerous());
1952            }
1953        }
1954    }
1955
1956    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
1957    /// debug printing purposes.
1958    pub fn debug_size_and_depth(&self) -> (usize, usize) {
1959        let mut size = 0;
1960        let mut max_depth = 0;
1961        let mut todo = vec![(self, 1)];
1962        while let Some((expr, depth)) = todo.pop() {
1963            size += 1;
1964            max_depth = max(max_depth, depth);
1965            todo.extend(expr.children().map(|c| (c, depth + 1)));
1966        }
1967        (size, max_depth)
1968    }
1969
1970    /// The MirRelationExpr is considered potentially expensive if and only if
1971    /// at least one of the following conditions is true:
1972    ///
1973    ///  - It contains at least one MirScalarExpr with a function call.
1974    ///  - It contains at least one FlatMap or a Reduce operator.
1975    ///  - We run into a RecursionLimitError while analyzing the expression.
1976    ///
1977    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
1978    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
1979    pub fn could_run_expensive_function(&self) -> bool {
1980        let mut result = false;
1981        use MirRelationExpr::*;
1982        use MirScalarExpr::*;
1983        if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
1984            result |= match scalar {
1985                Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
1986                // Function calls are considered expensive
1987                CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
1988            };
1989            Ok(())
1990        }) {
1991            // Conservatively set `true` if on RecursionLimitError.
1992            result = true;
1993        }
1994        self.visit_pre(|e: &MirRelationExpr| {
1995            // FlatMap has a table function; Reduce has an aggregate function.
1996            // Other constructs use MirScalarExpr to run a function
1997            result |= matches!(e, FlatMap { .. } | Reduce { .. });
1998        });
1999        result
2000    }
2001
2002    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2003    /// than what `Hashable::hashed` would give us.)
2004    pub fn hash_to_u64(&self) -> u64 {
2005        let mut h = DefaultHasher::new();
2006        self.hash(&mut h);
2007        h.finish()
2008    }
2009}
2010
2011// `LetRec` helpers
2012impl MirRelationExpr {
2013    /// True when `expr` contains a `LetRec` AST node.
2014    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2015        let mut worklist = vec![self];
2016        while let Some(expr) = worklist.pop() {
2017            if let MirRelationExpr::LetRec { .. } = expr {
2018                return true;
2019            }
2020            worklist.extend(expr.children());
2021        }
2022        false
2023    }
2024
2025    /// Return the number of sub-expressions in the tree (including self).
2026    pub fn size(&self) -> usize {
2027        let mut size = 0;
2028        self.visit_pre(|_| size += 1);
2029        size
2030    }
2031
2032    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2033    /// iterations. These are those ids that have a reference before they are defined, when reading
2034    /// all the bindings in order.
2035    ///
2036    /// For example:
2037    /// ```SQL
2038    /// WITH MUTUALLY RECURSIVE
2039    ///     x(...) AS f(z),
2040    ///     y(...) AS g(x),
2041    ///     z(...) AS h(y)
2042    /// ...;
2043    /// ```
2044    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2045    /// iteration.
2046    ///
2047    /// Note that if a binding references itself, that is also returned.
2048    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2049        let mut used_across_iterations = BTreeSet::new();
2050        let mut defined = BTreeSet::new();
2051        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2052            value.visit_pre(|expr| {
2053                if let MirRelationExpr::Get {
2054                    id: Local(get_id), ..
2055                } = expr
2056                {
2057                    // If we haven't seen a definition for it yet, then this will refer
2058                    // to the previous iteration.
2059                    // The `ids.contains` part of the condition is needed to exclude
2060                    // those ids that are not really in this LetRec, but either an inner
2061                    // or outer one.
2062                    if !defined.contains(get_id) && ids.contains(get_id) {
2063                        used_across_iterations.insert(*get_id);
2064                    }
2065                }
2066            });
2067            defined.insert(*binding_id);
2068        }
2069        used_across_iterations
2070    }
2071
2072    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2073    ///
2074    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2075    /// identifiers are rewritten to be the constant collection.
2076    /// This makes the computation perform exactly "one" iteration.
2077    ///
2078    /// This was used only temporarily while developing `LetRec`.
2079    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2080        let mut deadlist = BTreeSet::new();
2081        let mut worklist = vec![self];
2082        while let Some(expr) = worklist.pop() {
2083            if let MirRelationExpr::LetRec {
2084                ids,
2085                values,
2086                limits: _,
2087                body,
2088            } = expr
2089            {
2090                let ids_values = values
2091                    .drain(..)
2092                    .zip_eq(ids)
2093                    .map(|(value, id)| (*id, value))
2094                    .collect::<Vec<_>>();
2095                *expr = body.take_dangerous();
2096                for (id, mut value) in ids_values.into_iter().rev() {
2097                    // Remove references to potentially recursive identifiers.
2098                    deadlist.insert(id);
2099                    value.visit_pre_mut(|e| {
2100                        if let MirRelationExpr::Get {
2101                            id: crate::Id::Local(id),
2102                            typ,
2103                            ..
2104                        } = e
2105                        {
2106                            let typ = typ.clone();
2107                            if deadlist.contains(id) {
2108                                e.take_safely(Some(typ));
2109                            }
2110                        }
2111                    });
2112                    *expr = MirRelationExpr::Let {
2113                        id,
2114                        value: Box::new(value),
2115                        body: Box::new(expr.take_dangerous()),
2116                    };
2117                }
2118                worklist.push(expr);
2119            } else {
2120                worklist.extend(expr.children_mut().rev());
2121            }
2122        }
2123    }
2124
2125    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2126    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2127    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2128    /// redefinition.
2129    ///
2130    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2131    pub fn collect_expirations(
2132        id: LocalId,
2133        expr: &MirRelationExpr,
2134        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2135    ) {
2136        expr.visit_pre(|e| {
2137            if let MirRelationExpr::Get {
2138                id: Id::Local(referenced_id),
2139                ..
2140            } = e
2141            {
2142                // The following check needs `renumber_bindings` to have run recently
2143                if referenced_id >= &id {
2144                    expire_whens
2145                        .entry(*referenced_id)
2146                        .or_insert_with(Vec::new)
2147                        .push(id);
2148                }
2149            }
2150        });
2151    }
2152
2153    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2154    /// about such Ids whose information depended on the earlier definition of `id`, according to
2155    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2156    pub fn do_expirations<I>(
2157        redefined_id: LocalId,
2158        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2159        id_infos: &mut BTreeMap<LocalId, I>,
2160    ) -> Vec<(LocalId, I)> {
2161        let mut expired_infos = Vec::new();
2162        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2163            for expired_id in expirations.into_iter() {
2164                if let Some(offer) = id_infos.remove(&expired_id) {
2165                    expired_infos.push((expired_id, offer));
2166                }
2167            }
2168        }
2169        expired_infos
2170    }
2171}
2172/// Augment non-nullability of columns, by observing either
2173/// 1. Predicates that explicitly test for null values, and
2174/// 2. Columns that if null would make a predicate be null.
2175pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2176    let mut nonnull_required_columns = BTreeSet::new();
2177    for predicate in predicates {
2178        // Add any columns that being null would force the predicate to be null.
2179        // Should that happen, the row would be discarded.
2180        predicate.non_null_requirements(&mut nonnull_required_columns);
2181
2182        /*
2183        Test for explicit checks that a column is non-null.
2184
2185        This analysis is ad hoc, and will miss things:
2186
2187        materialize=> create table a(x int, y int);
2188        CREATE TABLE
2189        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2190        Optimized Plan
2191        --------------------------------------------------------------------------------------------------------
2192        Explained Query:                                                                                      +
2193        Project (#0) // { types: "(integer?)" }                                                             +
2194        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2195        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2196                                                                                  +
2197        Source materialize.public.a                                                                           +
2198        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2199
2200        (1 row)
2201        */
2202
2203        if let MirScalarExpr::CallUnary {
2204            func: UnaryFunc::Not(scalar_func::Not),
2205            expr,
2206        } = predicate
2207        {
2208            if let MirScalarExpr::CallUnary {
2209                func: UnaryFunc::IsNull(scalar_func::IsNull),
2210                expr,
2211            } = &**expr
2212            {
2213                if let MirScalarExpr::Column(c, _name) = &**expr {
2214                    nonnull_required_columns.insert(*c);
2215                }
2216            }
2217        }
2218    }
2219
2220    nonnull_required_columns
2221}
2222
2223impl CollectionPlan for MirRelationExpr {
2224    /// Collects the global collections that this MIR expression directly depends on, i.e., that it
2225    /// has a `Get` for. (It does _not_ traverse view definitions transitively.)
2226    ///
2227    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2228    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2229    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2230        if let MirRelationExpr::Get {
2231            id: Id::Global(id), ..
2232        } = self
2233        {
2234            out.insert(*id);
2235        }
2236        self.visit_children(|expr| expr.depends_on_into(out))
2237    }
2238}
2239
2240impl MirRelationExpr {
2241    /// Iterates through references to child expressions.
2242    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2243        let mut first = None;
2244        let mut second = None;
2245        let mut rest = None;
2246        let mut last = None;
2247
2248        use MirRelationExpr::*;
2249        match self {
2250            Constant { .. } | Get { .. } => (),
2251            Let { value, body, .. } => {
2252                first = Some(&**value);
2253                second = Some(&**body);
2254            }
2255            LetRec { values, body, .. } => {
2256                rest = Some(values);
2257                last = Some(&**body);
2258            }
2259            Project { input, .. }
2260            | Map { input, .. }
2261            | FlatMap { input, .. }
2262            | Filter { input, .. }
2263            | Reduce { input, .. }
2264            | TopK { input, .. }
2265            | Negate { input }
2266            | Threshold { input }
2267            | ArrangeBy { input, .. } => {
2268                first = Some(&**input);
2269            }
2270            Join { inputs, .. } => {
2271                rest = Some(inputs);
2272            }
2273            Union { base, inputs } => {
2274                first = Some(&**base);
2275                rest = Some(inputs);
2276            }
2277        }
2278
2279        first
2280            .into_iter()
2281            .chain(second)
2282            .chain(rest.into_iter().flatten())
2283            .chain(last)
2284    }
2285
2286    /// Iterates through mutable references to child expressions.
2287    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2288        let mut first = None;
2289        let mut second = None;
2290        let mut rest = None;
2291        let mut last = None;
2292
2293        use MirRelationExpr::*;
2294        match self {
2295            Constant { .. } | Get { .. } => (),
2296            Let { value, body, .. } => {
2297                first = Some(&mut **value);
2298                second = Some(&mut **body);
2299            }
2300            LetRec { values, body, .. } => {
2301                rest = Some(values);
2302                last = Some(&mut **body);
2303            }
2304            Project { input, .. }
2305            | Map { input, .. }
2306            | FlatMap { input, .. }
2307            | Filter { input, .. }
2308            | Reduce { input, .. }
2309            | TopK { input, .. }
2310            | Negate { input }
2311            | Threshold { input }
2312            | ArrangeBy { input, .. } => {
2313                first = Some(&mut **input);
2314            }
2315            Join { inputs, .. } => {
2316                rest = Some(inputs);
2317            }
2318            Union { base, inputs } => {
2319                first = Some(&mut **base);
2320                rest = Some(inputs);
2321            }
2322        }
2323
2324        first
2325            .into_iter()
2326            .chain(second)
2327            .chain(rest.into_iter().flatten())
2328            .chain(last)
2329    }
2330
2331    /// Iterative pre-order visitor.
2332    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2333        let mut worklist = vec![self];
2334        while let Some(expr) = worklist.pop() {
2335            f(expr);
2336            worklist.extend(expr.children().rev());
2337        }
2338    }
2339
2340    /// Iterative pre-order visitor.
2341    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2342        let mut worklist = vec![self];
2343        while let Some(expr) = worklist.pop() {
2344            f(expr);
2345            worklist.extend(expr.children_mut().rev());
2346        }
2347    }
2348
2349    /// Return a vector of references to the subtrees of this expression
2350    /// in post-visit order (the last element is `&self`).
2351    pub fn post_order_vec(&self) -> Vec<&Self> {
2352        let mut stack = vec![self];
2353        let mut result = vec![];
2354        while let Some(expr) = stack.pop() {
2355            result.push(expr);
2356            stack.extend(expr.children());
2357        }
2358        result.reverse();
2359        result
2360    }
2361}
2362
2363impl VisitChildren<Self> for MirRelationExpr {
2364    fn visit_children<F>(&self, mut f: F)
2365    where
2366        F: FnMut(&Self),
2367    {
2368        for child in self.children() {
2369            f(child)
2370        }
2371    }
2372
2373    fn visit_mut_children<F>(&mut self, mut f: F)
2374    where
2375        F: FnMut(&mut Self),
2376    {
2377        for child in self.children_mut() {
2378            f(child)
2379        }
2380    }
2381
2382    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2383    where
2384        F: FnMut(&Self) -> Result<(), E>,
2385    {
2386        for child in self.children() {
2387            f(child)?
2388        }
2389        Ok(())
2390    }
2391
2392    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2393    where
2394        F: FnMut(&mut Self) -> Result<(), E>,
2395    {
2396        for child in self.children_mut() {
2397            f(child)?
2398        }
2399        Ok(())
2400    }
2401
2402    fn children<'a>(&'a self) -> impl DoubleEndedIterator<Item = &'a MirRelationExpr>
2403    where
2404        Self: 'a,
2405    {
2406        self.children()
2407    }
2408
2409    fn children_mut<'a>(&'a mut self) -> impl DoubleEndedIterator<Item = &'a mut MirRelationExpr>
2410    where
2411        Self: 'a,
2412    {
2413        self.children_mut()
2414    }
2415}
2416
2417/// Specification for an ordering by a column.
2418#[derive(
2419    Debug,
2420    Clone,
2421    Copy,
2422    Eq,
2423    PartialEq,
2424    Ord,
2425    PartialOrd,
2426    Serialize,
2427    Deserialize,
2428    Hash,
2429    MzReflect
2430)]
2431pub struct ColumnOrder {
2432    /// The column index.
2433    pub column: usize,
2434    /// Whether to sort in descending order.
2435    #[serde(default)]
2436    pub desc: bool,
2437    /// Whether to sort nulls last.
2438    #[serde(default)]
2439    pub nulls_last: bool,
2440}
2441
2442impl Columnation for ColumnOrder {
2443    type InnerRegion = CopyRegion<Self>;
2444}
2445
2446impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2447where
2448    M: HumanizerMode,
2449{
2450    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2451        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2452        write!(
2453            f,
2454            "{} {} {}",
2455            self.child(&self.expr.column),
2456            if self.expr.desc { "desc" } else { "asc" },
2457            if self.expr.nulls_last {
2458                "nulls_last"
2459            } else {
2460                "nulls_first"
2461            },
2462        )
2463    }
2464}
2465
2466/// Describes an aggregation expression.
2467#[derive(
2468    Clone,
2469    Debug,
2470    Eq,
2471    PartialEq,
2472    Ord,
2473    PartialOrd,
2474    Serialize,
2475    Deserialize,
2476    Hash,
2477    MzReflect
2478)]
2479pub struct AggregateExpr {
2480    /// Names the aggregation function.
2481    pub func: AggregateFunc,
2482    /// An expression which extracts from each row the input to `func`.
2483    pub expr: MirScalarExpr,
2484    /// Should the aggregation be applied only to distinct results in each group.
2485    #[serde(default)]
2486    pub distinct: bool,
2487}
2488
2489impl AggregateExpr {
2490    /// Computes the type of this `AggregateExpr`.
2491    pub fn sql_typ(&self, column_types: &[SqlColumnType]) -> SqlColumnType {
2492        self.func.output_sql_type(self.expr.sql_typ(column_types))
2493    }
2494
2495    /// Computes the type of this `AggregateExpr`.
2496    pub fn typ(&self, column_types: &[ReprColumnType]) -> ReprColumnType {
2497        self.func.output_type(self.expr.typ(column_types))
2498    }
2499
2500    /// Returns whether the expression has a constant result.
2501    pub fn is_constant(&self) -> bool {
2502        match self.func {
2503            AggregateFunc::MaxNumeric
2504            | AggregateFunc::MaxInt16
2505            | AggregateFunc::MaxInt32
2506            | AggregateFunc::MaxInt64
2507            | AggregateFunc::MaxUInt16
2508            | AggregateFunc::MaxUInt32
2509            | AggregateFunc::MaxUInt64
2510            | AggregateFunc::MaxMzTimestamp
2511            | AggregateFunc::MaxFloat32
2512            | AggregateFunc::MaxFloat64
2513            | AggregateFunc::MaxBool
2514            | AggregateFunc::MaxString
2515            | AggregateFunc::MaxDate
2516            | AggregateFunc::MaxTimestamp
2517            | AggregateFunc::MaxTimestampTz
2518            | AggregateFunc::MaxInterval
2519            | AggregateFunc::MaxTime
2520            | AggregateFunc::MinNumeric
2521            | AggregateFunc::MinInt16
2522            | AggregateFunc::MinInt32
2523            | AggregateFunc::MinInt64
2524            | AggregateFunc::MinUInt16
2525            | AggregateFunc::MinUInt32
2526            | AggregateFunc::MinUInt64
2527            | AggregateFunc::MinMzTimestamp
2528            | AggregateFunc::MinFloat32
2529            | AggregateFunc::MinFloat64
2530            | AggregateFunc::MinBool
2531            | AggregateFunc::MinString
2532            | AggregateFunc::MinDate
2533            | AggregateFunc::MinTimestamp
2534            | AggregateFunc::MinTimestampTz
2535            | AggregateFunc::MinInterval
2536            | AggregateFunc::MinTime
2537            | AggregateFunc::Any
2538            | AggregateFunc::All
2539            | AggregateFunc::Dummy => self.expr.is_literal(),
2540            AggregateFunc::Count => self.expr.is_literal_null(),
2541            AggregateFunc::SumInt16
2542            | AggregateFunc::SumInt32
2543            | AggregateFunc::SumInt64
2544            | AggregateFunc::SumUInt16
2545            | AggregateFunc::SumUInt32
2546            | AggregateFunc::SumUInt64
2547            | AggregateFunc::SumFloat32
2548            | AggregateFunc::SumFloat64
2549            | AggregateFunc::SumNumeric
2550            | AggregateFunc::JsonbAgg { .. }
2551            | AggregateFunc::JsonbObjectAgg { .. }
2552            | AggregateFunc::MapAgg { .. }
2553            | AggregateFunc::ArrayConcat { .. }
2554            | AggregateFunc::ListConcat { .. }
2555            | AggregateFunc::StringAgg { .. }
2556            | AggregateFunc::RowNumber { .. }
2557            | AggregateFunc::Rank { .. }
2558            | AggregateFunc::DenseRank { .. }
2559            | AggregateFunc::LagLead { .. }
2560            | AggregateFunc::FirstValue { .. }
2561            | AggregateFunc::LastValue { .. }
2562            | AggregateFunc::FusedValueWindowFunc { .. }
2563            | AggregateFunc::WindowAggregate { .. }
2564            | AggregateFunc::FusedWindowAggregate { .. } => self.expr.is_literal_err(),
2565        }
2566    }
2567
2568    /// Returns an expression that computes `self` on a group that has exactly one row.
2569    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2570    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2571    pub fn on_unique(&self, input_type: &[ReprColumnType]) -> MirScalarExpr {
2572        match &self.func {
2573            // Count is one if non-null, and zero if null.
2574            AggregateFunc::Count => self
2575                .expr
2576                .clone()
2577                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2578                .if_then_else(
2579                    MirScalarExpr::literal_ok(Datum::Int64(0), ReprScalarType::Int64),
2580                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
2581                ),
2582
2583            // SumInt16 takes Int16s as input, but outputs Int64s.
2584            AggregateFunc::SumInt16 => self
2585                .expr
2586                .clone()
2587                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2588
2589            // SumInt32 takes Int32s as input, but outputs Int64s.
2590            AggregateFunc::SumInt32 => self
2591                .expr
2592                .clone()
2593                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2594
2595            // SumInt64 takes Int64s as input, but outputs numerics.
2596            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2597                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2598            )),
2599
2600            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2601            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2602                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2603            ),
2604
2605            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2606            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2607                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2608            ),
2609
2610            // SumUInt64 takes UInt64s as input, but outputs numerics.
2611            AggregateFunc::SumUInt64 => {
2612                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2613                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2614                ))
2615            }
2616
2617            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2618            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::call_variadic(
2619                JsonbBuildArray,
2620                vec![
2621                    self.expr
2622                        .clone()
2623                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2624                ],
2625            ),
2626
2627            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2628            AggregateFunc::JsonbObjectAgg { .. } => {
2629                let record = self
2630                    .expr
2631                    .clone()
2632                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2633                MirScalarExpr::call_variadic(
2634                    JsonbBuildObject,
2635                    (0..2)
2636                        .map(|i| {
2637                            record
2638                                .clone()
2639                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2640                        })
2641                        .collect(),
2642                )
2643            }
2644
2645            AggregateFunc::MapAgg { value_type, .. } => {
2646                let record = self
2647                    .expr
2648                    .clone()
2649                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2650                MirScalarExpr::call_variadic(
2651                    MapBuild {
2652                        value_type: value_type.clone(),
2653                    },
2654                    (0..2)
2655                        .map(|i| {
2656                            record
2657                                .clone()
2658                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2659                        })
2660                        .collect(),
2661                )
2662            }
2663
2664            // StringAgg takes nested records of strings and outputs a string
2665            AggregateFunc::StringAgg { .. } => self
2666                .expr
2667                .clone()
2668                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2669                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2670
2671            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2672            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2673                .expr
2674                .clone()
2675                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2676
2677            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2678            AggregateFunc::RowNumber { .. } => {
2679                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2680            }
2681            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2682            AggregateFunc::DenseRank { .. } => {
2683                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2684            }
2685
2686            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2687            AggregateFunc::LagLead { lag_lead, .. } => {
2688                let tuple = self
2689                    .expr
2690                    .clone()
2691                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2692
2693                // Get the overall return type
2694                let return_type_with_orig_row = self
2695                    .typ(input_type)
2696                    .scalar_type
2697                    .unwrap_list_element_type()
2698                    .clone();
2699                let lag_lead_return_type =
2700                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2701
2702                // Extract the original row
2703                let original_row = tuple
2704                    .clone()
2705                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2706
2707                // Extract the encoded args
2708                let encoded_args =
2709                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2710
2711                let (result_expr, column_name) =
2712                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2713
2714                MirScalarExpr::call_variadic(
2715                    ListCreate {
2716                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2717                    },
2718                    vec![MirScalarExpr::call_variadic(
2719                        RecordCreate {
2720                            field_names: vec![column_name, ColumnName::from("?record?")],
2721                        },
2722                        vec![result_expr, original_row],
2723                    )],
2724                )
2725            }
2726
2727            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2728            AggregateFunc::FirstValue { window_frame, .. } => {
2729                let tuple = self
2730                    .expr
2731                    .clone()
2732                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2733
2734                // Get the overall return type
2735                let return_type_with_orig_row = self
2736                    .typ(input_type)
2737                    .scalar_type
2738                    .unwrap_list_element_type()
2739                    .clone();
2740                let first_value_return_type =
2741                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2742
2743                // Extract the original row
2744                let original_row = tuple
2745                    .clone()
2746                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2747
2748                // Extract the input value
2749                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2750
2751                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2752                    window_frame,
2753                    arg,
2754                    first_value_return_type,
2755                );
2756
2757                MirScalarExpr::call_variadic(
2758                    ListCreate {
2759                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2760                    },
2761                    vec![MirScalarExpr::call_variadic(
2762                        RecordCreate {
2763                            field_names: vec![column_name, ColumnName::from("?record?")],
2764                        },
2765                        vec![result_expr, original_row],
2766                    )],
2767                )
2768            }
2769
2770            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2771            AggregateFunc::LastValue { window_frame, .. } => {
2772                let tuple = self
2773                    .expr
2774                    .clone()
2775                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2776
2777                // Get the overall return type
2778                let return_type_with_orig_row = self
2779                    .typ(input_type)
2780                    .scalar_type
2781                    .unwrap_list_element_type()
2782                    .clone();
2783                let last_value_return_type =
2784                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2785
2786                // Extract the original row
2787                let original_row = tuple
2788                    .clone()
2789                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2790
2791                // Extract the input value
2792                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2793
2794                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2795                    window_frame,
2796                    arg,
2797                    last_value_return_type,
2798                );
2799
2800                MirScalarExpr::call_variadic(
2801                    ListCreate {
2802                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2803                    },
2804                    vec![MirScalarExpr::call_variadic(
2805                        RecordCreate {
2806                            field_names: vec![column_name, ColumnName::from("?record?")],
2807                        },
2808                        vec![result_expr, original_row],
2809                    )],
2810                )
2811            }
2812
2813            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2814            // See an example MIR in `window_func_applied_to`.
2815            AggregateFunc::WindowAggregate {
2816                wrapped_aggregate,
2817                window_frame,
2818                order_by: _,
2819            } => {
2820                // TODO: deduplicate code between the various window function cases.
2821
2822                let tuple = self
2823                    .expr
2824                    .clone()
2825                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2826
2827                // Get the overall return type
2828                let return_type = self
2829                    .typ(input_type)
2830                    .scalar_type
2831                    .unwrap_list_element_type()
2832                    .clone();
2833                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2834
2835                // Extract the original row
2836                let original_row = tuple
2837                    .clone()
2838                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2839
2840                // Extract the input value
2841                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2842
2843                let (result, column_name) = Self::on_unique_window_agg(
2844                    window_frame,
2845                    arg_expr,
2846                    input_type,
2847                    window_agg_return_type,
2848                    wrapped_aggregate,
2849                );
2850
2851                MirScalarExpr::call_variadic(
2852                    ListCreate {
2853                        elem_type: SqlScalarType::from_repr(&return_type),
2854                    },
2855                    vec![MirScalarExpr::call_variadic(
2856                        RecordCreate {
2857                            field_names: vec![column_name, ColumnName::from("?record?")],
2858                        },
2859                        vec![result, original_row],
2860                    )],
2861                )
2862            }
2863
2864            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
2865            AggregateFunc::FusedWindowAggregate {
2866                wrapped_aggregates,
2867                order_by: _,
2868                window_frame,
2869            } => {
2870                // Throw away OrderByExprs
2871                let tuple = self
2872                    .expr
2873                    .clone()
2874                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2875
2876                // Extract the original row
2877                let original_row = tuple
2878                    .clone()
2879                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2880
2881                // Extract the args of the fused call
2882                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2883
2884                let return_type_with_orig_row = self
2885                    .typ(input_type)
2886                    .scalar_type
2887                    .unwrap_list_element_type()
2888                    .clone();
2889
2890                let all_func_return_types =
2891                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2892                let mut func_result_exprs = Vec::new();
2893                let mut col_names = Vec::new();
2894                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
2895                    let arg = all_args
2896                        .clone()
2897                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2898                    let return_type =
2899                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2900                    let (result, column_name) = Self::on_unique_window_agg(
2901                        window_frame,
2902                        arg,
2903                        input_type,
2904                        return_type,
2905                        wrapped_aggr,
2906                    );
2907                    func_result_exprs.push(result);
2908                    col_names.push(column_name);
2909                }
2910
2911                MirScalarExpr::call_variadic(
2912                    ListCreate {
2913                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
2914                    },
2915                    vec![MirScalarExpr::call_variadic(
2916                        RecordCreate {
2917                            field_names: vec![
2918                                ColumnName::from("?fused_window_aggr?"),
2919                                ColumnName::from("?record?"),
2920                            ],
2921                        },
2922                        vec![
2923                            MirScalarExpr::call_variadic(
2924                                RecordCreate {
2925                                    field_names: col_names,
2926                                },
2927                                func_result_exprs,
2928                            ),
2929                            original_row,
2930                        ],
2931                    )],
2932                )
2933            }
2934
2935            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
2936            AggregateFunc::FusedValueWindowFunc {
2937                funcs,
2938                order_by: outer_order_by,
2939            } => {
2940                // Throw away OrderByExprs
2941                let tuple = self
2942                    .expr
2943                    .clone()
2944                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2945
2946                // Extract the original row
2947                let original_row = tuple
2948                    .clone()
2949                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2950
2951                // Extract the encoded args of the fused call
2952                let all_encoded_args =
2953                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2954
2955                let return_type_with_orig_row = self
2956                    .typ(input_type)
2957                    .scalar_type
2958                    .unwrap_list_element_type()
2959                    .clone();
2960
2961                let all_func_return_types =
2962                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2963                let mut func_result_exprs = Vec::new();
2964                let mut col_names = Vec::new();
2965                for (idx, func) in funcs.iter().enumerate() {
2966                    let args_for_func = all_encoded_args
2967                        .clone()
2968                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
2969                    let return_type_for_func =
2970                        all_func_return_types.unwrap_record_element_type()[idx].clone();
2971                    let (result, column_name) = match func {
2972                        AggregateFunc::LagLead {
2973                            lag_lead,
2974                            order_by,
2975                            ignore_nulls: _,
2976                        } => {
2977                            assert_eq!(order_by, outer_order_by);
2978                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
2979                        }
2980                        AggregateFunc::FirstValue {
2981                            window_frame,
2982                            order_by,
2983                        } => {
2984                            assert_eq!(order_by, outer_order_by);
2985                            Self::on_unique_first_value_last_value(
2986                                window_frame,
2987                                args_for_func,
2988                                return_type_for_func,
2989                            )
2990                        }
2991                        AggregateFunc::LastValue {
2992                            window_frame,
2993                            order_by,
2994                        } => {
2995                            assert_eq!(order_by, outer_order_by);
2996                            Self::on_unique_first_value_last_value(
2997                                window_frame,
2998                                args_for_func,
2999                                return_type_for_func,
3000                            )
3001                        }
3002                        _ => panic!("unknown function in FusedValueWindowFunc"),
3003                    };
3004                    func_result_exprs.push(result);
3005                    col_names.push(column_name);
3006                }
3007
3008                MirScalarExpr::call_variadic(
3009                    ListCreate {
3010                        elem_type: SqlScalarType::from_repr(&return_type_with_orig_row),
3011                    },
3012                    vec![MirScalarExpr::call_variadic(
3013                        RecordCreate {
3014                            field_names: vec![
3015                                ColumnName::from("?fused_value_window_func?"),
3016                                ColumnName::from("?record?"),
3017                            ],
3018                        },
3019                        vec![
3020                            MirScalarExpr::call_variadic(
3021                                RecordCreate {
3022                                    field_names: col_names,
3023                                },
3024                                func_result_exprs,
3025                            ),
3026                            original_row,
3027                        ],
3028                    )],
3029                )
3030            }
3031
3032            // All other variants should return the argument to the aggregation.
3033            AggregateFunc::MaxNumeric
3034            | AggregateFunc::MaxInt16
3035            | AggregateFunc::MaxInt32
3036            | AggregateFunc::MaxInt64
3037            | AggregateFunc::MaxUInt16
3038            | AggregateFunc::MaxUInt32
3039            | AggregateFunc::MaxUInt64
3040            | AggregateFunc::MaxMzTimestamp
3041            | AggregateFunc::MaxFloat32
3042            | AggregateFunc::MaxFloat64
3043            | AggregateFunc::MaxBool
3044            | AggregateFunc::MaxString
3045            | AggregateFunc::MaxDate
3046            | AggregateFunc::MaxTimestamp
3047            | AggregateFunc::MaxTimestampTz
3048            | AggregateFunc::MaxInterval
3049            | AggregateFunc::MaxTime
3050            | AggregateFunc::MinNumeric
3051            | AggregateFunc::MinInt16
3052            | AggregateFunc::MinInt32
3053            | AggregateFunc::MinInt64
3054            | AggregateFunc::MinUInt16
3055            | AggregateFunc::MinUInt32
3056            | AggregateFunc::MinUInt64
3057            | AggregateFunc::MinMzTimestamp
3058            | AggregateFunc::MinFloat32
3059            | AggregateFunc::MinFloat64
3060            | AggregateFunc::MinBool
3061            | AggregateFunc::MinString
3062            | AggregateFunc::MinDate
3063            | AggregateFunc::MinTimestamp
3064            | AggregateFunc::MinTimestampTz
3065            | AggregateFunc::MinInterval
3066            | AggregateFunc::MinTime
3067            | AggregateFunc::SumFloat32
3068            | AggregateFunc::SumFloat64
3069            | AggregateFunc::SumNumeric
3070            | AggregateFunc::Any
3071            | AggregateFunc::All
3072            | AggregateFunc::Dummy => self.expr.clone(),
3073        }
3074    }
3075
3076    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3077    fn on_unique_ranking_window_funcs(
3078        &self,
3079        input_type: &[ReprColumnType],
3080        col_name: &str,
3081    ) -> MirScalarExpr {
3082        let sql_input_type: Vec<SqlColumnType> =
3083            input_type.iter().map(SqlColumnType::from_repr).collect();
3084        let list = self
3085            .expr
3086            .clone()
3087            // extract the list within the record
3088            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3089
3090        // extract the expression within the list
3091        let record = MirScalarExpr::call_variadic(
3092            ListIndex,
3093            vec![
3094                list,
3095                MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3096            ],
3097        );
3098
3099        MirScalarExpr::call_variadic(
3100            ListCreate {
3101                elem_type: self
3102                    .sql_typ(&sql_input_type)
3103                    .scalar_type
3104                    .unwrap_list_element_type()
3105                    .clone(),
3106            },
3107            vec![MirScalarExpr::call_variadic(
3108                RecordCreate {
3109                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3110                },
3111                vec![
3112                    MirScalarExpr::literal_ok(Datum::Int64(1), ReprScalarType::Int64),
3113                    record,
3114                ],
3115            )],
3116        )
3117    }
3118
3119    /// `on_unique` for `lag` and `lead`
3120    fn on_unique_lag_lead(
3121        lag_lead: &LagLeadType,
3122        encoded_args: MirScalarExpr,
3123        return_type: ReprScalarType,
3124    ) -> (MirScalarExpr, ColumnName) {
3125        let expr = encoded_args
3126            .clone()
3127            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3128        let offset = encoded_args
3129            .clone()
3130            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3131        let default_value =
3132            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3133
3134        // In this case, the window always has only one element, so if the offset is not null and
3135        // not zero, the default value should be returned instead.
3136        let value = offset
3137            .clone()
3138            .call_binary(
3139                MirScalarExpr::literal_ok(Datum::Int32(0), ReprScalarType::Int32),
3140                crate::func::Eq,
3141            )
3142            .if_then_else(expr, default_value);
3143        let result_expr = offset
3144            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3145            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3146
3147        let column_name = ColumnName::from(match lag_lead {
3148            LagLeadType::Lag => "?lag?",
3149            LagLeadType::Lead => "?lead?",
3150        });
3151
3152        (result_expr, column_name)
3153    }
3154
3155    /// `on_unique` for `first_value` and `last_value`
3156    fn on_unique_first_value_last_value(
3157        window_frame: &WindowFrame,
3158        arg: MirScalarExpr,
3159        return_type: ReprScalarType,
3160    ) -> (MirScalarExpr, ColumnName) {
3161        // If the window frame includes the current (single) row, return its value, null otherwise
3162        let result_expr = if window_frame.includes_current_row() {
3163            arg
3164        } else {
3165            MirScalarExpr::literal_null(return_type)
3166        };
3167        (result_expr, ColumnName::from("?first_value?"))
3168    }
3169
3170    /// `on_unique` for window aggregations
3171    fn on_unique_window_agg(
3172        window_frame: &WindowFrame,
3173        arg_expr: MirScalarExpr,
3174        input_type: &[ReprColumnType],
3175        return_type: ReprScalarType,
3176        wrapped_aggr: &AggregateFunc,
3177    ) -> (MirScalarExpr, ColumnName) {
3178        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3179        // that row. Otherwise, return the default value for the aggregate.
3180        let result_expr = if window_frame.includes_current_row() {
3181            AggregateExpr {
3182                func: wrapped_aggr.clone(),
3183                expr: arg_expr,
3184                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3185            }
3186            .on_unique(input_type)
3187        } else {
3188            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3189        };
3190        (result_expr, ColumnName::from("?window_agg?"))
3191    }
3192
3193    /// Returns whether the expression is COUNT(*) or not.  Note that
3194    /// when we define the count builtin in sql::func, we convert
3195    /// COUNT(*) to COUNT(true), making it indistinguishable from
3196    /// literal COUNT(true), but we prefer to consider this as the
3197    /// former.
3198    ///
3199    /// (HIR has the same `is_count_asterisk`.)
3200    pub fn is_count_asterisk(&self) -> bool {
3201        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3202    }
3203}
3204
3205/// Describe a join implementation in dataflow.
3206#[derive(
3207    Clone,
3208    Debug,
3209    Eq,
3210    PartialEq,
3211    Ord,
3212    PartialOrd,
3213    Serialize,
3214    Deserialize,
3215    Hash,
3216    MzReflect
3217)]
3218pub enum JoinImplementation {
3219    /// Perform a sequence of binary differential dataflow joins.
3220    ///
3221    /// The first argument indicates
3222    /// 1) the index of the starting collection,
3223    /// 2) if it should be arranged, the keys to arrange it by, and
3224    /// 3) the characteristics of the starting collection (for EXPLAINing).
3225    /// The sequence that follows lists other relation indexes, and the key for
3226    /// the arrangement we should use when joining it in.
3227    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3228    /// were used for join ordering.
3229    ///
3230    /// Each collection index should occur exactly once, either as the starting collection
3231    /// or somewhere in the list.
3232    Differential(
3233        (
3234            usize,
3235            Option<Vec<MirScalarExpr>>,
3236            Option<JoinInputCharacteristics>,
3237        ),
3238        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3239    ),
3240    /// Perform independent delta query dataflows for each input.
3241    ///
3242    /// The argument is a sequence of plans, for the input collections in order.
3243    /// Each plan starts from the corresponding index, and then in sequence joins
3244    /// against collections identified by index and with the specified arrangement key.
3245    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3246    /// used for join ordering.
3247    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3248    /// Join a user-created index with a constant collection to speed up the evaluation of a
3249    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3250    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3251    /// to represent it in MIR, because the fast path detection wants to match on this.
3252    ///
3253    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3254    IndexedFilter(
3255        GlobalId,
3256        GlobalId,
3257        Vec<MirScalarExpr>,
3258        #[mzreflect(ignore)] Vec<Row>,
3259    ),
3260    /// No implementation yet selected.
3261    Unimplemented,
3262}
3263
3264impl Default for JoinImplementation {
3265    fn default() -> Self {
3266        JoinImplementation::Unimplemented
3267    }
3268}
3269
3270impl JoinImplementation {
3271    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3272    pub fn is_implemented(&self) -> bool {
3273        match self {
3274            Self::Unimplemented => false,
3275            _ => true,
3276        }
3277    }
3278
3279    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3280    pub fn name(&self) -> Option<&'static str> {
3281        match self {
3282            Self::Differential(..) => Some("differential"),
3283            Self::DeltaQuery(..) => Some("delta"),
3284            Self::IndexedFilter(..) => Some("indexed_filter"),
3285            Self::Unimplemented => None,
3286        }
3287    }
3288}
3289
3290/// Characteristics of a join order candidate collection.
3291///
3292/// A candidate is described by a collection and a key, and may have various liabilities.
3293/// Primarily, the candidate may risk substantial inflation of records, which is something
3294/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3295/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3296/// collections in the interest of consistent tie-breaking. For more characteristics, see
3297/// comments on individual fields.
3298///
3299/// This has more than one version. `new` instantiates the appropriate version based on a
3300/// feature flag.
3301#[derive(
3302    Eq,
3303    PartialEq,
3304    Ord,
3305    PartialOrd,
3306    Debug,
3307    Clone,
3308    Serialize,
3309    Deserialize,
3310    Hash,
3311    MzReflect
3312)]
3313pub enum JoinInputCharacteristics {
3314    /// Old version, with `enable_join_prioritize_arranged` turned off.
3315    V1(JoinInputCharacteristicsV1),
3316    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3317    V2(JoinInputCharacteristicsV2),
3318}
3319
3320impl JoinInputCharacteristics {
3321    /// Creates a new instance with the given characteristics.
3322    pub fn new(
3323        unique_key: bool,
3324        key_length: usize,
3325        arranged: bool,
3326        cardinality: Option<usize>,
3327        filters: FilterCharacteristics,
3328        input: usize,
3329        enable_join_prioritize_arranged: bool,
3330    ) -> Self {
3331        if enable_join_prioritize_arranged {
3332            Self::V2(JoinInputCharacteristicsV2::new(
3333                unique_key,
3334                key_length,
3335                arranged,
3336                cardinality,
3337                filters,
3338                input,
3339            ))
3340        } else {
3341            Self::V1(JoinInputCharacteristicsV1::new(
3342                unique_key,
3343                key_length,
3344                arranged,
3345                cardinality,
3346                filters,
3347                input,
3348            ))
3349        }
3350    }
3351
3352    /// Turns the instance into a String to be printed in EXPLAIN.
3353    pub fn explain(&self) -> String {
3354        match self {
3355            Self::V1(jic) => jic.explain(),
3356            Self::V2(jic) => jic.explain(),
3357        }
3358    }
3359
3360    /// Whether the join input described by `self` is arranged.
3361    pub fn arranged(&self) -> bool {
3362        match self {
3363            Self::V1(jic) => jic.arranged,
3364            Self::V2(jic) => jic.arranged,
3365        }
3366    }
3367
3368    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3369    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3370        match self {
3371            Self::V1(jic) => &mut jic.filters,
3372            Self::V2(jic) => &mut jic.filters,
3373        }
3374    }
3375}
3376
3377/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3378#[derive(
3379    Eq,
3380    PartialEq,
3381    Ord,
3382    PartialOrd,
3383    Debug,
3384    Clone,
3385    Serialize,
3386    Deserialize,
3387    Hash,
3388    MzReflect
3389)]
3390pub struct JoinInputCharacteristicsV2 {
3391    /// An excellent indication that record count will not increase.
3392    pub unique_key: bool,
3393    /// Cross joins are bad.
3394    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3395    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3396    /// but otherwise `key_length` is less important than `arranged`.)
3397    pub not_cross: bool,
3398    /// Indicates that there will be no additional in-memory footprint.
3399    pub arranged: bool,
3400    /// A weaker signal that record count will not increase.
3401    pub key_length: usize,
3402    /// Estimated cardinality (lower is better)
3403    pub cardinality: Option<std::cmp::Reverse<usize>>,
3404    /// Characteristics of the filter that is applied at this input.
3405    pub filters: FilterCharacteristics,
3406    /// We want to prefer input earlier in the input list, for stability of ordering.
3407    pub input: std::cmp::Reverse<usize>,
3408}
3409
3410impl JoinInputCharacteristicsV2 {
3411    /// Creates a new instance with the given characteristics.
3412    pub fn new(
3413        unique_key: bool,
3414        key_length: usize,
3415        arranged: bool,
3416        cardinality: Option<usize>,
3417        filters: FilterCharacteristics,
3418        input: usize,
3419    ) -> Self {
3420        Self {
3421            unique_key,
3422            not_cross: key_length > 0,
3423            arranged,
3424            key_length,
3425            cardinality: cardinality.map(std::cmp::Reverse),
3426            filters,
3427            input: std::cmp::Reverse(input),
3428        }
3429    }
3430
3431    /// Turns the instance into a String to be printed in EXPLAIN.
3432    pub fn explain(&self) -> String {
3433        let mut e = "".to_owned();
3434        if self.unique_key {
3435            e.push_str("U");
3436        }
3437        // Don't need to print `not_cross`, because that is visible in the printed key.
3438        // if !self.not_cross {
3439        //     e.push_str("C");
3440        // }
3441        for _ in 0..self.key_length {
3442            e.push_str("K");
3443        }
3444        if self.arranged {
3445            e.push_str("A");
3446        }
3447        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3448            e.push_str(&format!("|{cardinality}|"));
3449        }
3450        e.push_str(&self.filters.explain());
3451        e
3452    }
3453}
3454
3455/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3456#[derive(
3457    Eq,
3458    PartialEq,
3459    Ord,
3460    PartialOrd,
3461    Debug,
3462    Clone,
3463    Serialize,
3464    Deserialize,
3465    Hash,
3466    MzReflect
3467)]
3468pub struct JoinInputCharacteristicsV1 {
3469    /// An excellent indication that record count will not increase.
3470    pub unique_key: bool,
3471    /// A weaker signal that record count will not increase.
3472    pub key_length: usize,
3473    /// Indicates that there will be no additional in-memory footprint.
3474    pub arranged: bool,
3475    /// Estimated cardinality (lower is better)
3476    pub cardinality: Option<std::cmp::Reverse<usize>>,
3477    /// Characteristics of the filter that is applied at this input.
3478    pub filters: FilterCharacteristics,
3479    /// We want to prefer input earlier in the input list, for stability of ordering.
3480    pub input: std::cmp::Reverse<usize>,
3481}
3482
3483impl JoinInputCharacteristicsV1 {
3484    /// Creates a new instance with the given characteristics.
3485    pub fn new(
3486        unique_key: bool,
3487        key_length: usize,
3488        arranged: bool,
3489        cardinality: Option<usize>,
3490        filters: FilterCharacteristics,
3491        input: usize,
3492    ) -> Self {
3493        Self {
3494            unique_key,
3495            key_length,
3496            arranged,
3497            cardinality: cardinality.map(std::cmp::Reverse),
3498            filters,
3499            input: std::cmp::Reverse(input),
3500        }
3501    }
3502
3503    /// Turns the instance into a String to be printed in EXPLAIN.
3504    pub fn explain(&self) -> String {
3505        let mut e = "".to_owned();
3506        if self.unique_key {
3507            e.push_str("U");
3508        }
3509        for _ in 0..self.key_length {
3510            e.push_str("K");
3511        }
3512        if self.arranged {
3513            e.push_str("A");
3514        }
3515        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3516            e.push_str(&format!("|{cardinality}|"));
3517        }
3518        e.push_str(&self.filters.explain());
3519        e
3520    }
3521}
3522
3523/// Instructions for finishing the result of a query.
3524///
3525/// The primary reason for the existence of this structure and attendant code
3526/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3527/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3528/// multisets. But as it turns out, the same idea can be used to optimize
3529/// trivial peeks.
3530///
3531/// The generic parameters are for accommodating prepared statement parameters in
3532/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3533/// `bind_parameters` on them.
3534#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3535pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3536    /// Order rows by the given columns.
3537    pub order_by: Vec<ColumnOrder>,
3538    /// Include only as many rows (after offset).
3539    pub limit: Option<L>,
3540    /// Omit as many rows.
3541    pub offset: O,
3542    /// Include only given columns.
3543    pub project: Vec<usize>,
3544}
3545
3546impl<L> RowSetFinishing<L> {
3547    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3548    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3549        RowSetFinishing {
3550            order_by: Vec::new(),
3551            limit: None,
3552            offset: 0,
3553            project: (0..arity).collect(),
3554        }
3555    }
3556    /// True if the finishing does nothing to any result set.
3557    pub fn is_trivial(&self, arity: usize) -> bool {
3558        self.limit.is_none()
3559            && self.order_by.is_empty()
3560            && self.offset == 0
3561            && self.project.iter().copied().eq(0..arity)
3562    }
3563    /// True if the finishing does not require an ORDER BY.
3564    ///
3565    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3566    /// explicit ordering we will skip an arbitrary bag of elements and return
3567    /// the first arbitrary elements in the remaining bag. The result semantics
3568    /// are still correct but maybe surprising for some users.
3569    pub fn is_streamable(&self, arity: usize) -> bool {
3570        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3571    }
3572}
3573
3574impl RowSetFinishing<NonNeg<i64>, usize> {
3575    /// The number of rows needed from before the finishing to evaluate the finishing:
3576    /// offset + limit.
3577    ///
3578    /// If it returns None, then we need all the rows.
3579    pub fn num_rows_needed(&self) -> Option<usize> {
3580        self.limit
3581            .as_ref()
3582            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3583    }
3584}
3585
3586impl RowSetFinishing {
3587    /// Applies finishing actions to a [`RowCollection`], and reports the total
3588    /// time it took to run.
3589    ///
3590    /// Returns a [`RowCollectionIter`] that contains all of the response data, as
3591    /// well as the size of the response in bytes.
3592    pub fn finish(
3593        &self,
3594        rows: RowCollection,
3595        max_result_size: u64,
3596        max_returned_query_size: Option<u64>,
3597        duration_histogram: &Histogram,
3598    ) -> Result<(RowCollectionIter, usize), String> {
3599        let now = Instant::now();
3600        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3601        let duration = now.elapsed();
3602        duration_histogram.observe(duration.as_secs_f64());
3603
3604        result
3605    }
3606
3607    /// Implementation for [`RowSetFinishing::finish`].
3608    fn finish_inner(
3609        &self,
3610        rows: RowCollection,
3611        max_result_size: u64,
3612        max_returned_query_size: Option<u64>,
3613    ) -> Result<(RowCollectionIter, usize), String> {
3614        // How much additional memory is required to make a sorted view.
3615        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3616        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3617
3618        // Bail if creating the sorted view would require us to use too much memory.
3619        if required_memory > usize::cast_from(max_result_size) {
3620            let max_bytes = ByteSize::b(max_result_size);
3621            return Err(format!("result exceeds max size of {max_bytes}",));
3622        }
3623
3624        let sorted_view = rows;
3625        let mut iter = sorted_view
3626            .into_row_iter()
3627            .apply_offset(self.offset)
3628            .with_projection(self.project.clone());
3629
3630        if let Some(limit) = self.limit {
3631            let limit = u64::from(limit);
3632            let limit = usize::cast_from(limit);
3633            iter = iter.with_limit(limit);
3634        };
3635
3636        // TODO(parkmycar): Re-think how we can calculate the total response size without
3637        // having to iterate through the entire collection of Rows, while still
3638        // respecting the LIMIT, OFFSET, and projections.
3639        //
3640        // Note: It feels a bit bad always calculating the response size, but we almost
3641        // always need it to either check the `max_returned_query_size`, or for reporting
3642        // in the query history.
3643        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3644
3645        // Bail if we would end up returning more data to the client than they can support.
3646        if let Some(max) = max_returned_query_size {
3647            if response_size > usize::cast_from(max) {
3648                let max_bytes = ByteSize::b(max);
3649                return Err(format!("result exceeds max size of {max_bytes}"));
3650            }
3651        }
3652
3653        Ok((iter, response_size))
3654    }
3655}
3656
3657/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3658/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3659/// on query result size.
3660#[derive(Debug)]
3661pub struct RowSetFinishingIncremental {
3662    /// Include only as many rows (after offset).
3663    pub remaining_limit: Option<usize>,
3664    /// Omit as many rows.
3665    pub remaining_offset: usize,
3666    /// The maximum allowed result size, as requested by the client.
3667    pub max_returned_query_size: Option<u64>,
3668    /// Tracks our remaining allowed budget for result size.
3669    pub remaining_max_returned_query_size: Option<u64>,
3670    /// Include only given columns.
3671    pub project: Vec<usize>,
3672}
3673
3674impl RowSetFinishingIncremental {
3675    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3676    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3677    /// `true`.
3678    ///
3679    /// # Panics
3680    ///
3681    /// Panics if the result is not streamable, that is it has an ORDER BY.
3682    pub fn new(
3683        offset: usize,
3684        limit: Option<NonNeg<i64>>,
3685        project: Vec<usize>,
3686        max_returned_query_size: Option<u64>,
3687    ) -> Self {
3688        let limit = limit.map(|l| {
3689            let l = u64::from(l);
3690            let l = usize::cast_from(l);
3691            l
3692        });
3693
3694        RowSetFinishingIncremental {
3695            remaining_limit: limit,
3696            remaining_offset: offset,
3697            max_returned_query_size,
3698            remaining_max_returned_query_size: max_returned_query_size,
3699            project,
3700        }
3701    }
3702
3703    /// Applies finishing actions to the given [`RowCollection`], and reports
3704    /// the total time it took to run.
3705    ///
3706    /// Returns a [`RowCollectionIter`] that contains all of the response
3707    /// data.
3708    pub fn finish_incremental(
3709        &mut self,
3710        rows: RowCollection,
3711        max_result_size: u64,
3712        duration_histogram: &Histogram,
3713    ) -> Result<RowCollectionIter, String> {
3714        let now = Instant::now();
3715        let result = self.finish_incremental_inner(rows, max_result_size);
3716        let duration = now.elapsed();
3717        duration_histogram.observe(duration.as_secs_f64());
3718
3719        result
3720    }
3721
3722    fn finish_incremental_inner(
3723        &mut self,
3724        rows: RowCollection,
3725        max_result_size: u64,
3726    ) -> Result<RowCollectionIter, String> {
3727        // How much additional memory is required to make a sorted view.
3728        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3729        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3730
3731        // Bail if creating the sorted view would require us to use too much memory.
3732        if required_memory > usize::cast_from(max_result_size) {
3733            let max_bytes = ByteSize::b(max_result_size);
3734            return Err(format!("total result exceeds max size of {max_bytes}",));
3735        }
3736
3737        let batch_num_rows = rows.count();
3738
3739        let sorted_view = rows;
3740        let mut iter = sorted_view
3741            .into_row_iter()
3742            .apply_offset(self.remaining_offset)
3743            .with_projection(self.project.clone());
3744
3745        if let Some(limit) = self.remaining_limit {
3746            iter = iter.with_limit(limit);
3747        };
3748
3749        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3750        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3751            *remaining_limit -= iter.count();
3752        }
3753
3754        // TODO(parkmycar): Re-think how we can calculate the total response size without
3755        // having to iterate through the entire collection of Rows, while still
3756        // respecting the LIMIT, OFFSET, and projections.
3757        //
3758        // Note: It feels a bit bad always calculating the response size, but we almost
3759        // always need it to either check the `max_returned_query_size`, or for reporting
3760        // in the query history.
3761        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3762
3763        // Bail if we would end up returning more data to the client than they can support.
3764        if let Some(max) = &mut self.remaining_max_returned_query_size {
3765            if let Some(remaining) = max.checked_sub(response_size.cast_into()) {
3766                *max = remaining;
3767            } else {
3768                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3769                return Err(format!("total result exceeds max size of {max_bytes}"));
3770            }
3771        }
3772
3773        Ok(iter)
3774    }
3775}
3776
3777/// Compares two rows columnwise, using [compare_columns].
3778///
3779/// Compared to the naive implementation, this allows sharing some memory and implements some
3780/// optimizations that avoid unnecessary row unpacking.
3781#[derive(Debug, Clone)]
3782pub struct RowComparator<O: AsRef<[ColumnOrder]> = Vec<ColumnOrder>> {
3783    order: O,
3784    /// Invariant: all column references in the order are less than this limit.
3785    /// This allows for partial unpacking of rows.
3786    limit: usize,
3787    left_vec: RefCell<DatumVec>,
3788    right_vec: RefCell<DatumVec>,
3789}
3790
3791impl<O: AsRef<[ColumnOrder]>> RowComparator<O> {
3792    /// Create a new row comparator from the given column ordering.
3793    pub fn new(order: O) -> Self {
3794        let limit = order
3795            .as_ref()
3796            .iter()
3797            .map(|o| o.column + 1)
3798            .max()
3799            .unwrap_or(0);
3800        Self {
3801            order,
3802            limit,
3803            left_vec: Default::default(),
3804            right_vec: Default::default(),
3805        }
3806    }
3807
3808    /// Compare two (references to) rows.
3809    pub fn compare_rows(
3810        &self,
3811        left_row: &RowRef,
3812        right_row: &RowRef,
3813        tiebreaker: impl Fn() -> Ordering,
3814    ) -> Ordering {
3815        let order = if self.limit == 0 {
3816            Ordering::Equal
3817        } else {
3818            // These borrows should never fail, since this struct is non-sync and this function
3819            // is non-recursive.
3820            let mut left_ref = self.left_vec.borrow_mut();
3821            let mut right_ref = self.right_vec.borrow_mut();
3822            let left_cols = left_ref.borrow_with_limit(left_row, self.limit);
3823            let right_cols = right_ref.borrow_with_limit(right_row, self.limit);
3824            compare_columns(self.order.as_ref(), &left_cols, &right_cols, || {
3825                Ordering::Equal
3826            })
3827        };
3828        // Tiebreak without the vecs borrowed, in case that recursively invokes this function.
3829        order.then_with(tiebreaker)
3830    }
3831}
3832
3833/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3834/// ordering, call `tiebreaker`.
3835pub fn compare_columns<F>(
3836    order: &[ColumnOrder],
3837    left: &[Datum],
3838    right: &[Datum],
3839    tiebreaker: F,
3840) -> Ordering
3841where
3842    F: Fn() -> Ordering,
3843{
3844    for order in order {
3845        let cmp = match (&left[order.column], &right[order.column]) {
3846            (Datum::Null, Datum::Null) => Ordering::Equal,
3847            (Datum::Null, _) => {
3848                if order.nulls_last {
3849                    Ordering::Greater
3850                } else {
3851                    Ordering::Less
3852                }
3853            }
3854            (_, Datum::Null) => {
3855                if order.nulls_last {
3856                    Ordering::Less
3857                } else {
3858                    Ordering::Greater
3859                }
3860            }
3861            (lval, rval) => {
3862                if order.desc {
3863                    rval.cmp(lval)
3864                } else {
3865                    lval.cmp(rval)
3866                }
3867            }
3868        };
3869        if cmp != Ordering::Equal {
3870            return cmp;
3871        }
3872    }
3873    tiebreaker()
3874}
3875
3876/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3877/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3878///
3879/// Window frames define a subset of the partition , and only a subset of
3880/// window functions make use of the window frame.
3881#[derive(
3882    Debug,
3883    Clone,
3884    Eq,
3885    PartialEq,
3886    Ord,
3887    PartialOrd,
3888    Serialize,
3889    Deserialize,
3890    Hash,
3891    MzReflect
3892)]
3893pub struct WindowFrame {
3894    /// ROWS, RANGE or GROUPS
3895    pub units: WindowFrameUnits,
3896    /// Where the frame starts
3897    pub start_bound: WindowFrameBound,
3898    /// Where the frame ends
3899    pub end_bound: WindowFrameBound,
3900}
3901
3902impl Display for WindowFrame {
3903    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3904        write!(
3905            f,
3906            "{} between {} and {}",
3907            self.units, self.start_bound, self.end_bound
3908        )
3909    }
3910}
3911
3912impl WindowFrame {
3913    /// Return the default window frame used when one is not explicitly defined
3914    pub fn default() -> Self {
3915        WindowFrame {
3916            units: WindowFrameUnits::Range,
3917            start_bound: WindowFrameBound::UnboundedPreceding,
3918            end_bound: WindowFrameBound::CurrentRow,
3919        }
3920    }
3921
3922    fn includes_current_row(&self) -> bool {
3923        use WindowFrameBound::*;
3924        match self.start_bound {
3925            UnboundedPreceding => match self.end_bound {
3926                UnboundedPreceding => false,
3927                OffsetPreceding(0) => true,
3928                OffsetPreceding(_) => false,
3929                CurrentRow => true,
3930                OffsetFollowing(_) => true,
3931                UnboundedFollowing => true,
3932            },
3933            OffsetPreceding(0) => match self.end_bound {
3934                UnboundedPreceding => unreachable!(),
3935                OffsetPreceding(0) => true,
3936                // Any nonzero offsets here will create an empty window
3937                OffsetPreceding(_) => false,
3938                CurrentRow => true,
3939                OffsetFollowing(_) => true,
3940                UnboundedFollowing => true,
3941            },
3942            OffsetPreceding(_) => match self.end_bound {
3943                UnboundedPreceding => unreachable!(),
3944                // Window ends at the current row
3945                OffsetPreceding(0) => true,
3946                OffsetPreceding(_) => false,
3947                CurrentRow => true,
3948                OffsetFollowing(_) => true,
3949                UnboundedFollowing => true,
3950            },
3951            CurrentRow => true,
3952            OffsetFollowing(0) => match self.end_bound {
3953                UnboundedPreceding => unreachable!(),
3954                OffsetPreceding(_) => unreachable!(),
3955                CurrentRow => unreachable!(),
3956                OffsetFollowing(_) => true,
3957                UnboundedFollowing => true,
3958            },
3959            OffsetFollowing(_) => match self.end_bound {
3960                UnboundedPreceding => unreachable!(),
3961                OffsetPreceding(_) => unreachable!(),
3962                CurrentRow => unreachable!(),
3963                OffsetFollowing(_) => false,
3964                UnboundedFollowing => false,
3965            },
3966            UnboundedFollowing => false,
3967        }
3968    }
3969}
3970
3971/// Describe how frame bounds are interpreted
3972#[derive(
3973    Debug,
3974    Clone,
3975    Eq,
3976    PartialEq,
3977    Ord,
3978    PartialOrd,
3979    Serialize,
3980    Deserialize,
3981    Hash,
3982    MzReflect
3983)]
3984pub enum WindowFrameUnits {
3985    /// Each row is treated as the unit of work for bounds
3986    Rows,
3987    /// Each peer group is treated as the unit of work for bounds,
3988    /// and offset-based bounds use the value of the ORDER BY expression
3989    Range,
3990    /// Each peer group is treated as the unit of work for bounds.
3991    /// Groups is currently not supported, and it is rejected during planning.
3992    Groups,
3993}
3994
3995impl Display for WindowFrameUnits {
3996    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3997        match self {
3998            WindowFrameUnits::Rows => write!(f, "rows"),
3999            WindowFrameUnits::Range => write!(f, "range"),
4000            WindowFrameUnits::Groups => write!(f, "groups"),
4001        }
4002    }
4003}
4004
4005/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
4006///
4007/// The order between frame bounds is significant, as Postgres enforces
4008/// some restrictions there.
4009#[derive(
4010    Debug,
4011    Clone,
4012    Serialize,
4013    Deserialize,
4014    PartialEq,
4015    Eq,
4016    Hash,
4017    MzReflect,
4018    PartialOrd,
4019    Ord
4020)]
4021pub enum WindowFrameBound {
4022    /// `UNBOUNDED PRECEDING`
4023    UnboundedPreceding,
4024    /// `<N> PRECEDING`
4025    OffsetPreceding(u64),
4026    /// `CURRENT ROW`
4027    CurrentRow,
4028    /// `<N> FOLLOWING`
4029    OffsetFollowing(u64),
4030    /// `UNBOUNDED FOLLOWING`.
4031    UnboundedFollowing,
4032}
4033
4034impl Display for WindowFrameBound {
4035    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4036        match self {
4037            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4038            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4039            WindowFrameBound::CurrentRow => write!(f, "current row"),
4040            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4041            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4042        }
4043    }
4044}
4045
4046/// Maximum iterations for a LetRec.
4047#[derive(
4048    Debug,
4049    Clone,
4050    Copy,
4051    PartialEq,
4052    Eq,
4053    PartialOrd,
4054    Ord,
4055    Hash,
4056    Serialize,
4057    Deserialize
4058)]
4059pub struct LetRecLimit {
4060    /// Maximum number of iterations to evaluate.
4061    pub max_iters: NonZeroU64,
4062    /// Whether to throw an error when reaching the above limit.
4063    /// If true, we simply use the current contents of each Id as the final result.
4064    pub return_at_limit: bool,
4065}
4066
4067impl LetRecLimit {
4068    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4069    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4070        limits
4071            .iter()
4072            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4073            .min()
4074    }
4075
4076    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4077    /// WMR without ERROR AT or RETURN AT.
4078    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4079}
4080
4081impl Display for LetRecLimit {
4082    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4083        write!(f, "[recursion_limit={}", self.max_iters)?;
4084        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4085            write!(f, ", return_at_limit")?;
4086        }
4087        write!(f, "]")
4088    }
4089}
4090
4091/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4092/// (See comment in MirRelationExpr::Get.)
4093#[derive(
4094    Clone,
4095    Debug,
4096    Eq,
4097    PartialEq,
4098    Ord,
4099    PartialOrd,
4100    Serialize,
4101    Deserialize,
4102    Hash
4103)]
4104pub enum AccessStrategy {
4105    /// It's either a local Get (a CTE), or unknown at the time.
4106    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4107    /// one of the other variants.
4108    UnknownOrLocal,
4109    /// The Get will read from Persist.
4110    Persist,
4111    /// The Get will read from an index or indexes: (index id, how the index will be used).
4112    Index(Vec<(GlobalId, IndexUsageType)>),
4113    /// The Get will read a collection that is computed by the same dataflow, but in a different
4114    /// `BuildDesc` in `objects_to_build`.
4115    SameDataflow,
4116}
4117
4118#[cfg(test)]
4119mod tests {
4120    use std::num::NonZeroUsize;
4121
4122    use mz_repr::explain::text::text_string_at;
4123
4124    use crate::explain::HumanizedExplain;
4125
4126    use super::*;
4127
4128    #[mz_ore::test]
4129    fn test_row_set_finishing_as_text() {
4130        let finishing = RowSetFinishing {
4131            order_by: vec![ColumnOrder {
4132                column: 4,
4133                desc: true,
4134                nulls_last: true,
4135            }],
4136            limit: Some(NonNeg::try_from(7).unwrap()),
4137            offset: Default::default(),
4138            project: vec![1, 3, 4, 5],
4139        };
4140
4141        let mode = HumanizedExplain::new(false);
4142        let expr = mode.expr(&finishing, None);
4143
4144        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4145
4146        let exp = {
4147            use mz_ore::fmt::FormatBuffer;
4148            let mut s = String::new();
4149            write!(&mut s, "Finish");
4150            write!(&mut s, " order_by=[#4 desc nulls_last]");
4151            write!(&mut s, " limit=7");
4152            write!(&mut s, " output=[#1, #3..=#5]");
4153            writeln!(&mut s, "");
4154            s
4155        };
4156
4157        assert_eq!(act, exp);
4158    }
4159
4160    #[mz_ore::test]
4161    fn test_row_set_finishing_incremental_max_returned_query_size() {
4162        let row = Row::pack_slice(&[Datum::String("hello")]);
4163        let row_size = u64::cast_from(row.data().len());
4164        let diff = NonZeroUsize::new(1).unwrap();
4165        let batch = RowCollection::new(vec![(row, diff)], &[]);
4166
4167        // Set max_returned_query_size to hold exactly 2 batches worth of rows.
4168        let mut finishing = RowSetFinishingIncremental::new(0, None, vec![0], Some(row_size * 2));
4169
4170        let max_result_size = u64::MAX;
4171
4172        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4173        assert!(r.is_ok());
4174        assert_eq!(finishing.remaining_max_returned_query_size, Some(row_size));
4175
4176        let r = finishing.finish_incremental_inner(batch.clone(), max_result_size);
4177        assert!(r.is_ok());
4178        assert_eq!(finishing.remaining_max_returned_query_size, Some(0));
4179
4180        let r = finishing.finish_incremental_inner(batch, max_result_size);
4181        assert!(r.unwrap_err().contains("total result exceeds max size"));
4182    }
4183}
4184
4185/// An iterator over AST structures, which calls out nodes in difference.
4186///
4187/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4188/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4189/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4190/// to compare two ASTs.
4191mod structured_diff {
4192
4193    use super::MirRelationExpr;
4194    use itertools::Itertools;
4195
4196    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4197    pub struct MreDiff<'a> {
4198        /// Pairs of expressions that must still be compared.
4199        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4200    }
4201
4202    impl<'a> MreDiff<'a> {
4203        /// Create a new `MirRelationExpr` structured difference.
4204        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4205            MreDiff {
4206                todo: vec![(expr1, expr2)],
4207            }
4208        }
4209    }
4210
4211    impl<'a> Iterator for MreDiff<'a> {
4212        // Pairs of expressions that do not match.
4213        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4214
4215        fn next(&mut self) -> Option<Self::Item> {
4216            while let Some((expr1, expr2)) = self.todo.pop() {
4217                match (expr1, expr2) {
4218                    (
4219                        MirRelationExpr::Constant {
4220                            rows: rows1,
4221                            typ: typ1,
4222                        },
4223                        MirRelationExpr::Constant {
4224                            rows: rows2,
4225                            typ: typ2,
4226                        },
4227                    ) => {
4228                        if rows1 != rows2 || typ1 != typ2 {
4229                            return Some((expr1, expr2));
4230                        }
4231                    }
4232                    (
4233                        MirRelationExpr::Get {
4234                            id: id1,
4235                            typ: typ1,
4236                            access_strategy: as1,
4237                        },
4238                        MirRelationExpr::Get {
4239                            id: id2,
4240                            typ: typ2,
4241                            access_strategy: as2,
4242                        },
4243                    ) => {
4244                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4245                            return Some((expr1, expr2));
4246                        }
4247                    }
4248                    (
4249                        MirRelationExpr::Let {
4250                            id: id1,
4251                            body: body1,
4252                            value: value1,
4253                        },
4254                        MirRelationExpr::Let {
4255                            id: id2,
4256                            body: body2,
4257                            value: value2,
4258                        },
4259                    ) => {
4260                        if id1 != id2 {
4261                            return Some((expr1, expr2));
4262                        } else {
4263                            self.todo.push((body1, body2));
4264                            self.todo.push((value1, value2));
4265                        }
4266                    }
4267                    (
4268                        MirRelationExpr::LetRec {
4269                            ids: ids1,
4270                            body: body1,
4271                            values: values1,
4272                            limits: limits1,
4273                        },
4274                        MirRelationExpr::LetRec {
4275                            ids: ids2,
4276                            body: body2,
4277                            values: values2,
4278                            limits: limits2,
4279                        },
4280                    ) => {
4281                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4282                            return Some((expr1, expr2));
4283                        } else {
4284                            self.todo.push((body1, body2));
4285                            self.todo.extend(values1.iter().zip_eq(values2.iter()));
4286                        }
4287                    }
4288                    (
4289                        MirRelationExpr::Project {
4290                            outputs: outputs1,
4291                            input: input1,
4292                        },
4293                        MirRelationExpr::Project {
4294                            outputs: outputs2,
4295                            input: input2,
4296                        },
4297                    ) => {
4298                        if outputs1 != outputs2 {
4299                            return Some((expr1, expr2));
4300                        } else {
4301                            self.todo.push((input1, input2));
4302                        }
4303                    }
4304                    (
4305                        MirRelationExpr::Map {
4306                            scalars: scalars1,
4307                            input: input1,
4308                        },
4309                        MirRelationExpr::Map {
4310                            scalars: scalars2,
4311                            input: input2,
4312                        },
4313                    ) => {
4314                        if scalars1 != scalars2 {
4315                            return Some((expr1, expr2));
4316                        } else {
4317                            self.todo.push((input1, input2));
4318                        }
4319                    }
4320                    (
4321                        MirRelationExpr::Filter {
4322                            predicates: predicates1,
4323                            input: input1,
4324                        },
4325                        MirRelationExpr::Filter {
4326                            predicates: predicates2,
4327                            input: input2,
4328                        },
4329                    ) => {
4330                        if predicates1 != predicates2 {
4331                            return Some((expr1, expr2));
4332                        } else {
4333                            self.todo.push((input1, input2));
4334                        }
4335                    }
4336                    (
4337                        MirRelationExpr::FlatMap {
4338                            input: input1,
4339                            func: func1,
4340                            exprs: exprs1,
4341                        },
4342                        MirRelationExpr::FlatMap {
4343                            input: input2,
4344                            func: func2,
4345                            exprs: exprs2,
4346                        },
4347                    ) => {
4348                        if func1 != func2 || exprs1 != exprs2 {
4349                            return Some((expr1, expr2));
4350                        } else {
4351                            self.todo.push((input1, input2));
4352                        }
4353                    }
4354                    (
4355                        MirRelationExpr::Join {
4356                            inputs: inputs1,
4357                            equivalences: eq1,
4358                            implementation: impl1,
4359                        },
4360                        MirRelationExpr::Join {
4361                            inputs: inputs2,
4362                            equivalences: eq2,
4363                            implementation: impl2,
4364                        },
4365                    ) => {
4366                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4367                            return Some((expr1, expr2));
4368                        } else {
4369                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4370                        }
4371                    }
4372                    (
4373                        MirRelationExpr::Reduce {
4374                            aggregates: aggregates1,
4375                            input: inputs1,
4376                            group_key: gk1,
4377                            monotonic: m1,
4378                            expected_group_size: egs1,
4379                        },
4380                        MirRelationExpr::Reduce {
4381                            aggregates: aggregates2,
4382                            input: inputs2,
4383                            group_key: gk2,
4384                            monotonic: m2,
4385                            expected_group_size: egs2,
4386                        },
4387                    ) => {
4388                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4389                            return Some((expr1, expr2));
4390                        } else {
4391                            self.todo.push((inputs1, inputs2));
4392                        }
4393                    }
4394                    (
4395                        MirRelationExpr::TopK {
4396                            group_key: gk1,
4397                            order_key: order1,
4398                            input: input1,
4399                            limit: l1,
4400                            offset: o1,
4401                            monotonic: m1,
4402                            expected_group_size: egs1,
4403                        },
4404                        MirRelationExpr::TopK {
4405                            group_key: gk2,
4406                            order_key: order2,
4407                            input: input2,
4408                            limit: l2,
4409                            offset: o2,
4410                            monotonic: m2,
4411                            expected_group_size: egs2,
4412                        },
4413                    ) => {
4414                        if order1 != order2
4415                            || gk1 != gk2
4416                            || l1 != l2
4417                            || o1 != o2
4418                            || m1 != m2
4419                            || egs1 != egs2
4420                        {
4421                            return Some((expr1, expr2));
4422                        } else {
4423                            self.todo.push((input1, input2));
4424                        }
4425                    }
4426                    (
4427                        MirRelationExpr::Negate { input: input1 },
4428                        MirRelationExpr::Negate { input: input2 },
4429                    ) => {
4430                        self.todo.push((input1, input2));
4431                    }
4432                    (
4433                        MirRelationExpr::Threshold { input: input1 },
4434                        MirRelationExpr::Threshold { input: input2 },
4435                    ) => {
4436                        self.todo.push((input1, input2));
4437                    }
4438                    (
4439                        MirRelationExpr::Union {
4440                            base: base1,
4441                            inputs: inputs1,
4442                        },
4443                        MirRelationExpr::Union {
4444                            base: base2,
4445                            inputs: inputs2,
4446                        },
4447                    ) => {
4448                        if inputs1.len() != inputs2.len() {
4449                            return Some((expr1, expr2));
4450                        } else {
4451                            self.todo.push((base1, base2));
4452                            self.todo.extend(inputs1.iter().zip_eq(inputs2.iter()));
4453                        }
4454                    }
4455                    (
4456                        MirRelationExpr::ArrangeBy {
4457                            keys: keys1,
4458                            input: input1,
4459                        },
4460                        MirRelationExpr::ArrangeBy {
4461                            keys: keys2,
4462                            input: input2,
4463                        },
4464                    ) => {
4465                        if keys1 != keys2 {
4466                            return Some((expr1, expr2));
4467                        } else {
4468                            self.todo.push((input1, input2));
4469                        }
4470                    }
4471                    _ => {
4472                        return Some((expr1, expr2));
4473                    }
4474                }
4475            }
4476            None
4477        }
4478    }
4479}