mz_expr/
relation.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10#![warn(missing_docs)]
11
12use std::cmp::{Ordering, max};
13use std::collections::{BTreeMap, BTreeSet};
14use std::fmt;
15use std::fmt::{Display, Formatter};
16use std::hash::{DefaultHasher, Hash, Hasher};
17use std::num::NonZeroU64;
18use std::time::Instant;
19
20use bytesize::ByteSize;
21use differential_dataflow::containers::{Columnation, CopyRegion};
22use itertools::Itertools;
23use mz_lowertest::MzReflect;
24use mz_ore::cast::CastFrom;
25use mz_ore::collections::CollectionExt;
26use mz_ore::id_gen::IdGen;
27use mz_ore::metrics::Histogram;
28use mz_ore::num::NonNeg;
29use mz_ore::soft_assert_no_log;
30use mz_ore::stack::RecursionLimitError;
31use mz_ore::str::Indent;
32use mz_proto::{IntoRustIfSome, ProtoType, RustType, TryFromProtoError};
33use mz_repr::adt::numeric::NumericMaxScale;
34use mz_repr::explain::text::text_string_at;
35use mz_repr::explain::{
36    DummyHumanizer, ExplainConfig, ExprHumanizer, IndexUsageType, PlanRenderingContext,
37};
38use mz_repr::{
39    ColumnName, ColumnType, Datum, Diff, GlobalId, IntoRowIterator, RelationType, Row, RowIterator,
40    ScalarType,
41};
42use proptest::prelude::{Arbitrary, BoxedStrategy, any};
43use proptest::strategy::{Strategy, Union};
44use proptest_derive::Arbitrary;
45use serde::{Deserialize, Serialize};
46
47use crate::Id::Local;
48use crate::explain::{HumanizedExpr, HumanizerMode};
49use crate::relation::func::{AggregateFunc, LagLeadType, TableFunc};
50use crate::row::{RowCollection, SortedRowCollectionIter};
51use crate::visit::{Visit, VisitChildren};
52use crate::{
53    EvalError, FilterCharacteristics, Id, LocalId, MirScalarExpr, UnaryFunc, VariadicFunc,
54    func as scalar_func,
55};
56
57pub mod canonicalize;
58pub mod func;
59pub mod join_input_mapper;
60
61include!(concat!(env!("OUT_DIR"), "/mz_expr.relation.rs"));
62
63/// A recursion limit to be used for stack-safe traversals of [`MirRelationExpr`] trees.
64///
65/// The recursion limit must be large enough to accommodate for the linear representation
66/// of some pathological but frequently occurring query fragments.
67///
68/// For example, in MIR we could have long chains of
69/// - (1) `Let` bindings,
70/// - (2) `CallBinary` calls with associative functions such as `+`
71///
72/// Until we fix those, we need to stick with the larger recursion limit.
73pub const RECURSION_LIMIT: usize = 2048;
74
75/// A trait for types that describe how to build a collection.
76pub trait CollectionPlan {
77    /// Collects the set of global identifiers from dataflows referenced in Get.
78    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>);
79
80    /// Returns the set of global identifiers from dataflows referenced in Get.
81    ///
82    /// See [`CollectionPlan::depends_on_into`] to reuse an existing `BTreeSet`.
83    fn depends_on(&self) -> BTreeSet<GlobalId> {
84        let mut out = BTreeSet::new();
85        self.depends_on_into(&mut out);
86        out
87    }
88}
89
90/// An abstract syntax tree which defines a collection.
91///
92/// The AST is meant to reflect the capabilities of the `differential_dataflow::Collection` type,
93/// written generically enough to avoid run-time compilation work.
94///
95/// `derived_hash_with_manual_eq` was complaining for the wrong reason: This lint exists because
96/// it's bad when `Eq` doesn't agree with `Hash`, which is often quite likely if one of them is
97/// implemented manually. However, our manual implementation of `Eq` _will_ agree with the derived
98/// one. This is because the reason for the manual implementation is not to change the semantics
99/// from the derived one, but to avoid stack overflows.
100#[allow(clippy::derived_hash_with_manual_eq)]
101#[derive(Clone, Debug, Ord, PartialOrd, Serialize, Deserialize, MzReflect, Hash)]
102pub enum MirRelationExpr {
103    /// A constant relation containing specified rows.
104    ///
105    /// The runtime memory footprint of this operator is zero.
106    ///
107    /// When you would like to pattern match on this, consider using `MirRelationExpr::as_const`
108    /// instead, which looks behind `ArrangeBy`s. You might want this matching behavior because
109    /// constant folding doesn't remove `ArrangeBy`s.
110    Constant {
111        /// Rows of the constant collection and their multiplicities.
112        rows: Result<Vec<(Row, Diff)>, EvalError>,
113        /// Schema of the collection.
114        typ: RelationType,
115    },
116    /// Get an existing dataflow.
117    ///
118    /// The runtime memory footprint of this operator is zero.
119    Get {
120        /// The identifier for the collection to load.
121        #[mzreflect(ignore)]
122        id: Id,
123        /// Schema of the collection.
124        typ: RelationType,
125        /// If this is a global Get, this will indicate whether we are going to read from Persist or
126        /// from an index, or from a different object in `objects_to_build`. If it's an index, then
127        /// how downstream dataflow operations will use this index is also recorded. This is filled
128        /// by `prune_and_annotate_dataflow_index_imports`. Note that this is not used by the
129        /// lowering to LIR, but is used only by EXPLAIN.
130        #[mzreflect(ignore)]
131        access_strategy: AccessStrategy,
132    },
133    /// Introduce a temporary dataflow.
134    ///
135    /// The runtime memory footprint of this operator is zero.
136    Let {
137        /// The identifier to be used in `Get` variants to retrieve `value`.
138        #[mzreflect(ignore)]
139        id: LocalId,
140        /// The collection to be bound to `id`.
141        value: Box<MirRelationExpr>,
142        /// The result of the `Let`, evaluated with `id` bound to `value`.
143        body: Box<MirRelationExpr>,
144    },
145    /// Introduce mutually recursive bindings.
146    ///
147    /// Each `LocalId` is immediately bound to an initially empty  collection
148    /// with the type of its corresponding `MirRelationExpr`. Repeatedly, each
149    /// binding is evaluated using the current contents of each other binding,
150    /// and is refreshed to contain the new evaluation. This process continues
151    /// through all bindings, and repeats as long as changes continue to occur.
152    ///
153    /// The resulting value of the expression is `body` evaluated once in the
154    /// context of the final iterates.
155    ///
156    /// A zero-binding instance can be replaced by `body`.
157    /// A single-binding instance is equivalent to `MirRelationExpr::Let`.
158    ///
159    /// The runtime memory footprint of this operator is zero.
160    LetRec {
161        /// The identifiers to be used in `Get` variants to retrieve each `value`.
162        #[mzreflect(ignore)]
163        ids: Vec<LocalId>,
164        /// The collections to be bound to each `id`.
165        values: Vec<MirRelationExpr>,
166        /// Maximum number of iterations, after which we should artificially force a fixpoint.
167        /// (Whether we error or just stop is configured by `LetRecLimit::return_at_limit`.)
168        /// The per-`LetRec` limit that the user specified is initially copied to each binding to
169        /// accommodate slicing and merging of `LetRec`s in MIR transforms (e.g., `NormalizeLets`).
170        #[mzreflect(ignore)]
171        limits: Vec<Option<LetRecLimit>>,
172        /// The result of the `Let`, evaluated with `id` bound to `value`.
173        body: Box<MirRelationExpr>,
174    },
175    /// Project out some columns from a dataflow
176    ///
177    /// The runtime memory footprint of this operator is zero.
178    Project {
179        /// The source collection.
180        input: Box<MirRelationExpr>,
181        /// Indices of columns to retain.
182        outputs: Vec<usize>,
183    },
184    /// Append new columns to a dataflow
185    ///
186    /// The runtime memory footprint of this operator is zero.
187    Map {
188        /// The source collection.
189        input: Box<MirRelationExpr>,
190        /// Expressions which determine values to append to each row.
191        /// An expression may refer to columns in `input` or
192        /// expressions defined earlier in the vector
193        scalars: Vec<MirScalarExpr>,
194    },
195    /// Like Map, but yields zero-or-more output rows per input row
196    ///
197    /// The runtime memory footprint of this operator is zero.
198    FlatMap {
199        /// The source collection
200        input: Box<MirRelationExpr>,
201        /// The table func to apply
202        func: TableFunc,
203        /// The argument to the table func
204        exprs: Vec<MirScalarExpr>,
205    },
206    /// Keep rows from a dataflow where all the predicates are true
207    ///
208    /// The runtime memory footprint of this operator is zero.
209    Filter {
210        /// The source collection.
211        input: Box<MirRelationExpr>,
212        /// Predicates, each of which must be true.
213        predicates: Vec<MirScalarExpr>,
214    },
215    /// Join several collections, where some columns must be equal.
216    ///
217    /// For further details consult the documentation for [`MirRelationExpr::join`].
218    ///
219    /// The runtime memory footprint of this operator can be proportional to
220    /// the sizes of all inputs and the size of all joins of prefixes.
221    /// This may be reduced due to arrangements available at rendering time.
222    Join {
223        /// A sequence of input relations.
224        inputs: Vec<MirRelationExpr>,
225        /// A sequence of equivalence classes of expressions on the cross product of inputs.
226        ///
227        /// Each equivalence class is a list of scalar expressions, where for each class the
228        /// intended interpretation is that all evaluated expressions should be equal.
229        ///
230        /// Each scalar expression is to be evaluated over the cross-product of all records
231        /// from all inputs. In many cases this may just be column selection from specific
232        /// inputs, but more general cases exist (e.g. complex functions of multiple columns
233        /// from multiple inputs, or just constant literals).
234        equivalences: Vec<Vec<MirScalarExpr>>,
235        /// Join implementation information.
236        #[serde(default)]
237        implementation: JoinImplementation,
238    },
239    /// Group a dataflow by some columns and aggregate over each group
240    ///
241    /// The runtime memory footprint of this operator is at most proportional to the
242    /// number of distinct records in the input and output. The actual requirements
243    /// can be less: the number of distinct inputs to each aggregate, summed across
244    /// each aggregate, plus the output size. For more details consult the code that
245    /// builds the associated dataflow.
246    Reduce {
247        /// The source collection.
248        input: Box<MirRelationExpr>,
249        /// Column indices used to form groups.
250        group_key: Vec<MirScalarExpr>,
251        /// Expressions which determine values to append to each row, after the group keys.
252        aggregates: Vec<AggregateExpr>,
253        /// True iff the input is known to monotonically increase (only addition of records).
254        #[serde(default)]
255        monotonic: bool,
256        /// User hint: expected number of values per group key. Used to optimize physical rendering.
257        #[serde(default)]
258        expected_group_size: Option<u64>,
259    },
260    /// Groups and orders within each group, limiting output.
261    ///
262    /// The runtime memory footprint of this operator is proportional to its input and output.
263    TopK {
264        /// The source collection.
265        input: Box<MirRelationExpr>,
266        /// Column indices used to form groups.
267        group_key: Vec<usize>,
268        /// Column indices used to order rows within groups.
269        order_key: Vec<ColumnOrder>,
270        /// Number of records to retain
271        #[serde(default)]
272        limit: Option<MirScalarExpr>,
273        /// Number of records to skip
274        #[serde(default)]
275        offset: usize,
276        /// True iff the input is known to monotonically increase (only addition of records).
277        #[serde(default)]
278        monotonic: bool,
279        /// User-supplied hint: how many rows will have the same group key.
280        #[serde(default)]
281        expected_group_size: Option<u64>,
282    },
283    /// Return a dataflow where the row counts are negated
284    ///
285    /// The runtime memory footprint of this operator is zero.
286    Negate {
287        /// The source collection.
288        input: Box<MirRelationExpr>,
289    },
290    /// Keep rows from a dataflow where the row counts are positive
291    ///
292    /// The runtime memory footprint of this operator is proportional to its input and output.
293    Threshold {
294        /// The source collection.
295        input: Box<MirRelationExpr>,
296    },
297    /// Adds the frequencies of elements in contained sets.
298    ///
299    /// The runtime memory footprint of this operator is zero.
300    Union {
301        /// A source collection.
302        base: Box<MirRelationExpr>,
303        /// Source collections to union.
304        inputs: Vec<MirRelationExpr>,
305    },
306    /// Technically a no-op. Used to render an index. Will be used to optimize queries
307    /// on finer grain. Each `keys` item represents a different index that should be
308    /// produced from the `keys`.
309    ///
310    /// The runtime memory footprint of this operator is proportional to its input.
311    ArrangeBy {
312        /// The source collection
313        input: Box<MirRelationExpr>,
314        /// Columns to arrange `input` by, in order of decreasing primacy
315        keys: Vec<Vec<MirScalarExpr>>,
316    },
317}
318
319impl PartialEq for MirRelationExpr {
320    fn eq(&self, other: &Self) -> bool {
321        // Capture the result and test it wrt `Ord` implementation in test environments.
322        let result = structured_diff::MreDiff::new(self, other).next().is_none();
323        mz_ore::soft_assert_eq_no_log!(result, self.cmp(other) == Ordering::Equal);
324        result
325    }
326}
327impl Eq for MirRelationExpr {}
328
329impl Arbitrary for MirRelationExpr {
330    type Parameters = ();
331
332    fn arbitrary_with((): Self::Parameters) -> Self::Strategy {
333        const VEC_LEN: usize = 2;
334
335        // A strategy for generating the leaf cases.
336        let leaf = Union::new([
337            (
338                any::<Result<Vec<(Row, Diff)>, EvalError>>(),
339                any::<RelationType>(),
340            )
341                .prop_map(|(rows, typ)| MirRelationExpr::Constant { rows, typ })
342                .boxed(),
343            (any::<Id>(), any::<RelationType>(), any::<AccessStrategy>())
344                .prop_map(|(id, typ, access_strategy)| MirRelationExpr::Get {
345                    id,
346                    typ,
347                    access_strategy,
348                })
349                .boxed(),
350        ])
351        .boxed();
352
353        leaf.prop_recursive(2, 2, 2, |inner| {
354            Union::new([
355                // Let
356                (any::<LocalId>(), inner.clone(), inner.clone())
357                    .prop_map(|(id, value, body)| MirRelationExpr::Let {
358                        id,
359                        value: Box::new(value),
360                        body: Box::new(body),
361                    })
362                    .boxed(),
363                // LetRec
364                (
365                    any::<Vec<LocalId>>(),
366                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
367                    any::<Vec<Option<LetRecLimit>>>(),
368                    inner.clone(),
369                )
370                    .prop_map(|(ids, values, limits, body)| MirRelationExpr::LetRec {
371                        ids,
372                        values,
373                        limits,
374                        body: Box::new(body),
375                    })
376                    .boxed(),
377                // Project
378                (inner.clone(), any::<Vec<usize>>())
379                    .prop_map(|(input, outputs)| MirRelationExpr::Project {
380                        input: Box::new(input),
381                        outputs,
382                    })
383                    .boxed(),
384                // Map
385                (inner.clone(), any::<Vec<MirScalarExpr>>())
386                    .prop_map(|(input, scalars)| MirRelationExpr::Map {
387                        input: Box::new(input),
388                        scalars,
389                    })
390                    .boxed(),
391                // FlatMap
392                (
393                    inner.clone(),
394                    any::<TableFunc>(),
395                    any::<Vec<MirScalarExpr>>(),
396                )
397                    .prop_map(|(input, func, exprs)| MirRelationExpr::FlatMap {
398                        input: Box::new(input),
399                        func,
400                        exprs,
401                    })
402                    .boxed(),
403                // Filter
404                (inner.clone(), any::<Vec<MirScalarExpr>>())
405                    .prop_map(|(input, predicates)| MirRelationExpr::Filter {
406                        input: Box::new(input),
407                        predicates,
408                    })
409                    .boxed(),
410                // Join
411                (
412                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
413                    any::<Vec<Vec<MirScalarExpr>>>(),
414                    any::<JoinImplementation>(),
415                )
416                    .prop_map(
417                        |(inputs, equivalences, implementation)| MirRelationExpr::Join {
418                            inputs,
419                            equivalences,
420                            implementation,
421                        },
422                    )
423                    .boxed(),
424                // Reduce
425                (
426                    inner.clone(),
427                    any::<Vec<MirScalarExpr>>(),
428                    any::<Vec<AggregateExpr>>(),
429                    any::<bool>(),
430                    any::<Option<u64>>(),
431                )
432                    .prop_map(
433                        |(input, group_key, aggregates, monotonic, expected_group_size)| {
434                            MirRelationExpr::Reduce {
435                                input: Box::new(input),
436                                group_key,
437                                aggregates,
438                                monotonic,
439                                expected_group_size,
440                            }
441                        },
442                    )
443                    .boxed(),
444                // TopK
445                (
446                    inner.clone(),
447                    any::<Vec<usize>>(),
448                    any::<Vec<ColumnOrder>>(),
449                    any::<Option<MirScalarExpr>>(),
450                    any::<usize>(),
451                    any::<bool>(),
452                    any::<Option<u64>>(),
453                )
454                    .prop_map(
455                        |(
456                            input,
457                            group_key,
458                            order_key,
459                            limit,
460                            offset,
461                            monotonic,
462                            expected_group_size,
463                        )| MirRelationExpr::TopK {
464                            input: Box::new(input),
465                            group_key,
466                            order_key,
467                            limit,
468                            offset,
469                            monotonic,
470                            expected_group_size,
471                        },
472                    )
473                    .boxed(),
474                // Negate
475                inner
476                    .clone()
477                    .prop_map(|input| MirRelationExpr::Negate {
478                        input: Box::new(input),
479                    })
480                    .boxed(),
481                // Threshold
482                inner
483                    .clone()
484                    .prop_map(|input| MirRelationExpr::Threshold {
485                        input: Box::new(input),
486                    })
487                    .boxed(),
488                // Union
489                (
490                    inner.clone(),
491                    proptest::collection::vec(inner.clone(), 1..VEC_LEN),
492                )
493                    .prop_map(|(base, inputs)| MirRelationExpr::Union {
494                        base: Box::new(base),
495                        inputs,
496                    })
497                    .boxed(),
498                // ArrangeBy
499                (inner.clone(), any::<Vec<Vec<MirScalarExpr>>>())
500                    .prop_map(|(input, keys)| MirRelationExpr::ArrangeBy {
501                        input: Box::new(input),
502                        keys,
503                    })
504                    .boxed(),
505            ])
506        })
507        .boxed()
508    }
509
510    type Strategy = BoxedStrategy<Self>;
511}
512
513impl MirRelationExpr {
514    /// Reports the schema of the relation.
515    ///
516    /// This method determines the type through recursive traversal of the
517    /// relation expression, drawing from the types of base collections.
518    /// As such, this is not an especially cheap method, and should be used
519    /// judiciously.
520    ///
521    /// The relation type is computed incrementally with a recursive post-order
522    /// traversal, that accumulates the input types for the relations yet to be
523    /// visited in `type_stack`.
524    pub fn typ(&self) -> RelationType {
525        let mut type_stack = Vec::new();
526        #[allow(deprecated)]
527        self.visit_pre_post_nolimit(
528            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
529                match &e {
530                    MirRelationExpr::Let { body, .. } => {
531                        // Do not traverse the value sub-graph, since it's not relevant for
532                        // determining the relation type of Let operators.
533                        Some(vec![&*body])
534                    }
535                    MirRelationExpr::LetRec { body, .. } => {
536                        // Do not traverse the value sub-graph, since it's not relevant for
537                        // determining the relation type of Let operators.
538                        Some(vec![&*body])
539                    }
540                    _ => None,
541                }
542            },
543            &mut |e: &MirRelationExpr| {
544                match e {
545                    MirRelationExpr::Let { .. } => {
546                        let body_typ = type_stack.pop().unwrap();
547                        // Insert a dummy relation type for the value, since `typ_with_input_types`
548                        // won't look at it, but expects the relation type of the body to be second.
549                        type_stack.push(RelationType::empty());
550                        type_stack.push(body_typ);
551                    }
552                    MirRelationExpr::LetRec { values, .. } => {
553                        let body_typ = type_stack.pop().unwrap();
554                        // Insert dummy relation types for the values, since `typ_with_input_types`
555                        // won't look at them, but expects the relation type of the body to be last.
556                        type_stack
557                            .extend(std::iter::repeat(RelationType::empty()).take(values.len()));
558                        type_stack.push(body_typ);
559                    }
560                    _ => {}
561                }
562                let num_inputs = e.num_inputs();
563                let relation_type =
564                    e.typ_with_input_types(&type_stack[type_stack.len() - num_inputs..]);
565                type_stack.truncate(type_stack.len() - num_inputs);
566                type_stack.push(relation_type);
567            },
568        );
569        assert_eq!(type_stack.len(), 1);
570        type_stack.pop().unwrap()
571    }
572
573    /// Reports the schema of the relation given the schema of the input relations.
574    ///
575    /// `input_types` is required to contain the schemas for the input relations of
576    /// the current relation in the same order as they are visited by `try_visit_children`
577    /// method, even though not all may be used for computing the schema of the
578    /// current relation. For example, `Let` expects two input types, one for the
579    /// value relation and one for the body, in that order, but only the one for the
580    /// body is used to determine the type of the `Let` relation.
581    ///
582    /// It is meant to be used during post-order traversals to compute relation
583    /// schemas incrementally.
584    pub fn typ_with_input_types(&self, input_types: &[RelationType]) -> RelationType {
585        let column_types = self.col_with_input_cols(input_types.iter().map(|i| &i.column_types));
586        let unique_keys = self.keys_with_input_keys(
587            input_types.iter().map(|i| i.arity()),
588            input_types.iter().map(|i| &i.keys),
589        );
590        RelationType::new(column_types).with_keys(unique_keys)
591    }
592
593    /// Reports the column types of the relation given the column types of the
594    /// input relations.
595    ///
596    /// This method delegates to `try_col_with_input_cols`, panicking if an `Err`
597    /// variant is returned.
598    pub fn col_with_input_cols<'a, I>(&self, input_types: I) -> Vec<ColumnType>
599    where
600        I: Iterator<Item = &'a Vec<ColumnType>>,
601    {
602        match self.try_col_with_input_cols(input_types) {
603            Ok(col_types) => col_types,
604            Err(err) => panic!("{err}"),
605        }
606    }
607
608    /// Reports the column types of the relation given the column types of the input relations.
609    ///
610    /// `input_types` is required to contain the column types for the input relations of
611    /// the current relation in the same order as they are visited by `try_visit_children`
612    /// method, even though not all may be used for computing the schema of the
613    /// current relation. For example, `Let` expects two input types, one for the
614    /// value relation and one for the body, in that order, but only the one for the
615    /// body is used to determine the type of the `Let` relation.
616    ///
617    /// It is meant to be used during post-order traversals to compute column types
618    /// incrementally.
619    pub fn try_col_with_input_cols<'a, I>(
620        &self,
621        mut input_types: I,
622    ) -> Result<Vec<ColumnType>, String>
623    where
624        I: Iterator<Item = &'a Vec<ColumnType>>,
625    {
626        use MirRelationExpr::*;
627
628        let col_types = match self {
629            Constant { rows, typ } => {
630                let mut col_types = typ.column_types.clone();
631                let mut seen_null = vec![false; typ.arity()];
632                if let Ok(rows) = rows {
633                    for (row, _diff) in rows {
634                        for (datum, i) in row.iter().zip_eq(0..typ.arity()) {
635                            if datum.is_null() {
636                                seen_null[i] = true;
637                            }
638                        }
639                    }
640                }
641                for (&seen_null, i) in seen_null.iter().zip_eq(0..typ.arity()) {
642                    if !seen_null {
643                        col_types[i].nullable = false;
644                    } else {
645                        assert!(col_types[i].nullable);
646                    }
647                }
648                col_types
649            }
650            Get { typ, .. } => typ.column_types.clone(),
651            Project { outputs, .. } => {
652                let input = input_types.next().unwrap();
653                outputs.iter().map(|&i| input[i].clone()).collect()
654            }
655            Map { scalars, .. } => {
656                let mut result = input_types.next().unwrap().clone();
657                for scalar in scalars.iter() {
658                    result.push(scalar.typ(&result))
659                }
660                result
661            }
662            FlatMap { func, .. } => {
663                let mut result = input_types.next().unwrap().clone();
664                result.extend(func.output_type().column_types);
665                result
666            }
667            Filter { predicates, .. } => {
668                let mut result = input_types.next().unwrap().clone();
669
670                // Set as nonnull any columns where null values would cause
671                // any predicate to evaluate to null.
672                for column in non_nullable_columns(predicates) {
673                    result[column].nullable = false;
674                }
675                result
676            }
677            Join { equivalences, .. } => {
678                // Concatenate input column types
679                let mut types = input_types.flat_map(|cols| cols.to_owned()).collect_vec();
680                // In an equivalence class, if any column is non-null, then make all non-null
681                for equivalence in equivalences {
682                    let col_inds = equivalence
683                        .iter()
684                        .filter_map(|expr| match expr {
685                            MirScalarExpr::Column(col, _name) => Some(*col),
686                            _ => None,
687                        })
688                        .collect_vec();
689                    if col_inds.iter().any(|i| !types.get(*i).unwrap().nullable) {
690                        for i in col_inds {
691                            types.get_mut(i).unwrap().nullable = false;
692                        }
693                    }
694                }
695                types
696            }
697            Reduce {
698                group_key,
699                aggregates,
700                ..
701            } => {
702                let input = input_types.next().unwrap();
703                group_key
704                    .iter()
705                    .map(|e| e.typ(input))
706                    .chain(aggregates.iter().map(|agg| agg.typ(input)))
707                    .collect()
708            }
709            TopK { .. } | Negate { .. } | Threshold { .. } | ArrangeBy { .. } => {
710                input_types.next().unwrap().clone()
711            }
712            Let { .. } => {
713                // skip over the input types for `value`.
714                input_types.nth(1).unwrap().clone()
715            }
716            LetRec { values, .. } => {
717                // skip over the input types for `values`.
718                input_types.nth(values.len()).unwrap().clone()
719            }
720            Union { .. } => {
721                let mut result = input_types.next().unwrap().clone();
722                for input_col_types in input_types {
723                    for (base_col, col) in result.iter_mut().zip_eq(input_col_types) {
724                        *base_col = base_col
725                            .union(col)
726                            .map_err(|e| format!("{}\nin plan:\n{}", e, self.pretty()))?;
727                    }
728                }
729                result
730            }
731        };
732
733        Ok(col_types)
734    }
735
736    /// Reports the unique keys of the relation given the arities and the unique
737    /// keys of the input relations.
738    ///
739    /// `input_arities` and `input_keys` are required to contain the
740    /// corresponding info for the input relations of
741    /// the current relation in the same order as they are visited by `try_visit_children`
742    /// method, even though not all may be used for computing the schema of the
743    /// current relation. For example, `Let` expects two input types, one for the
744    /// value relation and one for the body, in that order, but only the one for the
745    /// body is used to determine the type of the `Let` relation.
746    ///
747    /// It is meant to be used during post-order traversals to compute unique keys
748    /// incrementally.
749    pub fn keys_with_input_keys<'a, I, J>(
750        &self,
751        mut input_arities: I,
752        mut input_keys: J,
753    ) -> Vec<Vec<usize>>
754    where
755        I: Iterator<Item = usize>,
756        J: Iterator<Item = &'a Vec<Vec<usize>>>,
757    {
758        use MirRelationExpr::*;
759
760        let mut keys = match self {
761            Constant {
762                rows: Ok(rows),
763                typ,
764            } => {
765                let n_cols = typ.arity();
766                // If the `i`th entry is `Some`, then we have not yet observed non-uniqueness in the `i`th column.
767                let mut unique_values_per_col = vec![Some(BTreeSet::<Datum>::default()); n_cols];
768                for (row, diff) in rows {
769                    for (i, datum) in row.iter().enumerate() {
770                        if datum != Datum::Dummy {
771                            if let Some(unique_vals) = &mut unique_values_per_col[i] {
772                                let is_dupe = *diff != Diff::ONE || !unique_vals.insert(datum);
773                                if is_dupe {
774                                    unique_values_per_col[i] = None;
775                                }
776                            }
777                        }
778                    }
779                }
780                if rows.len() == 0 || (rows.len() == 1 && rows[0].1 == Diff::ONE) {
781                    vec![vec![]]
782                } else {
783                    // XXX - Multi-column keys are not detected.
784                    typ.keys
785                        .iter()
786                        .cloned()
787                        .chain(
788                            unique_values_per_col
789                                .into_iter()
790                                .enumerate()
791                                .filter(|(_idx, unique_vals)| unique_vals.is_some())
792                                .map(|(idx, _)| vec![idx]),
793                        )
794                        .collect()
795                }
796            }
797            Constant { rows: Err(_), typ } | Get { typ, .. } => typ.keys.clone(),
798            Threshold { .. } | ArrangeBy { .. } => input_keys.next().unwrap().clone(),
799            Let { .. } => {
800                // skip over the unique keys for value
801                input_keys.nth(1).unwrap().clone()
802            }
803            LetRec { values, .. } => {
804                // skip over the unique keys for value
805                input_keys.nth(values.len()).unwrap().clone()
806            }
807            Project { outputs, .. } => {
808                let input = input_keys.next().unwrap();
809                input
810                    .iter()
811                    .filter_map(|key_set| {
812                        if key_set.iter().all(|k| outputs.contains(k)) {
813                            Some(
814                                key_set
815                                    .iter()
816                                    .map(|c| outputs.iter().position(|o| o == c).unwrap())
817                                    .collect(),
818                            )
819                        } else {
820                            None
821                        }
822                    })
823                    .collect()
824            }
825            Map { scalars, .. } => {
826                let mut remappings = Vec::new();
827                let arity = input_arities.next().unwrap();
828                for (column, scalar) in scalars.iter().enumerate() {
829                    // assess whether the scalar preserves uniqueness,
830                    // and could participate in a key!
831
832                    fn uniqueness(expr: &MirScalarExpr) -> Option<usize> {
833                        match expr {
834                            MirScalarExpr::CallUnary { func, expr } => {
835                                if func.preserves_uniqueness() {
836                                    uniqueness(expr)
837                                } else {
838                                    None
839                                }
840                            }
841                            MirScalarExpr::Column(c, _name) => Some(*c),
842                            _ => None,
843                        }
844                    }
845
846                    if let Some(c) = uniqueness(scalar) {
847                        remappings.push((c, column + arity));
848                    }
849                }
850
851                let mut result = input_keys.next().unwrap().clone();
852                let mut new_keys = Vec::new();
853                // Any column in `remappings` could be replaced in a key
854                // by the corresponding c. This could lead to combinatorial
855                // explosion using our current representation, so we wont
856                // do that. Instead, we'll handle the case of one remapping.
857                if remappings.len() == 1 {
858                    let (old, new) = remappings.pop().unwrap();
859                    for key in &result {
860                        if key.contains(&old) {
861                            let mut new_key: Vec<usize> =
862                                key.iter().cloned().filter(|k| k != &old).collect();
863                            new_key.push(new);
864                            new_key.sort_unstable();
865                            new_keys.push(new_key);
866                        }
867                    }
868                    result.append(&mut new_keys);
869                }
870                result
871            }
872            FlatMap { .. } => {
873                // FlatMap can add duplicate rows, so input keys are no longer
874                // valid
875                vec![]
876            }
877            Negate { .. } => {
878                // Although negate may have distinct records for each key,
879                // the multiplicity is -1 rather than 1. This breaks many
880                // of the optimization uses of "keys".
881                vec![]
882            }
883            Filter { predicates, .. } => {
884                // A filter inherits the keys of its input unless the filters
885                // have reduced the input to a single row, in which case the
886                // keys of the input are `()`.
887                let mut input = input_keys.next().unwrap().clone();
888
889                if !input.is_empty() {
890                    // Track columns equated to literals, which we can prune.
891                    let mut cols_equal_to_literal = BTreeSet::new();
892
893                    // Perform union find on `col1 = col2` to establish
894                    // connected components of equated columns. Absent any
895                    // equalities, this will be `0 .. #c` (where #c is the
896                    // greatest column referenced by a predicate), but each
897                    // equality will orient the root of the greater to the root
898                    // of the lesser.
899                    let mut union_find = Vec::new();
900
901                    for expr in predicates.iter() {
902                        if let MirScalarExpr::CallBinary {
903                            func: crate::BinaryFunc::Eq,
904                            expr1,
905                            expr2,
906                        } = expr
907                        {
908                            if let MirScalarExpr::Column(c, _name) = &**expr1 {
909                                if expr2.is_literal_ok() {
910                                    cols_equal_to_literal.insert(c);
911                                }
912                            }
913                            if let MirScalarExpr::Column(c, _name) = &**expr2 {
914                                if expr1.is_literal_ok() {
915                                    cols_equal_to_literal.insert(c);
916                                }
917                            }
918                            // Perform union-find to equate columns.
919                            if let (Some(c1), Some(c2)) = (expr1.as_column(), expr2.as_column()) {
920                                if c1 != c2 {
921                                    // Ensure union_find has entries up to
922                                    // max(c1, c2) by filling up missing
923                                    // positions with identity mappings.
924                                    while union_find.len() <= std::cmp::max(c1, c2) {
925                                        union_find.push(union_find.len());
926                                    }
927                                    let mut r1 = c1; // Find the representative column of [c1].
928                                    while r1 != union_find[r1] {
929                                        assert!(union_find[r1] < r1);
930                                        r1 = union_find[r1];
931                                    }
932                                    let mut r2 = c2; // Find the representative column of [c2].
933                                    while r2 != union_find[r2] {
934                                        assert!(union_find[r2] < r2);
935                                        r2 = union_find[r2];
936                                    }
937                                    // Union [c1] and [c2] by pointing the
938                                    // larger to the smaller representative (we
939                                    // update the remaining equivalence class
940                                    // members only once after this for-loop).
941                                    union_find[std::cmp::max(r1, r2)] = std::cmp::min(r1, r2);
942                                }
943                            }
944                        }
945                    }
946
947                    // Complete union-find by pointing each element at its representative column.
948                    for i in 0..union_find.len() {
949                        // Iteration not required, as each prior already references the right column.
950                        union_find[i] = union_find[union_find[i]];
951                    }
952
953                    // Remove columns bound to literals, and remap columns equated to earlier columns.
954                    // We will re-expand remapped columns in a moment, but this avoids exponential work.
955                    for key_set in &mut input {
956                        key_set.retain(|k| !cols_equal_to_literal.contains(&k));
957                        for col in key_set.iter_mut() {
958                            if let Some(equiv) = union_find.get(*col) {
959                                *col = *equiv;
960                            }
961                        }
962                        key_set.sort();
963                        key_set.dedup();
964                    }
965                    input.sort();
966                    input.dedup();
967
968                    // Expand out each key to each of its equivalent forms.
969                    // Each instance of `col` can be replaced by any equivalent column.
970                    // This has the potential to result in exponentially sized number of unique keys,
971                    // and in the future we should probably maintain unique keys modulo equivalence.
972
973                    // First, compute an inverse map from each representative
974                    // column `sub` to all other equivalent columns `col`.
975                    let mut subs = Vec::new();
976                    for (col, sub) in union_find.iter().enumerate() {
977                        if *sub != col {
978                            assert!(*sub < col);
979                            while subs.len() <= *sub {
980                                subs.push(Vec::new());
981                            }
982                            subs[*sub].push(col);
983                        }
984                    }
985                    // For each column, substitute for it in each occurrence.
986                    let mut to_add = Vec::new();
987                    for (col, subs) in subs.iter().enumerate() {
988                        if !subs.is_empty() {
989                            for key_set in input.iter() {
990                                if key_set.contains(&col) {
991                                    let mut to_extend = key_set.clone();
992                                    to_extend.retain(|c| c != &col);
993                                    for sub in subs {
994                                        to_extend.push(*sub);
995                                        to_add.push(to_extend.clone());
996                                        to_extend.pop();
997                                    }
998                                }
999                            }
1000                        }
1001                        // No deduplication, as we cannot introduce duplicates.
1002                        input.append(&mut to_add);
1003                    }
1004                    for key_set in input.iter_mut() {
1005                        key_set.sort();
1006                        key_set.dedup();
1007                    }
1008                }
1009                input
1010            }
1011            Join { equivalences, .. } => {
1012                // It is important the `new_from_input_arities` constructor is
1013                // used. Otherwise, Materialize may potentially end up in an
1014                // infinite loop.
1015                let input_mapper = crate::JoinInputMapper::new_from_input_arities(input_arities);
1016
1017                input_mapper.global_keys(input_keys, equivalences)
1018            }
1019            Reduce { group_key, .. } => {
1020                // The group key should form a key, but we might already have
1021                // keys that are subsets of the group key, and should retain
1022                // those instead, if so.
1023                let mut result = Vec::new();
1024                for key_set in input_keys.next().unwrap() {
1025                    if key_set
1026                        .iter()
1027                        .all(|k| group_key.contains(&MirScalarExpr::column(*k)))
1028                    {
1029                        result.push(
1030                            key_set
1031                                .iter()
1032                                .map(|i| {
1033                                    group_key
1034                                        .iter()
1035                                        .position(|k| k == &MirScalarExpr::column(*i))
1036                                        .unwrap()
1037                                })
1038                                .collect::<Vec<_>>(),
1039                        );
1040                    }
1041                }
1042                if result.is_empty() {
1043                    result.push((0..group_key.len()).collect());
1044                }
1045                result
1046            }
1047            TopK {
1048                group_key, limit, ..
1049            } => {
1050                // If `limit` is `Some(1)` then the group key will become
1051                // a unique key, as there will be only one record with that key.
1052                let mut result = input_keys.next().unwrap().clone();
1053                if limit.as_ref().and_then(|x| x.as_literal_int64()) == Some(1) {
1054                    result.push(group_key.clone())
1055                }
1056                result
1057            }
1058            Union { base, inputs } => {
1059                // Generally, unions do not have any unique keys, because
1060                // each input might duplicate some. However, there is at
1061                // least one idiomatic structure that does preserve keys,
1062                // which results from SQL aggregations that must populate
1063                // absent records with default values. In that pattern,
1064                // the union of one GET with its negation, which has first
1065                // been subjected to a projection and map, we can remove
1066                // their influence on the key structure.
1067                //
1068                // If there are A, B, each with a unique `key` such that
1069                // we are looking at
1070                //
1071                //     A.proj(set_containing_key) + (B - A.proj(key)).map(stuff)
1072                //
1073                // Then we can report `key` as a unique key.
1074                //
1075                // TODO: make unique key structure an optimization analysis
1076                // rather than part of the type information.
1077                // TODO: perhaps ensure that (above) A.proj(key) is a
1078                // subset of B, as otherwise there are negative records
1079                // and who knows what is true (not expected, but again
1080                // who knows what the query plan might look like).
1081
1082                let arity = input_arities.next().unwrap();
1083                let (base_projection, base_with_project_stripped) =
1084                    if let MirRelationExpr::Project { input, outputs } = &**base {
1085                        (outputs.clone(), &**input)
1086                    } else {
1087                        // A input without a project is equivalent to an input
1088                        // with the project being all columns in the input in order.
1089                        ((0..arity).collect::<Vec<_>>(), &**base)
1090                    };
1091                let mut result = Vec::new();
1092                if let MirRelationExpr::Get {
1093                    id: first_id,
1094                    typ: _,
1095                    ..
1096                } = base_with_project_stripped
1097                {
1098                    if inputs.len() == 1 {
1099                        if let MirRelationExpr::Map { input, .. } = &inputs[0] {
1100                            if let MirRelationExpr::Union { base, inputs } = &**input {
1101                                if inputs.len() == 1 {
1102                                    if let Some((input, outputs)) = base.is_negated_project() {
1103                                        if let MirRelationExpr::Get {
1104                                            id: second_id,
1105                                            typ: _,
1106                                            ..
1107                                        } = input
1108                                        {
1109                                            if first_id == second_id {
1110                                                result.extend(
1111                                                    input_keys
1112                                                        .next()
1113                                                        .unwrap()
1114                                                        .into_iter()
1115                                                        .filter(|key| {
1116                                                            key.iter().all(|c| {
1117                                                                outputs.get(*c) == Some(c)
1118                                                                    && base_projection.get(*c)
1119                                                                        == Some(c)
1120                                                            })
1121                                                        })
1122                                                        .cloned(),
1123                                                );
1124                                            }
1125                                        }
1126                                    }
1127                                }
1128                            }
1129                        }
1130                    }
1131                }
1132                // Important: do not inherit keys of either input, as not unique.
1133                result
1134            }
1135        };
1136        keys.sort();
1137        keys.dedup();
1138        keys
1139    }
1140
1141    /// The number of columns in the relation.
1142    ///
1143    /// This number is determined from the type, which is determined recursively
1144    /// at non-trivial cost.
1145    ///
1146    /// The arity is computed incrementally with a recursive post-order
1147    /// traversal, that accumulates the arities for the relations yet to be
1148    /// visited in `arity_stack`.
1149    pub fn arity(&self) -> usize {
1150        let mut arity_stack = Vec::new();
1151        #[allow(deprecated)]
1152        self.visit_pre_post_nolimit(
1153            &mut |e: &MirRelationExpr| -> Option<Vec<&MirRelationExpr>> {
1154                match &e {
1155                    MirRelationExpr::Let { body, .. } => {
1156                        // Do not traverse the value sub-graph, since it's not relevant for
1157                        // determining the arity of Let operators.
1158                        Some(vec![&*body])
1159                    }
1160                    MirRelationExpr::LetRec { body, .. } => {
1161                        // Do not traverse the value sub-graph, since it's not relevant for
1162                        // determining the arity of Let operators.
1163                        Some(vec![&*body])
1164                    }
1165                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1166                        // No further traversal is required; these operators know their arity.
1167                        Some(Vec::new())
1168                    }
1169                    _ => None,
1170                }
1171            },
1172            &mut |e: &MirRelationExpr| {
1173                match &e {
1174                    MirRelationExpr::Let { .. } => {
1175                        let body_arity = arity_stack.pop().unwrap();
1176                        arity_stack.push(0);
1177                        arity_stack.push(body_arity);
1178                    }
1179                    MirRelationExpr::LetRec { values, .. } => {
1180                        let body_arity = arity_stack.pop().unwrap();
1181                        arity_stack.extend(std::iter::repeat(0).take(values.len()));
1182                        arity_stack.push(body_arity);
1183                    }
1184                    MirRelationExpr::Project { .. } | MirRelationExpr::Reduce { .. } => {
1185                        arity_stack.push(0);
1186                    }
1187                    _ => {}
1188                }
1189                let num_inputs = e.num_inputs();
1190                let input_arities = arity_stack.drain(arity_stack.len() - num_inputs..);
1191                let arity = e.arity_with_input_arities(input_arities);
1192                arity_stack.push(arity);
1193            },
1194        );
1195        assert_eq!(arity_stack.len(), 1);
1196        arity_stack.pop().unwrap()
1197    }
1198
1199    /// Reports the arity of the relation given the schema of the input relations.
1200    ///
1201    /// `input_arities` is required to contain the arities for the input relations of
1202    /// the current relation in the same order as they are visited by `try_visit_children`
1203    /// method, even though not all may be used for computing the schema of the
1204    /// current relation. For example, `Let` expects two input types, one for the
1205    /// value relation and one for the body, in that order, but only the one for the
1206    /// body is used to determine the type of the `Let` relation.
1207    ///
1208    /// It is meant to be used during post-order traversals to compute arities
1209    /// incrementally.
1210    pub fn arity_with_input_arities<I>(&self, mut input_arities: I) -> usize
1211    where
1212        I: Iterator<Item = usize>,
1213    {
1214        use MirRelationExpr::*;
1215
1216        match self {
1217            Constant { rows: _, typ } => typ.arity(),
1218            Get { typ, .. } => typ.arity(),
1219            Let { .. } => {
1220                input_arities.next();
1221                input_arities.next().unwrap()
1222            }
1223            LetRec { values, .. } => {
1224                for _ in 0..values.len() {
1225                    input_arities.next();
1226                }
1227                input_arities.next().unwrap()
1228            }
1229            Project { outputs, .. } => outputs.len(),
1230            Map { scalars, .. } => input_arities.next().unwrap() + scalars.len(),
1231            FlatMap { func, .. } => input_arities.next().unwrap() + func.output_arity(),
1232            Join { .. } => input_arities.sum(),
1233            Reduce {
1234                input: _,
1235                group_key,
1236                aggregates,
1237                ..
1238            } => group_key.len() + aggregates.len(),
1239            Filter { .. }
1240            | TopK { .. }
1241            | Negate { .. }
1242            | Threshold { .. }
1243            | Union { .. }
1244            | ArrangeBy { .. } => input_arities.next().unwrap(),
1245        }
1246    }
1247
1248    /// The number of child relations this relation has.
1249    pub fn num_inputs(&self) -> usize {
1250        let mut count = 0;
1251
1252        self.visit_children(|_| count += 1);
1253
1254        count
1255    }
1256
1257    /// Constructs a constant collection from specific rows and schema, where
1258    /// each row will have a multiplicity of one.
1259    pub fn constant(rows: Vec<Vec<Datum>>, typ: RelationType) -> Self {
1260        let rows = rows.into_iter().map(|row| (row, Diff::ONE)).collect();
1261        MirRelationExpr::constant_diff(rows, typ)
1262    }
1263
1264    /// Constructs a constant collection from specific rows and schema, where
1265    /// each row can have an arbitrary multiplicity.
1266    pub fn constant_diff(rows: Vec<(Vec<Datum>, Diff)>, typ: RelationType) -> Self {
1267        for (row, _diff) in &rows {
1268            for (datum, column_typ) in row.iter().zip(typ.column_types.iter()) {
1269                assert!(
1270                    datum.is_instance_of(column_typ),
1271                    "Expected datum of type {:?}, got value {:?}",
1272                    column_typ,
1273                    datum
1274                );
1275            }
1276        }
1277        let rows = Ok(rows
1278            .into_iter()
1279            .map(move |(row, diff)| (Row::pack_slice(&row), diff))
1280            .collect());
1281        MirRelationExpr::Constant { rows, typ }
1282    }
1283
1284    /// If self is a constant, return the value and the type, otherwise `None`.
1285    /// Looks behind `ArrangeBy`s.
1286    pub fn as_const(&self) -> Option<(&Result<Vec<(Row, Diff)>, EvalError>, &RelationType)> {
1287        match self {
1288            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1289            MirRelationExpr::ArrangeBy { input, .. } => input.as_const(),
1290            _ => None,
1291        }
1292    }
1293
1294    /// If self is a constant, mutably return the value and the type, otherwise `None`.
1295    /// Looks behind `ArrangeBy`s.
1296    pub fn as_const_mut(
1297        &mut self,
1298    ) -> Option<(&mut Result<Vec<(Row, Diff)>, EvalError>, &mut RelationType)> {
1299        match self {
1300            MirRelationExpr::Constant { rows, typ } => Some((rows, typ)),
1301            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_mut(),
1302            _ => None,
1303        }
1304    }
1305
1306    /// If self is a constant error, return the error, otherwise `None`.
1307    /// Looks behind `ArrangeBy`s.
1308    pub fn as_const_err(&self) -> Option<&EvalError> {
1309        match self {
1310            MirRelationExpr::Constant { rows: Err(e), .. } => Some(e),
1311            MirRelationExpr::ArrangeBy { input, .. } => input.as_const_err(),
1312            _ => None,
1313        }
1314    }
1315
1316    /// Checks if `self` is the single element collection with no columns.
1317    pub fn is_constant_singleton(&self) -> bool {
1318        if let Some((Ok(rows), typ)) = self.as_const() {
1319            rows.len() == 1 && typ.column_types.len() == 0 && rows[0].1 == Diff::ONE
1320        } else {
1321            false
1322        }
1323    }
1324
1325    /// Constructs the expression for getting a local collection.
1326    pub fn local_get(id: LocalId, typ: RelationType) -> Self {
1327        MirRelationExpr::Get {
1328            id: Id::Local(id),
1329            typ,
1330            access_strategy: AccessStrategy::UnknownOrLocal,
1331        }
1332    }
1333
1334    /// Constructs the expression for getting a global collection
1335    pub fn global_get(id: GlobalId, typ: RelationType) -> Self {
1336        MirRelationExpr::Get {
1337            id: Id::Global(id),
1338            typ,
1339            access_strategy: AccessStrategy::UnknownOrLocal,
1340        }
1341    }
1342
1343    /// Retains only the columns specified by `output`.
1344    pub fn project(mut self, mut outputs: Vec<usize>) -> Self {
1345        if let MirRelationExpr::Project {
1346            outputs: columns, ..
1347        } = &mut self
1348        {
1349            // Update `outputs` to reference base columns of `input`.
1350            for column in outputs.iter_mut() {
1351                *column = columns[*column];
1352            }
1353            *columns = outputs;
1354            self
1355        } else {
1356            MirRelationExpr::Project {
1357                input: Box::new(self),
1358                outputs,
1359            }
1360        }
1361    }
1362
1363    /// Append to each row the results of applying elements of `scalar`.
1364    pub fn map(mut self, scalars: Vec<MirScalarExpr>) -> Self {
1365        if let MirRelationExpr::Map { scalars: s, .. } = &mut self {
1366            s.extend(scalars);
1367            self
1368        } else if !scalars.is_empty() {
1369            MirRelationExpr::Map {
1370                input: Box::new(self),
1371                scalars,
1372            }
1373        } else {
1374            self
1375        }
1376    }
1377
1378    /// Append to each row a single `scalar`.
1379    pub fn map_one(self, scalar: MirScalarExpr) -> Self {
1380        self.map(vec![scalar])
1381    }
1382
1383    /// Like `map`, but yields zero-or-more output rows per input row
1384    pub fn flat_map(self, func: TableFunc, exprs: Vec<MirScalarExpr>) -> Self {
1385        MirRelationExpr::FlatMap {
1386            input: Box::new(self),
1387            func,
1388            exprs,
1389        }
1390    }
1391
1392    /// Retain only the rows satisfying each of several predicates.
1393    pub fn filter<I>(mut self, predicates: I) -> Self
1394    where
1395        I: IntoIterator<Item = MirScalarExpr>,
1396    {
1397        // Extract existing predicates
1398        let mut new_predicates = if let MirRelationExpr::Filter { input, predicates } = self {
1399            self = *input;
1400            predicates
1401        } else {
1402            Vec::new()
1403        };
1404        // Normalize collection of predicates.
1405        new_predicates.extend(predicates);
1406        new_predicates.retain(|p| !p.is_literal_true());
1407        new_predicates.sort();
1408        new_predicates.dedup();
1409        // Introduce a `Filter` only if we have predicates.
1410        if !new_predicates.is_empty() {
1411            self = MirRelationExpr::Filter {
1412                input: Box::new(self),
1413                predicates: new_predicates,
1414            };
1415        }
1416
1417        self
1418    }
1419
1420    /// Form the Cartesian outer-product of rows in both inputs.
1421    pub fn product(mut self, right: Self) -> Self {
1422        if right.is_constant_singleton() {
1423            self
1424        } else if self.is_constant_singleton() {
1425            right
1426        } else if let MirRelationExpr::Join { inputs, .. } = &mut self {
1427            inputs.push(right);
1428            self
1429        } else {
1430            MirRelationExpr::join(vec![self, right], vec![])
1431        }
1432    }
1433
1434    /// Performs a relational equijoin among the input collections.
1435    ///
1436    /// The sequence `inputs` each describe different input collections, and the sequence `variables` describes
1437    /// equality constraints that some of their columns must satisfy. Each element in `variable` describes a set
1438    /// of pairs  `(input_index, column_index)` where every value described by that set must be equal.
1439    ///
1440    /// For example, the pair `(input, column)` indexes into `inputs[input][column]`, extracting the `input`th
1441    /// input collection and for each row examining its `column`th column.
1442    ///
1443    /// # Example
1444    ///
1445    /// ```rust
1446    /// use mz_repr::{Datum, ColumnType, RelationType, ScalarType};
1447    /// use mz_expr::MirRelationExpr;
1448    ///
1449    /// // A common schema for each input.
1450    /// let schema = RelationType::new(vec![
1451    ///     ScalarType::Int32.nullable(false),
1452    ///     ScalarType::Int32.nullable(false),
1453    /// ]);
1454    ///
1455    /// // the specific data are not important here.
1456    /// let data = vec![Datum::Int32(0), Datum::Int32(1)];
1457    ///
1458    /// // Three collections that could have been different.
1459    /// let input0 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1460    /// let input1 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1461    /// let input2 = MirRelationExpr::constant(vec![data.clone()], schema.clone());
1462    ///
1463    /// // Join the three relations looking for triangles, like so.
1464    /// //
1465    /// //     Output(A,B,C) := Input0(A,B), Input1(B,C), Input2(A,C)
1466    /// let joined = MirRelationExpr::join(
1467    ///     vec![input0, input1, input2],
1468    ///     vec![
1469    ///         vec![(0,0), (2,0)], // fields A of inputs 0 and 2.
1470    ///         vec![(0,1), (1,0)], // fields B of inputs 0 and 1.
1471    ///         vec![(1,1), (2,1)], // fields C of inputs 1 and 2.
1472    ///     ],
1473    /// );
1474    ///
1475    /// // Technically the above produces `Output(A,B,B,C,A,C)` because the columns are concatenated.
1476    /// // A projection resolves this and produces the correct output.
1477    /// let result = joined.project(vec![0, 1, 3]);
1478    /// ```
1479    pub fn join(inputs: Vec<MirRelationExpr>, variables: Vec<Vec<(usize, usize)>>) -> Self {
1480        let input_mapper = join_input_mapper::JoinInputMapper::new(&inputs);
1481
1482        let equivalences = variables
1483            .into_iter()
1484            .map(|vs| {
1485                vs.into_iter()
1486                    .map(|(r, c)| input_mapper.map_expr_to_global(MirScalarExpr::column(c), r))
1487                    .collect::<Vec<_>>()
1488            })
1489            .collect::<Vec<_>>();
1490
1491        Self::join_scalars(inputs, equivalences)
1492    }
1493
1494    /// Constructs a join operator from inputs and required-equal scalar expressions.
1495    pub fn join_scalars(
1496        mut inputs: Vec<MirRelationExpr>,
1497        equivalences: Vec<Vec<MirScalarExpr>>,
1498    ) -> Self {
1499        // Remove all constant inputs that are the identity for join.
1500        // They neither introduce nor modify any column references.
1501        inputs.retain(|i| !i.is_constant_singleton());
1502        MirRelationExpr::Join {
1503            inputs,
1504            equivalences,
1505            implementation: JoinImplementation::Unimplemented,
1506        }
1507    }
1508
1509    /// Perform a key-wise reduction / aggregation.
1510    ///
1511    /// The `group_key` argument indicates columns in the input collection that should
1512    /// be grouped, and `aggregates` lists aggregation functions each of which produces
1513    /// one output column in addition to the keys.
1514    pub fn reduce(
1515        self,
1516        group_key: Vec<usize>,
1517        aggregates: Vec<AggregateExpr>,
1518        expected_group_size: Option<u64>,
1519    ) -> Self {
1520        MirRelationExpr::Reduce {
1521            input: Box::new(self),
1522            group_key: group_key.into_iter().map(MirScalarExpr::column).collect(),
1523            aggregates,
1524            monotonic: false,
1525            expected_group_size,
1526        }
1527    }
1528
1529    /// Perform a key-wise reduction order by and limit.
1530    ///
1531    /// The `group_key` argument indicates columns in the input collection that should
1532    /// be grouped, the `order_key` argument indicates columns that should be further
1533    /// used to order records within groups, and the `limit` argument constrains the
1534    /// total number of records that should be produced in each group.
1535    pub fn top_k(
1536        self,
1537        group_key: Vec<usize>,
1538        order_key: Vec<ColumnOrder>,
1539        limit: Option<MirScalarExpr>,
1540        offset: usize,
1541        expected_group_size: Option<u64>,
1542    ) -> Self {
1543        MirRelationExpr::TopK {
1544            input: Box::new(self),
1545            group_key,
1546            order_key,
1547            limit,
1548            offset,
1549            expected_group_size,
1550            monotonic: false,
1551        }
1552    }
1553
1554    /// Negates the occurrences of each row.
1555    pub fn negate(self) -> Self {
1556        if let MirRelationExpr::Negate { input } = self {
1557            *input
1558        } else {
1559            MirRelationExpr::Negate {
1560                input: Box::new(self),
1561            }
1562        }
1563    }
1564
1565    /// Removes all but the first occurrence of each row.
1566    pub fn distinct(self) -> Self {
1567        let arity = self.arity();
1568        self.distinct_by((0..arity).collect())
1569    }
1570
1571    /// Removes all but the first occurrence of each key. Columns not included
1572    /// in the `group_key` are discarded.
1573    pub fn distinct_by(self, group_key: Vec<usize>) -> Self {
1574        self.reduce(group_key, vec![], None)
1575    }
1576
1577    /// Discards rows with a negative frequency.
1578    pub fn threshold(self) -> Self {
1579        if let MirRelationExpr::Threshold { .. } = &self {
1580            self
1581        } else {
1582            MirRelationExpr::Threshold {
1583                input: Box::new(self),
1584            }
1585        }
1586    }
1587
1588    /// Unions together any number inputs.
1589    ///
1590    /// If `inputs` is empty, then an empty relation of type `typ` is
1591    /// constructed.
1592    pub fn union_many(mut inputs: Vec<Self>, typ: RelationType) -> Self {
1593        // Deconstruct `inputs` as `Union`s and reconstitute.
1594        let mut flat_inputs = Vec::with_capacity(inputs.len());
1595        for input in inputs {
1596            if let MirRelationExpr::Union { base, inputs } = input {
1597                flat_inputs.push(*base);
1598                flat_inputs.extend(inputs);
1599            } else {
1600                flat_inputs.push(input);
1601            }
1602        }
1603        inputs = flat_inputs;
1604        if inputs.len() == 0 {
1605            MirRelationExpr::Constant {
1606                rows: Ok(vec![]),
1607                typ,
1608            }
1609        } else if inputs.len() == 1 {
1610            inputs.into_element()
1611        } else {
1612            MirRelationExpr::Union {
1613                base: Box::new(inputs.remove(0)),
1614                inputs,
1615            }
1616        }
1617    }
1618
1619    /// Produces one collection where each row is present with the sum of its frequencies in each input.
1620    pub fn union(self, other: Self) -> Self {
1621        // Deconstruct `self` and `other` as `Union`s and reconstitute.
1622        let mut flat_inputs = Vec::with_capacity(2);
1623        if let MirRelationExpr::Union { base, inputs } = self {
1624            flat_inputs.push(*base);
1625            flat_inputs.extend(inputs);
1626        } else {
1627            flat_inputs.push(self);
1628        }
1629        if let MirRelationExpr::Union { base, inputs } = other {
1630            flat_inputs.push(*base);
1631            flat_inputs.extend(inputs);
1632        } else {
1633            flat_inputs.push(other);
1634        }
1635
1636        MirRelationExpr::Union {
1637            base: Box::new(flat_inputs.remove(0)),
1638            inputs: flat_inputs,
1639        }
1640    }
1641
1642    /// Arranges the collection by the specified columns
1643    pub fn arrange_by(self, keys: &[Vec<MirScalarExpr>]) -> Self {
1644        MirRelationExpr::ArrangeBy {
1645            input: Box::new(self),
1646            keys: keys.to_owned(),
1647        }
1648    }
1649
1650    /// Indicates if this is a constant empty collection.
1651    ///
1652    /// A false value does not mean the collection is known to be non-empty,
1653    /// only that we cannot currently determine that it is statically empty.
1654    pub fn is_empty(&self) -> bool {
1655        if let Some((Ok(rows), ..)) = self.as_const() {
1656            rows.is_empty()
1657        } else {
1658            false
1659        }
1660    }
1661
1662    /// If the expression is a negated project, return the input and the projection.
1663    pub fn is_negated_project(&self) -> Option<(&MirRelationExpr, &[usize])> {
1664        if let MirRelationExpr::Negate { input } = self {
1665            if let MirRelationExpr::Project { input, outputs } = &**input {
1666                return Some((&**input, outputs));
1667            }
1668        }
1669        if let MirRelationExpr::Project { input, outputs } = self {
1670            if let MirRelationExpr::Negate { input } = &**input {
1671                return Some((&**input, outputs));
1672            }
1673        }
1674        None
1675    }
1676
1677    /// Pretty-print this [MirRelationExpr] to a string.
1678    pub fn pretty(&self) -> String {
1679        let config = ExplainConfig::default();
1680        self.explain(&config, None)
1681    }
1682
1683    /// Pretty-print this [MirRelationExpr] to a string using a custom
1684    /// [ExplainConfig] and an optionally provided [ExprHumanizer].
1685    pub fn explain(&self, config: &ExplainConfig, humanizer: Option<&dyn ExprHumanizer>) -> String {
1686        text_string_at(self, || PlanRenderingContext {
1687            indent: Indent::default(),
1688            humanizer: humanizer.unwrap_or(&DummyHumanizer),
1689            annotations: BTreeMap::default(),
1690            config,
1691        })
1692    }
1693
1694    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the optionally
1695    /// given scalar types. The given scalar types should be `base_eq` with the types that `typ()`
1696    /// would find. Keys and nullability are ignored in the given `RelationType`, and instead we set
1697    /// the best possible key and nullability, since we are making an empty collection.
1698    ///
1699    /// If `typ` is not given, then this calls `.typ()` (which is possibly expensive) to determine
1700    /// the correct type.
1701    pub fn take_safely(&mut self, typ: Option<RelationType>) -> MirRelationExpr {
1702        if let Some(typ) = &typ {
1703            soft_assert_no_log!(
1704                self.typ()
1705                    .column_types
1706                    .iter()
1707                    .zip_eq(typ.column_types.iter())
1708                    .all(|(t1, t2)| t1.scalar_type.base_eq(&t2.scalar_type))
1709            );
1710        }
1711        let mut typ = typ.unwrap_or_else(|| self.typ());
1712        typ.keys = vec![vec![]];
1713        for ct in typ.column_types.iter_mut() {
1714            ct.nullable = false;
1715        }
1716        std::mem::replace(
1717            self,
1718            MirRelationExpr::Constant {
1719                rows: Ok(vec![]),
1720                typ,
1721            },
1722        )
1723    }
1724
1725    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with the given scalar
1726    /// types. Nullability is ignored in the given `ColumnType`s, and instead we set the best
1727    /// possible nullability, since we are making an empty collection.
1728    pub fn take_safely_with_col_types(&mut self, typ: Vec<ColumnType>) -> MirRelationExpr {
1729        self.take_safely(Some(RelationType::new(typ)))
1730    }
1731
1732    /// Take ownership of `self`, leaving an empty `MirRelationExpr::Constant` with an **incorrect** type.
1733    ///
1734    /// This should only be used if `self` is about to be dropped or otherwise overwritten.
1735    pub fn take_dangerous(&mut self) -> MirRelationExpr {
1736        let empty = MirRelationExpr::Constant {
1737            rows: Ok(vec![]),
1738            typ: RelationType::new(Vec::new()),
1739        };
1740        std::mem::replace(self, empty)
1741    }
1742
1743    /// Replaces `self` with some logic applied to `self`.
1744    pub fn replace_using<F>(&mut self, logic: F)
1745    where
1746        F: FnOnce(MirRelationExpr) -> MirRelationExpr,
1747    {
1748        let empty = MirRelationExpr::Constant {
1749            rows: Ok(vec![]),
1750            typ: RelationType::new(Vec::new()),
1751        };
1752        let expr = std::mem::replace(self, empty);
1753        *self = logic(expr);
1754    }
1755
1756    /// Store `self` in a `Let` and pass the corresponding `Get` to `body`.
1757    pub fn let_in<Body, E>(self, id_gen: &mut IdGen, body: Body) -> Result<MirRelationExpr, E>
1758    where
1759        Body: FnOnce(&mut IdGen, MirRelationExpr) -> Result<MirRelationExpr, E>,
1760    {
1761        if let MirRelationExpr::Get { .. } = self {
1762            // already done
1763            body(id_gen, self)
1764        } else {
1765            let id = LocalId::new(id_gen.allocate_id());
1766            let get = MirRelationExpr::Get {
1767                id: Id::Local(id),
1768                typ: self.typ(),
1769                access_strategy: AccessStrategy::UnknownOrLocal,
1770            };
1771            let body = (body)(id_gen, get)?;
1772            Ok(MirRelationExpr::Let {
1773                id,
1774                value: Box::new(self),
1775                body: Box::new(body),
1776            })
1777        }
1778    }
1779
1780    /// Return every row in `self` that does not have a matching row in the first columns of `keys_and_values`, using `default` to fill in the remaining columns
1781    /// (If `default` is a row of nulls, this is the 'outer' part of LEFT OUTER JOIN)
1782    pub fn anti_lookup<E>(
1783        self,
1784        id_gen: &mut IdGen,
1785        keys_and_values: MirRelationExpr,
1786        default: Vec<(Datum, ScalarType)>,
1787    ) -> Result<MirRelationExpr, E> {
1788        let (data, column_types): (Vec<_>, Vec<_>) = default
1789            .into_iter()
1790            .map(|(datum, scalar_type)| (datum, scalar_type.nullable(datum.is_null())))
1791            .unzip();
1792        assert_eq!(keys_and_values.arity() - self.arity(), data.len());
1793        self.let_in(id_gen, |_id_gen, get_keys| {
1794            let get_keys_arity = get_keys.arity();
1795            Ok(MirRelationExpr::join(
1796                vec![
1797                    // all the missing keys (with count 1)
1798                    keys_and_values
1799                        .distinct_by((0..get_keys_arity).collect())
1800                        .negate()
1801                        .union(get_keys.clone().distinct()),
1802                    // join with keys to get the correct counts
1803                    get_keys.clone(),
1804                ],
1805                (0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
1806            )
1807            // get rid of the extra copies of columns from keys
1808            .project((0..get_keys_arity).collect())
1809            // This join is logically equivalent to
1810            // `.map(<default_expr>)`, but using a join allows for
1811            // potential predicate pushdown and elision in the
1812            // optimizer.
1813            .product(MirRelationExpr::constant(
1814                vec![data],
1815                RelationType::new(column_types),
1816            )))
1817        })
1818    }
1819
1820    /// Return:
1821    /// * every row in keys_and_values
1822    /// * every row in `self` that does not have a matching row in the first columns of
1823    ///   `keys_and_values`, using `default` to fill in the remaining columns
1824    /// (This is LEFT OUTER JOIN if:
1825    /// 1) `default` is a row of null
1826    /// 2) matching rows in `keys_and_values` and `self` have the same multiplicity.)
1827    pub fn lookup<E>(
1828        self,
1829        id_gen: &mut IdGen,
1830        keys_and_values: MirRelationExpr,
1831        default: Vec<(Datum<'static>, ScalarType)>,
1832    ) -> Result<MirRelationExpr, E> {
1833        keys_and_values.let_in(id_gen, |id_gen, get_keys_and_values| {
1834            Ok(get_keys_and_values.clone().union(self.anti_lookup(
1835                id_gen,
1836                get_keys_and_values,
1837                default,
1838            )?))
1839        })
1840    }
1841
1842    /// True iff the expression contains a `NullaryFunc::MzLogicalTimestamp`.
1843    pub fn contains_temporal(&self) -> bool {
1844        let mut contains = false;
1845        self.visit_scalars(&mut |e| contains = contains || e.contains_temporal());
1846        contains
1847    }
1848
1849    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1850    ///
1851    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1852    pub fn try_visit_scalars_mut1<F, E>(&mut self, f: &mut F) -> Result<(), E>
1853    where
1854        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1855    {
1856        use MirRelationExpr::*;
1857        match self {
1858            Map { scalars, .. } => {
1859                for s in scalars {
1860                    f(s)?;
1861                }
1862            }
1863            Filter { predicates, .. } => {
1864                for p in predicates {
1865                    f(p)?;
1866                }
1867            }
1868            FlatMap { exprs, .. } => {
1869                for expr in exprs {
1870                    f(expr)?;
1871                }
1872            }
1873            Join {
1874                inputs: _,
1875                equivalences,
1876                implementation,
1877            } => {
1878                for equivalence in equivalences {
1879                    for expr in equivalence {
1880                        f(expr)?;
1881                    }
1882                }
1883                match implementation {
1884                    JoinImplementation::Differential((_, start_key, _), order) => {
1885                        if let Some(start_key) = start_key {
1886                            for k in start_key {
1887                                f(k)?;
1888                            }
1889                        }
1890                        for (_, lookup_key, _) in order {
1891                            for k in lookup_key {
1892                                f(k)?;
1893                            }
1894                        }
1895                    }
1896                    JoinImplementation::DeltaQuery(paths) => {
1897                        for path in paths {
1898                            for (_, lookup_key, _) in path {
1899                                for k in lookup_key {
1900                                    f(k)?;
1901                                }
1902                            }
1903                        }
1904                    }
1905                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
1906                        for k in index_key {
1907                            f(k)?;
1908                        }
1909                    }
1910                    JoinImplementation::Unimplemented => {} // No scalar exprs
1911                }
1912            }
1913            ArrangeBy { keys, .. } => {
1914                for key in keys {
1915                    for s in key {
1916                        f(s)?;
1917                    }
1918                }
1919            }
1920            Reduce {
1921                group_key,
1922                aggregates,
1923                ..
1924            } => {
1925                for s in group_key {
1926                    f(s)?;
1927                }
1928                for agg in aggregates {
1929                    f(&mut agg.expr)?;
1930                }
1931            }
1932            TopK { limit, .. } => {
1933                if let Some(s) = limit {
1934                    f(s)?;
1935                }
1936            }
1937            Constant { .. }
1938            | Get { .. }
1939            | Let { .. }
1940            | LetRec { .. }
1941            | Project { .. }
1942            | Negate { .. }
1943            | Threshold { .. }
1944            | Union { .. } => (),
1945        }
1946        Ok(())
1947    }
1948
1949    /// Fallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1950    /// rooted at `self`.
1951    ///
1952    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1953    /// nodes.
1954    pub fn try_visit_scalars_mut<F, E>(&mut self, f: &mut F) -> Result<(), E>
1955    where
1956        F: FnMut(&mut MirScalarExpr) -> Result<(), E>,
1957        E: From<RecursionLimitError>,
1958    {
1959        self.try_visit_mut_post(&mut |expr| expr.try_visit_scalars_mut1(f))
1960    }
1961
1962    /// Infallible mutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
1963    /// rooted at `self`.
1964    ///
1965    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
1966    /// nodes.
1967    pub fn visit_scalars_mut<F>(&mut self, f: &mut F)
1968    where
1969        F: FnMut(&mut MirScalarExpr),
1970    {
1971        self.try_visit_scalars_mut(&mut |s| {
1972            f(s);
1973            Ok::<_, RecursionLimitError>(())
1974        })
1975        .expect("Unexpected error in `visit_scalars_mut` call");
1976    }
1977
1978    /// Fallible visitor for the [`MirScalarExpr`]s directly owned by this relation expression.
1979    ///
1980    /// The `f` visitor should not recursively descend into owned [`MirRelationExpr`]s.
1981    pub fn try_visit_scalars_1<F, E>(&self, f: &mut F) -> Result<(), E>
1982    where
1983        F: FnMut(&MirScalarExpr) -> Result<(), E>,
1984    {
1985        use MirRelationExpr::*;
1986        match self {
1987            Map { scalars, .. } => {
1988                for s in scalars {
1989                    f(s)?;
1990                }
1991            }
1992            Filter { predicates, .. } => {
1993                for p in predicates {
1994                    f(p)?;
1995                }
1996            }
1997            FlatMap { exprs, .. } => {
1998                for expr in exprs {
1999                    f(expr)?;
2000                }
2001            }
2002            Join {
2003                inputs: _,
2004                equivalences,
2005                implementation,
2006            } => {
2007                for equivalence in equivalences {
2008                    for expr in equivalence {
2009                        f(expr)?;
2010                    }
2011                }
2012                match implementation {
2013                    JoinImplementation::Differential((_, start_key, _), order) => {
2014                        if let Some(start_key) = start_key {
2015                            for k in start_key {
2016                                f(k)?;
2017                            }
2018                        }
2019                        for (_, lookup_key, _) in order {
2020                            for k in lookup_key {
2021                                f(k)?;
2022                            }
2023                        }
2024                    }
2025                    JoinImplementation::DeltaQuery(paths) => {
2026                        for path in paths {
2027                            for (_, lookup_key, _) in path {
2028                                for k in lookup_key {
2029                                    f(k)?;
2030                                }
2031                            }
2032                        }
2033                    }
2034                    JoinImplementation::IndexedFilter(_coll_id, _idx_id, index_key, _) => {
2035                        for k in index_key {
2036                            f(k)?;
2037                        }
2038                    }
2039                    JoinImplementation::Unimplemented => {} // No scalar exprs
2040                }
2041            }
2042            ArrangeBy { keys, .. } => {
2043                for key in keys {
2044                    for s in key {
2045                        f(s)?;
2046                    }
2047                }
2048            }
2049            Reduce {
2050                group_key,
2051                aggregates,
2052                ..
2053            } => {
2054                for s in group_key {
2055                    f(s)?;
2056                }
2057                for agg in aggregates {
2058                    f(&agg.expr)?;
2059                }
2060            }
2061            TopK { limit, .. } => {
2062                if let Some(s) = limit {
2063                    f(s)?;
2064                }
2065            }
2066            Constant { .. }
2067            | Get { .. }
2068            | Let { .. }
2069            | LetRec { .. }
2070            | Project { .. }
2071            | Negate { .. }
2072            | Threshold { .. }
2073            | Union { .. } => (),
2074        }
2075        Ok(())
2076    }
2077
2078    /// Fallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2079    /// rooted at `self`.
2080    ///
2081    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2082    /// nodes.
2083    pub fn try_visit_scalars<F, E>(&self, f: &mut F) -> Result<(), E>
2084    where
2085        F: FnMut(&MirScalarExpr) -> Result<(), E>,
2086        E: From<RecursionLimitError>,
2087    {
2088        self.try_visit_post(&mut |expr| expr.try_visit_scalars_1(f))
2089    }
2090
2091    /// Infallible immutable visitor for the [`MirScalarExpr`]s in the [`MirRelationExpr`] subtree
2092    /// rooted at `self`.
2093    ///
2094    /// Note that this does not recurse into [`MirRelationExpr`] subtrees within [`MirScalarExpr`]
2095    /// nodes.
2096    pub fn visit_scalars<F>(&self, f: &mut F)
2097    where
2098        F: FnMut(&MirScalarExpr),
2099    {
2100        self.try_visit_scalars(&mut |s| {
2101            f(s);
2102            Ok::<_, RecursionLimitError>(())
2103        })
2104        .expect("Unexpected error in `visit_scalars` call");
2105    }
2106
2107    /// Clears the contents of `self` even if it's so deep that simply dropping it would cause a
2108    /// stack overflow in `drop_in_place`.
2109    ///
2110    /// Leaves `self` in an unusable state, so this should only be used if `self` is about to be
2111    /// dropped or otherwise overwritten.
2112    pub fn destroy_carefully(&mut self) {
2113        let mut todo = vec![self.take_dangerous()];
2114        while let Some(mut expr) = todo.pop() {
2115            for child in expr.children_mut() {
2116                todo.push(child.take_dangerous());
2117            }
2118        }
2119    }
2120
2121    /// Computes the size (total number of nodes) and maximum depth of a MirRelationExpr for
2122    /// debug printing purposes.
2123    pub fn debug_size_and_depth(&self) -> (usize, usize) {
2124        let mut size = 0;
2125        let mut max_depth = 0;
2126        let mut todo = vec![(self, 1)];
2127        while let Some((expr, depth)) = todo.pop() {
2128            size += 1;
2129            max_depth = max(max_depth, depth);
2130            todo.extend(expr.children().map(|c| (c, depth + 1)));
2131        }
2132        (size, max_depth)
2133    }
2134
2135    /// The MirRelationExpr is considered potentially expensive if and only if
2136    /// at least one of the following conditions is true:
2137    ///
2138    ///  - It contains at least one FlatMap or a Reduce operator.
2139    ///  - It contains at least one MirScalarExpr with a function call.
2140    ///
2141    /// !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2142    /// should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2143    pub fn could_run_expensive_function(&self) -> bool {
2144        let mut result = false;
2145        self.visit_pre(|e: &MirRelationExpr| {
2146            use MirRelationExpr::*;
2147            use MirScalarExpr::*;
2148            if let Err(_) = self.try_visit_scalars::<_, RecursionLimitError>(&mut |scalar| {
2149                result |= match scalar {
2150                    Column(_, _) | Literal(_, _) | CallUnmaterializable(_) | If { .. } => false,
2151                    // Function calls are considered expensive
2152                    CallUnary { .. } | CallBinary { .. } | CallVariadic { .. } => true,
2153                };
2154                Ok(())
2155            }) {
2156                // Conservatively set `true` if on RecursionLimitError.
2157                result = true;
2158            }
2159            // FlatMap has a table function; Reduce has an aggregate function.
2160            // Other constructs use MirScalarExpr to run a function
2161            result |= matches!(e, FlatMap { .. } | Reduce { .. });
2162        });
2163        result
2164    }
2165
2166    /// Hash to an u64 using Rust's default Hasher. (Which is a somewhat slower, but better Hasher
2167    /// than what `Hashable::hashed` would give us.)
2168    pub fn hash_to_u64(&self) -> u64 {
2169        let mut h = DefaultHasher::new();
2170        self.hash(&mut h);
2171        h.finish()
2172    }
2173}
2174
2175// `LetRec` helpers
2176impl MirRelationExpr {
2177    /// True when `expr` contains a `LetRec` AST node.
2178    pub fn is_recursive(self: &MirRelationExpr) -> bool {
2179        let mut worklist = vec![self];
2180        while let Some(expr) = worklist.pop() {
2181            if let MirRelationExpr::LetRec { .. } = expr {
2182                return true;
2183            }
2184            worklist.extend(expr.children());
2185        }
2186        false
2187    }
2188
2189    /// Return the number of sub-expressions in the tree (including self).
2190    pub fn size(&self) -> usize {
2191        let mut size = 0;
2192        self.visit_pre(|_| size += 1);
2193        size
2194    }
2195
2196    /// Given the ids and values of a LetRec, it computes the subset of ids that are used across
2197    /// iterations. These are those ids that have a reference before they are defined, when reading
2198    /// all the bindings in order.
2199    ///
2200    /// For example:
2201    /// ```SQL
2202    /// WITH MUTUALLY RECURSIVE
2203    ///     x(...) AS f(z),
2204    ///     y(...) AS g(x),
2205    ///     z(...) AS h(y)
2206    /// ...;
2207    /// ```
2208    /// Here, only `z` is returned, because `x` and `y` are referenced only within the same
2209    /// iteration.
2210    ///
2211    /// Note that if a binding references itself, that is also returned.
2212    pub fn recursive_ids(ids: &[LocalId], values: &[MirRelationExpr]) -> BTreeSet<LocalId> {
2213        let mut used_across_iterations = BTreeSet::new();
2214        let mut defined = BTreeSet::new();
2215        for (binding_id, value) in itertools::zip_eq(ids.iter(), values.iter()) {
2216            value.visit_pre(|expr| {
2217                if let MirRelationExpr::Get {
2218                    id: Local(get_id), ..
2219                } = expr
2220                {
2221                    // If we haven't seen a definition for it yet, then this will refer
2222                    // to the previous iteration.
2223                    // The `ids.contains` part of the condition is needed to exclude
2224                    // those ids that are not really in this LetRec, but either an inner
2225                    // or outer one.
2226                    if !defined.contains(get_id) && ids.contains(get_id) {
2227                        used_across_iterations.insert(*get_id);
2228                    }
2229                }
2230            });
2231            defined.insert(*binding_id);
2232        }
2233        used_across_iterations
2234    }
2235
2236    /// Replaces `LetRec` nodes with a stack of `Let` nodes.
2237    ///
2238    /// In each `Let` binding, uses of `Get` in `value` that are not at strictly greater
2239    /// identifiers are rewritten to be the constant collection.
2240    /// This makes the computation perform exactly "one" iteration.
2241    ///
2242    /// This was used only temporarily while developing `LetRec`.
2243    pub fn make_nonrecursive(self: &mut MirRelationExpr) {
2244        let mut deadlist = BTreeSet::new();
2245        let mut worklist = vec![self];
2246        while let Some(expr) = worklist.pop() {
2247            if let MirRelationExpr::LetRec {
2248                ids,
2249                values,
2250                limits: _,
2251                body,
2252            } = expr
2253            {
2254                let ids_values = values
2255                    .drain(..)
2256                    .zip(ids)
2257                    .map(|(value, id)| (*id, value))
2258                    .collect::<Vec<_>>();
2259                *expr = body.take_dangerous();
2260                for (id, mut value) in ids_values.into_iter().rev() {
2261                    // Remove references to potentially recursive identifiers.
2262                    deadlist.insert(id);
2263                    value.visit_pre_mut(|e| {
2264                        if let MirRelationExpr::Get {
2265                            id: crate::Id::Local(id),
2266                            typ,
2267                            ..
2268                        } = e
2269                        {
2270                            let typ = typ.clone();
2271                            if deadlist.contains(id) {
2272                                e.take_safely(Some(typ));
2273                            }
2274                        }
2275                    });
2276                    *expr = MirRelationExpr::Let {
2277                        id,
2278                        value: Box::new(value),
2279                        body: Box::new(expr.take_dangerous()),
2280                    };
2281                }
2282                worklist.push(expr);
2283            } else {
2284                worklist.extend(expr.children_mut().rev());
2285            }
2286        }
2287    }
2288
2289    /// For each Id `id'` referenced in `expr`, if it is larger or equal than `id`, then record in
2290    /// `expire_whens` that when `id'` is redefined, then we should expire the information that
2291    /// we are holding about `id`. Call `do_expirations` with `expire_whens` at each Id
2292    /// redefinition.
2293    ///
2294    /// IMPORTANT: Relies on the numbering of Ids to be what `renumber_bindings` gives.
2295    pub fn collect_expirations(
2296        id: LocalId,
2297        expr: &MirRelationExpr,
2298        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2299    ) {
2300        expr.visit_pre(|e| {
2301            if let MirRelationExpr::Get {
2302                id: Id::Local(referenced_id),
2303                ..
2304            } = e
2305            {
2306                // The following check needs `renumber_bindings` to have run recently
2307                if referenced_id >= &id {
2308                    expire_whens
2309                        .entry(*referenced_id)
2310                        .or_insert_with(Vec::new)
2311                        .push(id);
2312                }
2313            }
2314        });
2315    }
2316
2317    /// Call this function when `id` is redefined. It modifies `id_infos` by removing information
2318    /// about such Ids whose information depended on the earlier definition of `id`, according to
2319    /// `expire_whens`. Also modifies `expire_whens`: it removes the currently processed entry.
2320    pub fn do_expirations<I>(
2321        redefined_id: LocalId,
2322        expire_whens: &mut BTreeMap<LocalId, Vec<LocalId>>,
2323        id_infos: &mut BTreeMap<LocalId, I>,
2324    ) -> Vec<(LocalId, I)> {
2325        let mut expired_infos = Vec::new();
2326        if let Some(expirations) = expire_whens.remove(&redefined_id) {
2327            for expired_id in expirations.into_iter() {
2328                if let Some(offer) = id_infos.remove(&expired_id) {
2329                    expired_infos.push((expired_id, offer));
2330                }
2331            }
2332        }
2333        expired_infos
2334    }
2335}
2336/// Augment non-nullability of columns, by observing either
2337/// 1. Predicates that explicitly test for null values, and
2338/// 2. Columns that if null would make a predicate be null.
2339pub fn non_nullable_columns(predicates: &[MirScalarExpr]) -> BTreeSet<usize> {
2340    let mut nonnull_required_columns = BTreeSet::new();
2341    for predicate in predicates {
2342        // Add any columns that being null would force the predicate to be null.
2343        // Should that happen, the row would be discarded.
2344        predicate.non_null_requirements(&mut nonnull_required_columns);
2345
2346        /*
2347        Test for explicit checks that a column is non-null.
2348
2349        This analysis is ad hoc, and will miss things:
2350
2351        materialize=> create table a(x int, y int);
2352        CREATE TABLE
2353        materialize=> explain with(types) select x from a where (y=x and y is not null) or x is not null;
2354        Optimized Plan
2355        --------------------------------------------------------------------------------------------------------
2356        Explained Query:                                                                                      +
2357        Project (#0) // { types: "(integer?)" }                                                             +
2358        Filter ((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))) // { types: "(integer?, integer?)" }+
2359        Get materialize.public.a // { types: "(integer?, integer?)" }                                   +
2360                                                                                  +
2361        Source materialize.public.a                                                                           +
2362        filter=(((#0) IS NOT NULL OR ((#1) IS NOT NULL AND (#0 = #1))))                                     +
2363
2364        (1 row)
2365        */
2366
2367        if let MirScalarExpr::CallUnary {
2368            func: UnaryFunc::Not(scalar_func::Not),
2369            expr,
2370        } = predicate
2371        {
2372            if let MirScalarExpr::CallUnary {
2373                func: UnaryFunc::IsNull(scalar_func::IsNull),
2374                expr,
2375            } = &**expr
2376            {
2377                if let MirScalarExpr::Column(c, _name) = &**expr {
2378                    nonnull_required_columns.insert(*c);
2379                }
2380            }
2381        }
2382    }
2383
2384    nonnull_required_columns
2385}
2386
2387impl CollectionPlan for MirRelationExpr {
2388    // !!!WARNING!!!: this method has an HirRelationExpr counterpart. The two
2389    // should be kept in sync w.r.t. HIR ⇒ MIR lowering!
2390    fn depends_on_into(&self, out: &mut BTreeSet<GlobalId>) {
2391        if let MirRelationExpr::Get {
2392            id: Id::Global(id), ..
2393        } = self
2394        {
2395            out.insert(*id);
2396        }
2397        self.visit_children(|expr| expr.depends_on_into(out))
2398    }
2399}
2400
2401impl MirRelationExpr {
2402    /// Iterates through references to child expressions.
2403    pub fn children(&self) -> impl DoubleEndedIterator<Item = &Self> {
2404        let mut first = None;
2405        let mut second = None;
2406        let mut rest = None;
2407        let mut last = None;
2408
2409        use MirRelationExpr::*;
2410        match self {
2411            Constant { .. } | Get { .. } => (),
2412            Let { value, body, .. } => {
2413                first = Some(&**value);
2414                second = Some(&**body);
2415            }
2416            LetRec { values, body, .. } => {
2417                rest = Some(values);
2418                last = Some(&**body);
2419            }
2420            Project { input, .. }
2421            | Map { input, .. }
2422            | FlatMap { input, .. }
2423            | Filter { input, .. }
2424            | Reduce { input, .. }
2425            | TopK { input, .. }
2426            | Negate { input }
2427            | Threshold { input }
2428            | ArrangeBy { input, .. } => {
2429                first = Some(&**input);
2430            }
2431            Join { inputs, .. } => {
2432                rest = Some(inputs);
2433            }
2434            Union { base, inputs } => {
2435                first = Some(&**base);
2436                rest = Some(inputs);
2437            }
2438        }
2439
2440        first
2441            .into_iter()
2442            .chain(second)
2443            .chain(rest.into_iter().flatten())
2444            .chain(last)
2445    }
2446
2447    /// Iterates through mutable references to child expressions.
2448    pub fn children_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut Self> {
2449        let mut first = None;
2450        let mut second = None;
2451        let mut rest = None;
2452        let mut last = None;
2453
2454        use MirRelationExpr::*;
2455        match self {
2456            Constant { .. } | Get { .. } => (),
2457            Let { value, body, .. } => {
2458                first = Some(&mut **value);
2459                second = Some(&mut **body);
2460            }
2461            LetRec { values, body, .. } => {
2462                rest = Some(values);
2463                last = Some(&mut **body);
2464            }
2465            Project { input, .. }
2466            | Map { input, .. }
2467            | FlatMap { input, .. }
2468            | Filter { input, .. }
2469            | Reduce { input, .. }
2470            | TopK { input, .. }
2471            | Negate { input }
2472            | Threshold { input }
2473            | ArrangeBy { input, .. } => {
2474                first = Some(&mut **input);
2475            }
2476            Join { inputs, .. } => {
2477                rest = Some(inputs);
2478            }
2479            Union { base, inputs } => {
2480                first = Some(&mut **base);
2481                rest = Some(inputs);
2482            }
2483        }
2484
2485        first
2486            .into_iter()
2487            .chain(second)
2488            .chain(rest.into_iter().flatten())
2489            .chain(last)
2490    }
2491
2492    /// Iterative pre-order visitor.
2493    pub fn visit_pre<'a, F: FnMut(&'a Self)>(&'a self, mut f: F) {
2494        let mut worklist = vec![self];
2495        while let Some(expr) = worklist.pop() {
2496            f(expr);
2497            worklist.extend(expr.children().rev());
2498        }
2499    }
2500
2501    /// Iterative pre-order visitor.
2502    pub fn visit_pre_mut<F: FnMut(&mut Self)>(&mut self, mut f: F) {
2503        let mut worklist = vec![self];
2504        while let Some(expr) = worklist.pop() {
2505            f(expr);
2506            worklist.extend(expr.children_mut().rev());
2507        }
2508    }
2509
2510    /// Return a vector of references to the subtrees of this expression
2511    /// in post-visit order (the last element is `&self`).
2512    pub fn post_order_vec(&self) -> Vec<&Self> {
2513        let mut stack = vec![self];
2514        let mut result = vec![];
2515        while let Some(expr) = stack.pop() {
2516            result.push(expr);
2517            stack.extend(expr.children());
2518        }
2519        result.reverse();
2520        result
2521    }
2522}
2523
2524impl VisitChildren<Self> for MirRelationExpr {
2525    fn visit_children<F>(&self, mut f: F)
2526    where
2527        F: FnMut(&Self),
2528    {
2529        for child in self.children() {
2530            f(child)
2531        }
2532    }
2533
2534    fn visit_mut_children<F>(&mut self, mut f: F)
2535    where
2536        F: FnMut(&mut Self),
2537    {
2538        for child in self.children_mut() {
2539            f(child)
2540        }
2541    }
2542
2543    fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
2544    where
2545        F: FnMut(&Self) -> Result<(), E>,
2546        E: From<RecursionLimitError>,
2547    {
2548        for child in self.children() {
2549            f(child)?
2550        }
2551        Ok(())
2552    }
2553
2554    fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
2555    where
2556        F: FnMut(&mut Self) -> Result<(), E>,
2557        E: From<RecursionLimitError>,
2558    {
2559        for child in self.children_mut() {
2560            f(child)?
2561        }
2562        Ok(())
2563    }
2564}
2565
2566/// Specification for an ordering by a column.
2567#[derive(
2568    Arbitrary,
2569    Debug,
2570    Clone,
2571    Copy,
2572    Eq,
2573    PartialEq,
2574    Ord,
2575    PartialOrd,
2576    Serialize,
2577    Deserialize,
2578    Hash,
2579    MzReflect,
2580)]
2581pub struct ColumnOrder {
2582    /// The column index.
2583    pub column: usize,
2584    /// Whether to sort in descending order.
2585    #[serde(default)]
2586    pub desc: bool,
2587    /// Whether to sort nulls last.
2588    #[serde(default)]
2589    pub nulls_last: bool,
2590}
2591
2592impl Columnation for ColumnOrder {
2593    type InnerRegion = CopyRegion<Self>;
2594}
2595
2596impl RustType<ProtoColumnOrder> for ColumnOrder {
2597    fn into_proto(&self) -> ProtoColumnOrder {
2598        ProtoColumnOrder {
2599            column: self.column.into_proto(),
2600            desc: self.desc,
2601            nulls_last: self.nulls_last,
2602        }
2603    }
2604
2605    fn from_proto(proto: ProtoColumnOrder) -> Result<Self, TryFromProtoError> {
2606        Ok(ColumnOrder {
2607            column: proto.column.into_rust()?,
2608            desc: proto.desc,
2609            nulls_last: proto.nulls_last,
2610        })
2611    }
2612}
2613
2614impl<'a, M> fmt::Display for HumanizedExpr<'a, ColumnOrder, M>
2615where
2616    M: HumanizerMode,
2617{
2618    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2619        // If you modify this, then please also attend to Display for ColumnOrderWithExpr!
2620        write!(
2621            f,
2622            "{} {} {}",
2623            self.child(&self.expr.column),
2624            if self.expr.desc { "desc" } else { "asc" },
2625            if self.expr.nulls_last {
2626                "nulls_last"
2627            } else {
2628                "nulls_first"
2629            },
2630        )
2631    }
2632}
2633
2634/// Describes an aggregation expression.
2635#[derive(
2636    Arbitrary, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
2637)]
2638pub struct AggregateExpr {
2639    /// Names the aggregation function.
2640    pub func: AggregateFunc,
2641    /// An expression which extracts from each row the input to `func`.
2642    pub expr: MirScalarExpr,
2643    /// Should the aggregation be applied only to distinct results in each group.
2644    #[serde(default)]
2645    pub distinct: bool,
2646}
2647
2648impl RustType<ProtoAggregateExpr> for AggregateExpr {
2649    fn into_proto(&self) -> ProtoAggregateExpr {
2650        ProtoAggregateExpr {
2651            func: Some(self.func.into_proto()),
2652            expr: Some(self.expr.into_proto()),
2653            distinct: self.distinct,
2654        }
2655    }
2656
2657    fn from_proto(proto: ProtoAggregateExpr) -> Result<Self, TryFromProtoError> {
2658        Ok(Self {
2659            func: proto.func.into_rust_if_some("ProtoAggregateExpr::func")?,
2660            expr: proto.expr.into_rust_if_some("ProtoAggregateExpr::expr")?,
2661            distinct: proto.distinct,
2662        })
2663    }
2664}
2665
2666impl AggregateExpr {
2667    /// Computes the type of this `AggregateExpr`.
2668    pub fn typ(&self, column_types: &[ColumnType]) -> ColumnType {
2669        self.func.output_type(self.expr.typ(column_types))
2670    }
2671
2672    /// Returns whether the expression has a constant result.
2673    pub fn is_constant(&self) -> bool {
2674        match self.func {
2675            AggregateFunc::MaxInt16
2676            | AggregateFunc::MaxInt32
2677            | AggregateFunc::MaxInt64
2678            | AggregateFunc::MaxUInt16
2679            | AggregateFunc::MaxUInt32
2680            | AggregateFunc::MaxUInt64
2681            | AggregateFunc::MaxMzTimestamp
2682            | AggregateFunc::MaxFloat32
2683            | AggregateFunc::MaxFloat64
2684            | AggregateFunc::MaxBool
2685            | AggregateFunc::MaxString
2686            | AggregateFunc::MaxDate
2687            | AggregateFunc::MaxTimestamp
2688            | AggregateFunc::MaxTimestampTz
2689            | AggregateFunc::MinInt16
2690            | AggregateFunc::MinInt32
2691            | AggregateFunc::MinInt64
2692            | AggregateFunc::MinUInt16
2693            | AggregateFunc::MinUInt32
2694            | AggregateFunc::MinUInt64
2695            | AggregateFunc::MinMzTimestamp
2696            | AggregateFunc::MinFloat32
2697            | AggregateFunc::MinFloat64
2698            | AggregateFunc::MinBool
2699            | AggregateFunc::MinString
2700            | AggregateFunc::MinDate
2701            | AggregateFunc::MinTimestamp
2702            | AggregateFunc::MinTimestampTz
2703            | AggregateFunc::Any
2704            | AggregateFunc::All
2705            | AggregateFunc::Dummy => self.expr.is_literal(),
2706            AggregateFunc::Count => self.expr.is_literal_null(),
2707            _ => self.expr.is_literal_err(),
2708        }
2709    }
2710
2711    /// Returns an expression that computes `self` on a group that has exactly one row.
2712    /// Instead of performing a `Reduce` with `self`, one can perform a `Map` with the expression
2713    /// returned by `on_unique`, which is cheaper. (See `ReduceElision`.)
2714    pub fn on_unique(&self, input_type: &[ColumnType]) -> MirScalarExpr {
2715        match &self.func {
2716            // Count is one if non-null, and zero if null.
2717            AggregateFunc::Count => self
2718                .expr
2719                .clone()
2720                .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
2721                .if_then_else(
2722                    MirScalarExpr::literal_ok(Datum::Int64(0), ScalarType::Int64),
2723                    MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
2724                ),
2725
2726            // SumInt16 takes Int16s as input, but outputs Int64s.
2727            AggregateFunc::SumInt16 => self
2728                .expr
2729                .clone()
2730                .call_unary(UnaryFunc::CastInt16ToInt64(scalar_func::CastInt16ToInt64)),
2731
2732            // SumInt32 takes Int32s as input, but outputs Int64s.
2733            AggregateFunc::SumInt32 => self
2734                .expr
2735                .clone()
2736                .call_unary(UnaryFunc::CastInt32ToInt64(scalar_func::CastInt32ToInt64)),
2737
2738            // SumInt64 takes Int64s as input, but outputs numerics.
2739            AggregateFunc::SumInt64 => self.expr.clone().call_unary(UnaryFunc::CastInt64ToNumeric(
2740                scalar_func::CastInt64ToNumeric(Some(NumericMaxScale::ZERO)),
2741            )),
2742
2743            // SumUInt16 takes UInt16s as input, but outputs UInt64s.
2744            AggregateFunc::SumUInt16 => self.expr.clone().call_unary(
2745                UnaryFunc::CastUint16ToUint64(scalar_func::CastUint16ToUint64),
2746            ),
2747
2748            // SumUInt32 takes UInt32s as input, but outputs UInt64s.
2749            AggregateFunc::SumUInt32 => self.expr.clone().call_unary(
2750                UnaryFunc::CastUint32ToUint64(scalar_func::CastUint32ToUint64),
2751            ),
2752
2753            // SumUInt64 takes UInt64s as input, but outputs numerics.
2754            AggregateFunc::SumUInt64 => {
2755                self.expr.clone().call_unary(UnaryFunc::CastUint64ToNumeric(
2756                    scalar_func::CastUint64ToNumeric(Some(NumericMaxScale::ZERO)),
2757                ))
2758            }
2759
2760            // JsonbAgg takes _anything_ as input, but must output a Jsonb array.
2761            AggregateFunc::JsonbAgg { .. } => MirScalarExpr::CallVariadic {
2762                func: VariadicFunc::JsonbBuildArray,
2763                exprs: vec![
2764                    self.expr
2765                        .clone()
2766                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2767                ],
2768            },
2769
2770            // JsonbAgg takes _anything_ as input, but must output a Jsonb object.
2771            AggregateFunc::JsonbObjectAgg { .. } => {
2772                let record = self
2773                    .expr
2774                    .clone()
2775                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2776                MirScalarExpr::CallVariadic {
2777                    func: VariadicFunc::JsonbBuildObject,
2778                    exprs: (0..2)
2779                        .map(|i| {
2780                            record
2781                                .clone()
2782                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2783                        })
2784                        .collect(),
2785                }
2786            }
2787
2788            AggregateFunc::MapAgg { value_type, .. } => {
2789                let record = self
2790                    .expr
2791                    .clone()
2792                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2793                MirScalarExpr::CallVariadic {
2794                    func: VariadicFunc::MapBuild {
2795                        value_type: value_type.clone(),
2796                    },
2797                    exprs: (0..2)
2798                        .map(|i| {
2799                            record
2800                                .clone()
2801                                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(i)))
2802                        })
2803                        .collect(),
2804                }
2805            }
2806
2807            // StringAgg takes nested records of strings and outputs a string
2808            AggregateFunc::StringAgg { .. } => self
2809                .expr
2810                .clone()
2811                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)))
2812                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2813
2814            // ListConcat and ArrayConcat take a single level of records and output a list containing exactly 1 element
2815            AggregateFunc::ListConcat { .. } | AggregateFunc::ArrayConcat { .. } => self
2816                .expr
2817                .clone()
2818                .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0))),
2819
2820            // RowNumber, Rank, DenseRank take a list of records and output a list containing exactly 1 element
2821            AggregateFunc::RowNumber { .. } => {
2822                self.on_unique_ranking_window_funcs(input_type, "?row_number?")
2823            }
2824            AggregateFunc::Rank { .. } => self.on_unique_ranking_window_funcs(input_type, "?rank?"),
2825            AggregateFunc::DenseRank { .. } => {
2826                self.on_unique_ranking_window_funcs(input_type, "?dense_rank?")
2827            }
2828
2829            // The input type for LagLead is ((OriginalRow, (InputValue, Offset, Default)), OrderByExprs...)
2830            AggregateFunc::LagLead { lag_lead, .. } => {
2831                let tuple = self
2832                    .expr
2833                    .clone()
2834                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2835
2836                // Get the overall return type
2837                let return_type_with_orig_row = self
2838                    .typ(input_type)
2839                    .scalar_type
2840                    .unwrap_list_element_type()
2841                    .clone();
2842                let lag_lead_return_type =
2843                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2844
2845                // Extract the original row
2846                let original_row = tuple
2847                    .clone()
2848                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2849
2850                // Extract the encoded args
2851                let encoded_args =
2852                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2853
2854                let (result_expr, column_name) =
2855                    Self::on_unique_lag_lead(lag_lead, encoded_args, lag_lead_return_type.clone());
2856
2857                MirScalarExpr::CallVariadic {
2858                    func: VariadicFunc::ListCreate {
2859                        elem_type: return_type_with_orig_row,
2860                    },
2861                    exprs: vec![MirScalarExpr::CallVariadic {
2862                        func: VariadicFunc::RecordCreate {
2863                            field_names: vec![column_name, ColumnName::from("?record?")],
2864                        },
2865                        exprs: vec![result_expr, original_row],
2866                    }],
2867                }
2868            }
2869
2870            // The input type for FirstValue is ((OriginalRow, InputValue), OrderByExprs...)
2871            AggregateFunc::FirstValue { window_frame, .. } => {
2872                let tuple = self
2873                    .expr
2874                    .clone()
2875                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2876
2877                // Get the overall return type
2878                let return_type_with_orig_row = self
2879                    .typ(input_type)
2880                    .scalar_type
2881                    .unwrap_list_element_type()
2882                    .clone();
2883                let first_value_return_type =
2884                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2885
2886                // Extract the original row
2887                let original_row = tuple
2888                    .clone()
2889                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2890
2891                // Extract the input value
2892                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2893
2894                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2895                    window_frame,
2896                    arg,
2897                    first_value_return_type,
2898                );
2899
2900                MirScalarExpr::CallVariadic {
2901                    func: VariadicFunc::ListCreate {
2902                        elem_type: return_type_with_orig_row,
2903                    },
2904                    exprs: vec![MirScalarExpr::CallVariadic {
2905                        func: VariadicFunc::RecordCreate {
2906                            field_names: vec![column_name, ColumnName::from("?record?")],
2907                        },
2908                        exprs: vec![result_expr, original_row],
2909                    }],
2910                }
2911            }
2912
2913            // The input type for LastValue is ((OriginalRow, InputValue), OrderByExprs...)
2914            AggregateFunc::LastValue { window_frame, .. } => {
2915                let tuple = self
2916                    .expr
2917                    .clone()
2918                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2919
2920                // Get the overall return type
2921                let return_type_with_orig_row = self
2922                    .typ(input_type)
2923                    .scalar_type
2924                    .unwrap_list_element_type()
2925                    .clone();
2926                let last_value_return_type =
2927                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
2928
2929                // Extract the original row
2930                let original_row = tuple
2931                    .clone()
2932                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2933
2934                // Extract the input value
2935                let arg = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2936
2937                let (result_expr, column_name) = Self::on_unique_first_value_last_value(
2938                    window_frame,
2939                    arg,
2940                    last_value_return_type,
2941                );
2942
2943                MirScalarExpr::CallVariadic {
2944                    func: VariadicFunc::ListCreate {
2945                        elem_type: return_type_with_orig_row,
2946                    },
2947                    exprs: vec![MirScalarExpr::CallVariadic {
2948                        func: VariadicFunc::RecordCreate {
2949                            field_names: vec![column_name, ColumnName::from("?record?")],
2950                        },
2951                        exprs: vec![result_expr, original_row],
2952                    }],
2953                }
2954            }
2955
2956            // The input type for window aggs is ((OriginalRow, InputValue), OrderByExprs...)
2957            // See an example MIR in `window_func_applied_to`.
2958            AggregateFunc::WindowAggregate {
2959                wrapped_aggregate,
2960                window_frame,
2961                order_by: _,
2962            } => {
2963                // TODO: deduplicate code between the various window function cases.
2964
2965                let tuple = self
2966                    .expr
2967                    .clone()
2968                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2969
2970                // Get the overall return type
2971                let return_type = self
2972                    .typ(input_type)
2973                    .scalar_type
2974                    .unwrap_list_element_type()
2975                    .clone();
2976                let window_agg_return_type = return_type.unwrap_record_element_type()[0].clone();
2977
2978                // Extract the original row
2979                let original_row = tuple
2980                    .clone()
2981                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
2982
2983                // Extract the input value
2984                let arg_expr = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
2985
2986                let (result, column_name) = Self::on_unique_window_agg(
2987                    window_frame,
2988                    arg_expr,
2989                    input_type,
2990                    window_agg_return_type,
2991                    wrapped_aggregate,
2992                );
2993
2994                MirScalarExpr::CallVariadic {
2995                    func: VariadicFunc::ListCreate {
2996                        elem_type: return_type,
2997                    },
2998                    exprs: vec![MirScalarExpr::CallVariadic {
2999                        func: VariadicFunc::RecordCreate {
3000                            field_names: vec![column_name, ColumnName::from("?record?")],
3001                        },
3002                        exprs: vec![result, original_row],
3003                    }],
3004                }
3005            }
3006
3007            // The input type is ((OriginalRow, (Arg1, Arg2, ...)), OrderByExprs...)
3008            AggregateFunc::FusedWindowAggregate {
3009                wrapped_aggregates,
3010                order_by: _,
3011                window_frame,
3012            } => {
3013                // Throw away OrderByExprs
3014                let tuple = self
3015                    .expr
3016                    .clone()
3017                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3018
3019                // Extract the original row
3020                let original_row = tuple
3021                    .clone()
3022                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3023
3024                // Extract the args of the fused call
3025                let all_args = tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3026
3027                let return_type_with_orig_row = self
3028                    .typ(input_type)
3029                    .scalar_type
3030                    .unwrap_list_element_type()
3031                    .clone();
3032
3033                let all_func_return_types =
3034                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3035                let mut func_result_exprs = Vec::new();
3036                let mut col_names = Vec::new();
3037                for (idx, wrapped_aggr) in wrapped_aggregates.iter().enumerate() {
3038                    let arg = all_args
3039                        .clone()
3040                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3041                    let return_type =
3042                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3043                    let (result, column_name) = Self::on_unique_window_agg(
3044                        window_frame,
3045                        arg,
3046                        input_type,
3047                        return_type,
3048                        wrapped_aggr,
3049                    );
3050                    func_result_exprs.push(result);
3051                    col_names.push(column_name);
3052                }
3053
3054                MirScalarExpr::CallVariadic {
3055                    func: VariadicFunc::ListCreate {
3056                        elem_type: return_type_with_orig_row,
3057                    },
3058                    exprs: vec![MirScalarExpr::CallVariadic {
3059                        func: VariadicFunc::RecordCreate {
3060                            field_names: vec![
3061                                ColumnName::from("?fused_window_aggr?"),
3062                                ColumnName::from("?record?"),
3063                            ],
3064                        },
3065                        exprs: vec![
3066                            MirScalarExpr::CallVariadic {
3067                                func: VariadicFunc::RecordCreate {
3068                                    field_names: col_names,
3069                                },
3070                                exprs: func_result_exprs,
3071                            },
3072                            original_row,
3073                        ],
3074                    }],
3075                }
3076            }
3077
3078            // The input type is ((OriginalRow, (Args1, Args2, ...)), OrderByExprs...)
3079            AggregateFunc::FusedValueWindowFunc {
3080                funcs,
3081                order_by: outer_order_by,
3082            } => {
3083                // Throw away OrderByExprs
3084                let tuple = self
3085                    .expr
3086                    .clone()
3087                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3088
3089                // Extract the original row
3090                let original_row = tuple
3091                    .clone()
3092                    .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3093
3094                // Extract the encoded args of the fused call
3095                let all_encoded_args =
3096                    tuple.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3097
3098                let return_type_with_orig_row = self
3099                    .typ(input_type)
3100                    .scalar_type
3101                    .unwrap_list_element_type()
3102                    .clone();
3103
3104                let all_func_return_types =
3105                    return_type_with_orig_row.unwrap_record_element_type()[0].clone();
3106                let mut func_result_exprs = Vec::new();
3107                let mut col_names = Vec::new();
3108                for (idx, func) in funcs.iter().enumerate() {
3109                    let args_for_func = all_encoded_args
3110                        .clone()
3111                        .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(idx)));
3112                    let return_type_for_func =
3113                        all_func_return_types.unwrap_record_element_type()[idx].clone();
3114                    let (result, column_name) = match func {
3115                        AggregateFunc::LagLead {
3116                            lag_lead,
3117                            order_by,
3118                            ignore_nulls: _,
3119                        } => {
3120                            assert_eq!(order_by, outer_order_by);
3121                            Self::on_unique_lag_lead(lag_lead, args_for_func, return_type_for_func)
3122                        }
3123                        AggregateFunc::FirstValue {
3124                            window_frame,
3125                            order_by,
3126                        } => {
3127                            assert_eq!(order_by, outer_order_by);
3128                            Self::on_unique_first_value_last_value(
3129                                window_frame,
3130                                args_for_func,
3131                                return_type_for_func,
3132                            )
3133                        }
3134                        AggregateFunc::LastValue {
3135                            window_frame,
3136                            order_by,
3137                        } => {
3138                            assert_eq!(order_by, outer_order_by);
3139                            Self::on_unique_first_value_last_value(
3140                                window_frame,
3141                                args_for_func,
3142                                return_type_for_func,
3143                            )
3144                        }
3145                        _ => panic!("unknown function in FusedValueWindowFunc"),
3146                    };
3147                    func_result_exprs.push(result);
3148                    col_names.push(column_name);
3149                }
3150
3151                MirScalarExpr::CallVariadic {
3152                    func: VariadicFunc::ListCreate {
3153                        elem_type: return_type_with_orig_row,
3154                    },
3155                    exprs: vec![MirScalarExpr::CallVariadic {
3156                        func: VariadicFunc::RecordCreate {
3157                            field_names: vec![
3158                                ColumnName::from("?fused_value_window_func?"),
3159                                ColumnName::from("?record?"),
3160                            ],
3161                        },
3162                        exprs: vec![
3163                            MirScalarExpr::CallVariadic {
3164                                func: VariadicFunc::RecordCreate {
3165                                    field_names: col_names,
3166                                },
3167                                exprs: func_result_exprs,
3168                            },
3169                            original_row,
3170                        ],
3171                    }],
3172                }
3173            }
3174
3175            // All other variants should return the argument to the aggregation.
3176            AggregateFunc::MaxNumeric
3177            | AggregateFunc::MaxInt16
3178            | AggregateFunc::MaxInt32
3179            | AggregateFunc::MaxInt64
3180            | AggregateFunc::MaxUInt16
3181            | AggregateFunc::MaxUInt32
3182            | AggregateFunc::MaxUInt64
3183            | AggregateFunc::MaxMzTimestamp
3184            | AggregateFunc::MaxFloat32
3185            | AggregateFunc::MaxFloat64
3186            | AggregateFunc::MaxBool
3187            | AggregateFunc::MaxString
3188            | AggregateFunc::MaxDate
3189            | AggregateFunc::MaxTimestamp
3190            | AggregateFunc::MaxTimestampTz
3191            | AggregateFunc::MaxInterval
3192            | AggregateFunc::MaxTime
3193            | AggregateFunc::MinNumeric
3194            | AggregateFunc::MinInt16
3195            | AggregateFunc::MinInt32
3196            | AggregateFunc::MinInt64
3197            | AggregateFunc::MinUInt16
3198            | AggregateFunc::MinUInt32
3199            | AggregateFunc::MinUInt64
3200            | AggregateFunc::MinMzTimestamp
3201            | AggregateFunc::MinFloat32
3202            | AggregateFunc::MinFloat64
3203            | AggregateFunc::MinBool
3204            | AggregateFunc::MinString
3205            | AggregateFunc::MinDate
3206            | AggregateFunc::MinTimestamp
3207            | AggregateFunc::MinTimestampTz
3208            | AggregateFunc::MinInterval
3209            | AggregateFunc::MinTime
3210            | AggregateFunc::SumFloat32
3211            | AggregateFunc::SumFloat64
3212            | AggregateFunc::SumNumeric
3213            | AggregateFunc::Any
3214            | AggregateFunc::All
3215            | AggregateFunc::Dummy => self.expr.clone(),
3216        }
3217    }
3218
3219    /// `on_unique` for ROW_NUMBER, RANK, DENSE_RANK
3220    fn on_unique_ranking_window_funcs(
3221        &self,
3222        input_type: &[ColumnType],
3223        col_name: &str,
3224    ) -> MirScalarExpr {
3225        let list = self
3226            .expr
3227            .clone()
3228            // extract the list within the record
3229            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3230
3231        // extract the expression within the list
3232        let record = MirScalarExpr::CallVariadic {
3233            func: VariadicFunc::ListIndex,
3234            exprs: vec![
3235                list,
3236                MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3237            ],
3238        };
3239
3240        MirScalarExpr::CallVariadic {
3241            func: VariadicFunc::ListCreate {
3242                elem_type: self
3243                    .typ(input_type)
3244                    .scalar_type
3245                    .unwrap_list_element_type()
3246                    .clone(),
3247            },
3248            exprs: vec![MirScalarExpr::CallVariadic {
3249                func: VariadicFunc::RecordCreate {
3250                    field_names: vec![ColumnName::from(col_name), ColumnName::from("?record?")],
3251                },
3252                exprs: vec![
3253                    MirScalarExpr::literal_ok(Datum::Int64(1), ScalarType::Int64),
3254                    record,
3255                ],
3256            }],
3257        }
3258    }
3259
3260    /// `on_unique` for `lag` and `lead`
3261    fn on_unique_lag_lead(
3262        lag_lead: &LagLeadType,
3263        encoded_args: MirScalarExpr,
3264        return_type: ScalarType,
3265    ) -> (MirScalarExpr, ColumnName) {
3266        let expr = encoded_args
3267            .clone()
3268            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(0)));
3269        let offset = encoded_args
3270            .clone()
3271            .call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(1)));
3272        let default_value =
3273            encoded_args.call_unary(UnaryFunc::RecordGet(scalar_func::RecordGet(2)));
3274
3275        // In this case, the window always has only one element, so if the offset is not null and
3276        // not zero, the default value should be returned instead.
3277        let value = offset
3278            .clone()
3279            .call_binary(
3280                MirScalarExpr::literal_ok(Datum::Int32(0), ScalarType::Int32),
3281                crate::BinaryFunc::Eq,
3282            )
3283            .if_then_else(expr, default_value);
3284        let result_expr = offset
3285            .call_unary(UnaryFunc::IsNull(crate::func::IsNull))
3286            .if_then_else(MirScalarExpr::literal_null(return_type), value);
3287
3288        let column_name = ColumnName::from(match lag_lead {
3289            LagLeadType::Lag => "?lag?",
3290            LagLeadType::Lead => "?lead?",
3291        });
3292
3293        (result_expr, column_name)
3294    }
3295
3296    /// `on_unique` for `first_value` and `last_value`
3297    fn on_unique_first_value_last_value(
3298        window_frame: &WindowFrame,
3299        arg: MirScalarExpr,
3300        return_type: ScalarType,
3301    ) -> (MirScalarExpr, ColumnName) {
3302        // If the window frame includes the current (single) row, return its value, null otherwise
3303        let result_expr = if window_frame.includes_current_row() {
3304            arg
3305        } else {
3306            MirScalarExpr::literal_null(return_type)
3307        };
3308        (result_expr, ColumnName::from("?first_value?"))
3309    }
3310
3311    /// `on_unique` for window aggregations
3312    fn on_unique_window_agg(
3313        window_frame: &WindowFrame,
3314        arg_expr: MirScalarExpr,
3315        input_type: &[ColumnType],
3316        return_type: ScalarType,
3317        wrapped_aggr: &AggregateFunc,
3318    ) -> (MirScalarExpr, ColumnName) {
3319        // If the window frame includes the current (single) row, evaluate the wrapped aggregate on
3320        // that row. Otherwise, return the default value for the aggregate.
3321        let result_expr = if window_frame.includes_current_row() {
3322            AggregateExpr {
3323                func: wrapped_aggr.clone(),
3324                expr: arg_expr,
3325                distinct: false, // We have just one input element; DISTINCT doesn't matter.
3326            }
3327            .on_unique(input_type)
3328        } else {
3329            MirScalarExpr::literal_ok(wrapped_aggr.default(), return_type)
3330        };
3331        (result_expr, ColumnName::from("?window_agg?"))
3332    }
3333
3334    /// Returns whether the expression is COUNT(*) or not.  Note that
3335    /// when we define the count builtin in sql::func, we convert
3336    /// COUNT(*) to COUNT(true), making it indistinguishable from
3337    /// literal COUNT(true), but we prefer to consider this as the
3338    /// former.
3339    ///
3340    /// (HIR has the same `is_count_asterisk`.)
3341    pub fn is_count_asterisk(&self) -> bool {
3342        self.func == AggregateFunc::Count && self.expr.is_literal_true() && !self.distinct
3343    }
3344}
3345
3346/// Describe a join implementation in dataflow.
3347#[derive(
3348    Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3349)]
3350pub enum JoinImplementation {
3351    /// Perform a sequence of binary differential dataflow joins.
3352    ///
3353    /// The first argument indicates
3354    /// 1) the index of the starting collection,
3355    /// 2) if it should be arranged, the keys to arrange it by, and
3356    /// 3) the characteristics of the starting collection (for EXPLAINing).
3357    /// The sequence that follows lists other relation indexes, and the key for
3358    /// the arrangement we should use when joining it in.
3359    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that
3360    /// were used for join ordering.
3361    ///
3362    /// Each collection index should occur exactly once, either as the starting collection
3363    /// or somewhere in the list.
3364    Differential(
3365        (
3366            usize,
3367            Option<Vec<MirScalarExpr>>,
3368            Option<JoinInputCharacteristics>,
3369        ),
3370        Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>,
3371    ),
3372    /// Perform independent delta query dataflows for each input.
3373    ///
3374    /// The argument is a sequence of plans, for the input collections in order.
3375    /// Each plan starts from the corresponding index, and then in sequence joins
3376    /// against collections identified by index and with the specified arrangement key.
3377    /// The JoinInputCharacteristics are for EXPLAINing the characteristics that were
3378    /// used for join ordering.
3379    DeltaQuery(Vec<Vec<(usize, Vec<MirScalarExpr>, Option<JoinInputCharacteristics>)>>),
3380    /// Join a user-created index with a constant collection to speed up the evaluation of a
3381    /// predicate such as `(f1 = 3 AND f2 = 5) OR (f1 = 7 AND f2 = 9)`.
3382    /// This gets translated to a Differential join during MIR -> LIR lowering, but we still want
3383    /// to represent it in MIR, because the fast path detection wants to match on this.
3384    ///
3385    /// Consists of (`<coll_id>`, `<index_id>`, `<index_key>`, `<constants>`)
3386    IndexedFilter(
3387        GlobalId,
3388        GlobalId,
3389        Vec<MirScalarExpr>,
3390        #[mzreflect(ignore)] Vec<Row>,
3391    ),
3392    /// No implementation yet selected.
3393    Unimplemented,
3394}
3395
3396impl Default for JoinImplementation {
3397    fn default() -> Self {
3398        JoinImplementation::Unimplemented
3399    }
3400}
3401
3402impl JoinImplementation {
3403    /// Returns `true` iff the value is not [`JoinImplementation::Unimplemented`].
3404    pub fn is_implemented(&self) -> bool {
3405        match self {
3406            Self::Unimplemented => false,
3407            _ => true,
3408        }
3409    }
3410
3411    /// Returns an optional implementation name if the value is not [`JoinImplementation::Unimplemented`].
3412    pub fn name(&self) -> Option<&'static str> {
3413        match self {
3414            Self::Differential(..) => Some("differential"),
3415            Self::DeltaQuery(..) => Some("delta"),
3416            Self::IndexedFilter(..) => Some("indexed_filter"),
3417            Self::Unimplemented => None,
3418        }
3419    }
3420}
3421
3422/// Characteristics of a join order candidate collection.
3423///
3424/// A candidate is described by a collection and a key, and may have various liabilities.
3425/// Primarily, the candidate may risk substantial inflation of records, which is something
3426/// that concerns us greatly. Additionally, the candidate may be unarranged, and we would
3427/// prefer candidates that do not require additional memory. Finally, we prefer lower id
3428/// collections in the interest of consistent tie-breaking. For more characteristics, see
3429/// comments on individual fields.
3430///
3431/// This has more than one version. `new` instantiates the appropriate version based on a
3432/// feature flag.
3433#[derive(
3434    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3435)]
3436pub enum JoinInputCharacteristics {
3437    /// Old version, with `enable_join_prioritize_arranged` turned off.
3438    V1(JoinInputCharacteristicsV1),
3439    /// Newer version, with `enable_join_prioritize_arranged` turned on.
3440    V2(JoinInputCharacteristicsV2),
3441}
3442
3443impl JoinInputCharacteristics {
3444    /// Creates a new instance with the given characteristics.
3445    pub fn new(
3446        unique_key: bool,
3447        key_length: usize,
3448        arranged: bool,
3449        cardinality: Option<usize>,
3450        filters: FilterCharacteristics,
3451        input: usize,
3452        enable_join_prioritize_arranged: bool,
3453    ) -> Self {
3454        if enable_join_prioritize_arranged {
3455            Self::V2(JoinInputCharacteristicsV2::new(
3456                unique_key,
3457                key_length,
3458                arranged,
3459                cardinality,
3460                filters,
3461                input,
3462            ))
3463        } else {
3464            Self::V1(JoinInputCharacteristicsV1::new(
3465                unique_key,
3466                key_length,
3467                arranged,
3468                cardinality,
3469                filters,
3470                input,
3471            ))
3472        }
3473    }
3474
3475    /// Turns the instance into a String to be printed in EXPLAIN.
3476    pub fn explain(&self) -> String {
3477        match self {
3478            Self::V1(jic) => jic.explain(),
3479            Self::V2(jic) => jic.explain(),
3480        }
3481    }
3482
3483    /// Whether the join input described by `self` is arranged.
3484    pub fn arranged(&self) -> bool {
3485        match self {
3486            Self::V1(jic) => jic.arranged,
3487            Self::V2(jic) => jic.arranged,
3488        }
3489    }
3490
3491    /// Returns the `FilterCharacteristics` for the join input described by `self`.
3492    pub fn filters(&mut self) -> &mut FilterCharacteristics {
3493        match self {
3494            Self::V1(jic) => &mut jic.filters,
3495            Self::V2(jic) => &mut jic.filters,
3496        }
3497    }
3498}
3499
3500/// Newer version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned on.
3501#[derive(
3502    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3503)]
3504pub struct JoinInputCharacteristicsV2 {
3505    /// An excellent indication that record count will not increase.
3506    pub unique_key: bool,
3507    /// Cross joins are bad.
3508    /// (`key_length > 0` also implies that it is not a cross join. However, we need to note cross
3509    /// joins in a separate field, because not being a cross join is more important than `arranged`,
3510    /// but otherwise `key_length` is less important than `arranged`.)
3511    pub not_cross: bool,
3512    /// Indicates that there will be no additional in-memory footprint.
3513    pub arranged: bool,
3514    /// A weaker signal that record count will not increase.
3515    pub key_length: usize,
3516    /// Estimated cardinality (lower is better)
3517    pub cardinality: Option<std::cmp::Reverse<usize>>,
3518    /// Characteristics of the filter that is applied at this input.
3519    pub filters: FilterCharacteristics,
3520    /// We want to prefer input earlier in the input list, for stability of ordering.
3521    pub input: std::cmp::Reverse<usize>,
3522}
3523
3524impl JoinInputCharacteristicsV2 {
3525    /// Creates a new instance with the given characteristics.
3526    pub fn new(
3527        unique_key: bool,
3528        key_length: usize,
3529        arranged: bool,
3530        cardinality: Option<usize>,
3531        filters: FilterCharacteristics,
3532        input: usize,
3533    ) -> Self {
3534        Self {
3535            unique_key,
3536            not_cross: key_length > 0,
3537            arranged,
3538            key_length,
3539            cardinality: cardinality.map(std::cmp::Reverse),
3540            filters,
3541            input: std::cmp::Reverse(input),
3542        }
3543    }
3544
3545    /// Turns the instance into a String to be printed in EXPLAIN.
3546    pub fn explain(&self) -> String {
3547        let mut e = "".to_owned();
3548        if self.unique_key {
3549            e.push_str("U");
3550        }
3551        // Don't need to print `not_cross`, because that is visible in the printed key.
3552        // if !self.not_cross {
3553        //     e.push_str("C");
3554        // }
3555        for _ in 0..self.key_length {
3556            e.push_str("K");
3557        }
3558        if self.arranged {
3559            e.push_str("A");
3560        }
3561        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3562            e.push_str(&format!("|{cardinality}|"));
3563        }
3564        e.push_str(&self.filters.explain());
3565        e
3566    }
3567}
3568
3569/// Old version of `JoinInputCharacteristics`, with `enable_join_prioritize_arranged` turned off.
3570#[derive(
3571    Eq, PartialEq, Ord, PartialOrd, Debug, Clone, Serialize, Deserialize, Hash, MzReflect, Arbitrary,
3572)]
3573pub struct JoinInputCharacteristicsV1 {
3574    /// An excellent indication that record count will not increase.
3575    pub unique_key: bool,
3576    /// A weaker signal that record count will not increase.
3577    pub key_length: usize,
3578    /// Indicates that there will be no additional in-memory footprint.
3579    pub arranged: bool,
3580    /// Estimated cardinality (lower is better)
3581    pub cardinality: Option<std::cmp::Reverse<usize>>,
3582    /// Characteristics of the filter that is applied at this input.
3583    pub filters: FilterCharacteristics,
3584    /// We want to prefer input earlier in the input list, for stability of ordering.
3585    pub input: std::cmp::Reverse<usize>,
3586}
3587
3588impl JoinInputCharacteristicsV1 {
3589    /// Creates a new instance with the given characteristics.
3590    pub fn new(
3591        unique_key: bool,
3592        key_length: usize,
3593        arranged: bool,
3594        cardinality: Option<usize>,
3595        filters: FilterCharacteristics,
3596        input: usize,
3597    ) -> Self {
3598        Self {
3599            unique_key,
3600            key_length,
3601            arranged,
3602            cardinality: cardinality.map(std::cmp::Reverse),
3603            filters,
3604            input: std::cmp::Reverse(input),
3605        }
3606    }
3607
3608    /// Turns the instance into a String to be printed in EXPLAIN.
3609    pub fn explain(&self) -> String {
3610        let mut e = "".to_owned();
3611        if self.unique_key {
3612            e.push_str("U");
3613        }
3614        for _ in 0..self.key_length {
3615            e.push_str("K");
3616        }
3617        if self.arranged {
3618            e.push_str("A");
3619        }
3620        if let Some(std::cmp::Reverse(cardinality)) = self.cardinality {
3621            e.push_str(&format!("|{cardinality}|"));
3622        }
3623        e.push_str(&self.filters.explain());
3624        e
3625    }
3626}
3627
3628/// Instructions for finishing the result of a query.
3629///
3630/// The primary reason for the existence of this structure and attendant code
3631/// is that SQL's ORDER BY requires sorting rows (as already implied by the
3632/// keywords), whereas much of the rest of SQL is defined in terms of unordered
3633/// multisets. But as it turns out, the same idea can be used to optimize
3634/// trivial peeks.
3635///
3636/// The generic parameters are for accommodating prepared statement parameters in
3637/// `limit` and `offset`: the planner can hold these fields as HirScalarExpr long enough to call
3638/// `bind_parameters` on them.
3639#[derive(Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
3640pub struct RowSetFinishing<L = NonNeg<i64>, O = usize> {
3641    /// Order rows by the given columns.
3642    pub order_by: Vec<ColumnOrder>,
3643    /// Include only as many rows (after offset).
3644    pub limit: Option<L>,
3645    /// Omit as many rows.
3646    pub offset: O,
3647    /// Include only given columns.
3648    pub project: Vec<usize>,
3649}
3650
3651impl RustType<ProtoRowSetFinishing> for RowSetFinishing {
3652    fn into_proto(&self) -> ProtoRowSetFinishing {
3653        ProtoRowSetFinishing {
3654            order_by: self.order_by.into_proto(),
3655            limit: self.limit.into_proto(),
3656            offset: self.offset.into_proto(),
3657            project: self.project.into_proto(),
3658        }
3659    }
3660
3661    fn from_proto(x: ProtoRowSetFinishing) -> Result<Self, TryFromProtoError> {
3662        Ok(RowSetFinishing {
3663            order_by: x.order_by.into_rust()?,
3664            limit: x.limit.into_rust()?,
3665            offset: x.offset.into_rust()?,
3666            project: x.project.into_rust()?,
3667        })
3668    }
3669}
3670
3671impl<L> RowSetFinishing<L> {
3672    /// Returns a trivial finishing, i.e., that does nothing to the result set.
3673    pub fn trivial(arity: usize) -> RowSetFinishing<L> {
3674        RowSetFinishing {
3675            order_by: Vec::new(),
3676            limit: None,
3677            offset: 0,
3678            project: (0..arity).collect(),
3679        }
3680    }
3681    /// True if the finishing does nothing to any result set.
3682    pub fn is_trivial(&self, arity: usize) -> bool {
3683        self.limit.is_none()
3684            && self.order_by.is_empty()
3685            && self.offset == 0
3686            && self.project.iter().copied().eq(0..arity)
3687    }
3688    /// True if the finishing does not require an ORDER BY.
3689    ///
3690    /// LIMIT and OFFSET without an ORDER BY _are_ streamable: without an
3691    /// explicit ordering we will skip an arbitrary bag of elements and return
3692    /// the first arbitrary elements in the remaining bag. The result semantics
3693    /// are still correct but maybe surprising for some users.
3694    pub fn is_streamable(&self, arity: usize) -> bool {
3695        self.order_by.is_empty() && self.project.iter().copied().eq(0..arity)
3696    }
3697}
3698
3699impl RowSetFinishing<NonNeg<i64>, usize> {
3700    /// The number of rows needed from before the finishing to evaluate the finishing:
3701    /// offset + limit.
3702    ///
3703    /// If it returns None, then we need all the rows.
3704    pub fn num_rows_needed(&self) -> Option<usize> {
3705        self.limit
3706            .as_ref()
3707            .map(|l| usize::cast_from(u64::from(l.clone())) + self.offset)
3708    }
3709}
3710
3711impl RowSetFinishing {
3712    /// Applies finishing actions to a [`RowCollection`], and reports the total
3713    /// time it took to run.
3714    ///
3715    /// Returns a [`SortedRowCollectionIter`] that contains all of the response data, as
3716    /// well as the size of the response in bytes.
3717    pub fn finish(
3718        &self,
3719        rows: RowCollection,
3720        max_result_size: u64,
3721        max_returned_query_size: Option<u64>,
3722        duration_histogram: &Histogram,
3723    ) -> Result<(SortedRowCollectionIter, usize), String> {
3724        let now = Instant::now();
3725        let result = self.finish_inner(rows, max_result_size, max_returned_query_size);
3726        let duration = now.elapsed();
3727        duration_histogram.observe(duration.as_secs_f64());
3728
3729        result
3730    }
3731
3732    /// Implementation for [`RowSetFinishing::finish`].
3733    fn finish_inner(
3734        &self,
3735        rows: RowCollection,
3736        max_result_size: u64,
3737        max_returned_query_size: Option<u64>,
3738    ) -> Result<(SortedRowCollectionIter, usize), String> {
3739        // How much additional memory is required to make a sorted view.
3740        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3741        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3742
3743        // Bail if creating the sorted view would require us to use too much memory.
3744        if required_memory > usize::cast_from(max_result_size) {
3745            let max_bytes = ByteSize::b(max_result_size);
3746            return Err(format!("result exceeds max size of {max_bytes}",));
3747        }
3748
3749        let sorted_view = rows.sorted_view(&self.order_by);
3750        let mut iter = sorted_view
3751            .into_row_iter()
3752            .apply_offset(self.offset)
3753            .with_projection(self.project.clone());
3754
3755        if let Some(limit) = self.limit {
3756            let limit = u64::from(limit);
3757            let limit = usize::cast_from(limit);
3758            iter = iter.with_limit(limit);
3759        };
3760
3761        // TODO(parkmycar): Re-think how we can calculate the total response size without
3762        // having to iterate through the entire collection of Rows, while still
3763        // respecting the LIMIT, OFFSET, and projections.
3764        //
3765        // Note: It feels a bit bad always calculating the response size, but we almost
3766        // always need it to either check the `max_returned_query_size`, or for reporting
3767        // in the query history.
3768        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3769
3770        // Bail if we would end up returning more data to the client than they can support.
3771        if let Some(max) = max_returned_query_size {
3772            if response_size > usize::cast_from(max) {
3773                let max_bytes = ByteSize::b(max);
3774                return Err(format!("result exceeds max size of {max_bytes}"));
3775            }
3776        }
3777
3778        Ok((iter, response_size))
3779    }
3780}
3781
3782/// A [RowSetFinishing] that can be repeatedly applied to batches of updates (in
3783/// a [RowCollection]) and keeps track of the remaining limit, offset, and cap
3784/// on query result size.
3785#[derive(Debug)]
3786pub struct RowSetFinishingIncremental {
3787    /// Include only as many rows (after offset).
3788    pub remaining_limit: Option<usize>,
3789    /// Omit as many rows.
3790    pub remaining_offset: usize,
3791    /// The maximum allowed result size, as requested by the client.
3792    pub max_returned_query_size: Option<u64>,
3793    /// Tracks our remaining allowed budget for result size.
3794    pub remaining_max_returned_query_size: Option<u64>,
3795    /// Include only given columns.
3796    pub project: Vec<usize>,
3797}
3798
3799impl RowSetFinishingIncremental {
3800    /// Turns the given [RowSetFinishing] into a [RowSetFinishingIncremental].
3801    /// Can only be used when [is_streamable](RowSetFinishing::is_streamable) is
3802    /// `true`.
3803    ///
3804    /// # Panics
3805    ///
3806    /// Panics if the result is not streamable, that is it has an ORDER BY.
3807    pub fn new(
3808        offset: usize,
3809        limit: Option<NonNeg<i64>>,
3810        project: Vec<usize>,
3811        max_returned_query_size: Option<u64>,
3812    ) -> Self {
3813        let limit = limit.map(|l| {
3814            let l = u64::from(l);
3815            let l = usize::cast_from(l);
3816            l
3817        });
3818
3819        RowSetFinishingIncremental {
3820            remaining_limit: limit,
3821            remaining_offset: offset,
3822            max_returned_query_size,
3823            remaining_max_returned_query_size: max_returned_query_size,
3824            project,
3825        }
3826    }
3827
3828    /// Applies finishing actions to the given [`RowCollection`], and reports
3829    /// the total time it took to run.
3830    ///
3831    /// Returns a [`SortedRowCollectionIter`] that contains all of the response
3832    /// data.
3833    pub fn finish_incremental(
3834        &mut self,
3835        rows: RowCollection,
3836        max_result_size: u64,
3837        duration_histogram: &Histogram,
3838    ) -> Result<SortedRowCollectionIter, String> {
3839        let now = Instant::now();
3840        let result = self.finish_incremental_inner(rows, max_result_size);
3841        let duration = now.elapsed();
3842        duration_histogram.observe(duration.as_secs_f64());
3843
3844        result
3845    }
3846
3847    fn finish_incremental_inner(
3848        &mut self,
3849        rows: RowCollection,
3850        max_result_size: u64,
3851    ) -> Result<SortedRowCollectionIter, String> {
3852        // How much additional memory is required to make a sorted view.
3853        let sorted_view_mem = rows.entries().saturating_mul(std::mem::size_of::<usize>());
3854        let required_memory = rows.byte_len().saturating_add(sorted_view_mem);
3855
3856        // Bail if creating the sorted view would require us to use too much memory.
3857        if required_memory > usize::cast_from(max_result_size) {
3858            let max_bytes = ByteSize::b(max_result_size);
3859            return Err(format!("total result exceeds max size of {max_bytes}",));
3860        }
3861
3862        let batch_num_rows = rows.count(0, None);
3863
3864        let sorted_view = rows.sorted_view(&[]);
3865        let mut iter = sorted_view
3866            .into_row_iter()
3867            .apply_offset(self.remaining_offset)
3868            .with_projection(self.project.clone());
3869
3870        if let Some(limit) = self.remaining_limit {
3871            iter = iter.with_limit(limit);
3872        };
3873
3874        self.remaining_offset = self.remaining_offset.saturating_sub(batch_num_rows);
3875        if let Some(remaining_limit) = self.remaining_limit.as_mut() {
3876            *remaining_limit -= iter.count();
3877        }
3878
3879        // TODO(parkmycar): Re-think how we can calculate the total response size without
3880        // having to iterate through the entire collection of Rows, while still
3881        // respecting the LIMIT, OFFSET, and projections.
3882        //
3883        // Note: It feels a bit bad always calculating the response size, but we almost
3884        // always need it to either check the `max_returned_query_size`, or for reporting
3885        // in the query history.
3886        let response_size: usize = iter.clone().map(|row| row.data().len()).sum();
3887
3888        // Bail if we would end up returning more data to the client than they can support.
3889        if let Some(max) = self.remaining_max_returned_query_size {
3890            if response_size > usize::cast_from(max) {
3891                let max_bytes = ByteSize::b(self.max_returned_query_size.expect("known to exist"));
3892                return Err(format!("total result exceeds max size of {max_bytes}"));
3893            }
3894        }
3895
3896        Ok(iter)
3897    }
3898}
3899
3900/// Compare `left` and `right` using `order`. If that doesn't produce a strict
3901/// ordering, call `tiebreaker`.
3902pub fn compare_columns<F>(
3903    order: &[ColumnOrder],
3904    left: &[Datum],
3905    right: &[Datum],
3906    tiebreaker: F,
3907) -> Ordering
3908where
3909    F: Fn() -> Ordering,
3910{
3911    for order in order {
3912        let cmp = match (&left[order.column], &right[order.column]) {
3913            (Datum::Null, Datum::Null) => Ordering::Equal,
3914            (Datum::Null, _) => {
3915                if order.nulls_last {
3916                    Ordering::Greater
3917                } else {
3918                    Ordering::Less
3919                }
3920            }
3921            (_, Datum::Null) => {
3922                if order.nulls_last {
3923                    Ordering::Less
3924                } else {
3925                    Ordering::Greater
3926                }
3927            }
3928            (lval, rval) => {
3929                if order.desc {
3930                    rval.cmp(lval)
3931                } else {
3932                    lval.cmp(rval)
3933                }
3934            }
3935        };
3936        if cmp != Ordering::Equal {
3937            return cmp;
3938        }
3939    }
3940    tiebreaker()
3941}
3942
3943/// Describe a window frame, e.g. `RANGE UNBOUNDED PRECEDING` or
3944/// `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`.
3945///
3946/// Window frames define a subset of the partition , and only a subset of
3947/// window functions make use of the window frame.
3948#[derive(
3949    Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
3950)]
3951pub struct WindowFrame {
3952    /// ROWS, RANGE or GROUPS
3953    pub units: WindowFrameUnits,
3954    /// Where the frame starts
3955    pub start_bound: WindowFrameBound,
3956    /// Where the frame ends
3957    pub end_bound: WindowFrameBound,
3958}
3959
3960impl Display for WindowFrame {
3961    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
3962        write!(
3963            f,
3964            "{} between {} and {}",
3965            self.units, self.start_bound, self.end_bound
3966        )
3967    }
3968}
3969
3970impl WindowFrame {
3971    /// Return the default window frame used when one is not explicitly defined
3972    pub fn default() -> Self {
3973        WindowFrame {
3974            units: WindowFrameUnits::Range,
3975            start_bound: WindowFrameBound::UnboundedPreceding,
3976            end_bound: WindowFrameBound::CurrentRow,
3977        }
3978    }
3979
3980    fn includes_current_row(&self) -> bool {
3981        use WindowFrameBound::*;
3982        match self.start_bound {
3983            UnboundedPreceding => match self.end_bound {
3984                UnboundedPreceding => false,
3985                OffsetPreceding(0) => true,
3986                OffsetPreceding(_) => false,
3987                CurrentRow => true,
3988                OffsetFollowing(_) => true,
3989                UnboundedFollowing => true,
3990            },
3991            OffsetPreceding(0) => match self.end_bound {
3992                UnboundedPreceding => unreachable!(),
3993                OffsetPreceding(0) => true,
3994                // Any nonzero offsets here will create an empty window
3995                OffsetPreceding(_) => false,
3996                CurrentRow => true,
3997                OffsetFollowing(_) => true,
3998                UnboundedFollowing => true,
3999            },
4000            OffsetPreceding(_) => match self.end_bound {
4001                UnboundedPreceding => unreachable!(),
4002                // Window ends at the current row
4003                OffsetPreceding(0) => true,
4004                OffsetPreceding(_) => false,
4005                CurrentRow => true,
4006                OffsetFollowing(_) => true,
4007                UnboundedFollowing => true,
4008            },
4009            CurrentRow => true,
4010            OffsetFollowing(0) => match self.end_bound {
4011                UnboundedPreceding => unreachable!(),
4012                OffsetPreceding(_) => unreachable!(),
4013                CurrentRow => unreachable!(),
4014                OffsetFollowing(_) => true,
4015                UnboundedFollowing => true,
4016            },
4017            OffsetFollowing(_) => match self.end_bound {
4018                UnboundedPreceding => unreachable!(),
4019                OffsetPreceding(_) => unreachable!(),
4020                CurrentRow => unreachable!(),
4021                OffsetFollowing(_) => false,
4022                UnboundedFollowing => false,
4023            },
4024            UnboundedFollowing => false,
4025        }
4026    }
4027}
4028
4029impl RustType<ProtoWindowFrame> for WindowFrame {
4030    fn into_proto(&self) -> ProtoWindowFrame {
4031        ProtoWindowFrame {
4032            units: Some(self.units.into_proto()),
4033            start_bound: Some(self.start_bound.into_proto()),
4034            end_bound: Some(self.end_bound.into_proto()),
4035        }
4036    }
4037
4038    fn from_proto(proto: ProtoWindowFrame) -> Result<Self, TryFromProtoError> {
4039        Ok(WindowFrame {
4040            units: proto.units.into_rust_if_some("ProtoWindowFrame::units")?,
4041            start_bound: proto
4042                .start_bound
4043                .into_rust_if_some("ProtoWindowFrame::start_bound")?,
4044            end_bound: proto
4045                .end_bound
4046                .into_rust_if_some("ProtoWindowFrame::end_bound")?,
4047        })
4048    }
4049}
4050
4051/// Describe how frame bounds are interpreted
4052#[derive(
4053    Arbitrary, Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, MzReflect,
4054)]
4055pub enum WindowFrameUnits {
4056    /// Each row is treated as the unit of work for bounds
4057    Rows,
4058    /// Each peer group is treated as the unit of work for bounds,
4059    /// and offset-based bounds use the value of the ORDER BY expression
4060    Range,
4061    /// Each peer group is treated as the unit of work for bounds.
4062    /// Groups is currently not supported, and it is rejected during planning.
4063    Groups,
4064}
4065
4066impl Display for WindowFrameUnits {
4067    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4068        match self {
4069            WindowFrameUnits::Rows => write!(f, "rows"),
4070            WindowFrameUnits::Range => write!(f, "range"),
4071            WindowFrameUnits::Groups => write!(f, "groups"),
4072        }
4073    }
4074}
4075
4076impl RustType<proto_window_frame::ProtoWindowFrameUnits> for WindowFrameUnits {
4077    fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameUnits {
4078        use proto_window_frame::proto_window_frame_units::Kind::*;
4079        proto_window_frame::ProtoWindowFrameUnits {
4080            kind: Some(match self {
4081                WindowFrameUnits::Rows => Rows(()),
4082                WindowFrameUnits::Range => Range(()),
4083                WindowFrameUnits::Groups => Groups(()),
4084            }),
4085        }
4086    }
4087
4088    fn from_proto(
4089        proto: proto_window_frame::ProtoWindowFrameUnits,
4090    ) -> Result<Self, TryFromProtoError> {
4091        use proto_window_frame::proto_window_frame_units::Kind::*;
4092        Ok(match proto.kind {
4093            Some(Rows(())) => WindowFrameUnits::Rows,
4094            Some(Range(())) => WindowFrameUnits::Range,
4095            Some(Groups(())) => WindowFrameUnits::Groups,
4096            None => {
4097                return Err(TryFromProtoError::missing_field(
4098                    "ProtoWindowFrameUnits::kind",
4099                ));
4100            }
4101        })
4102    }
4103}
4104
4105/// Specifies [WindowFrame]'s `start_bound` and `end_bound`
4106///
4107/// The order between frame bounds is significant, as Postgres enforces
4108/// some restrictions there.
4109#[derive(
4110    Arbitrary, Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, MzReflect, PartialOrd, Ord,
4111)]
4112pub enum WindowFrameBound {
4113    /// `UNBOUNDED PRECEDING`
4114    UnboundedPreceding,
4115    /// `<N> PRECEDING`
4116    OffsetPreceding(u64),
4117    /// `CURRENT ROW`
4118    CurrentRow,
4119    /// `<N> FOLLOWING`
4120    OffsetFollowing(u64),
4121    /// `UNBOUNDED FOLLOWING`.
4122    UnboundedFollowing,
4123}
4124
4125impl Display for WindowFrameBound {
4126    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
4127        match self {
4128            WindowFrameBound::UnboundedPreceding => write!(f, "unbounded preceding"),
4129            WindowFrameBound::OffsetPreceding(offset) => write!(f, "{} preceding", offset),
4130            WindowFrameBound::CurrentRow => write!(f, "current row"),
4131            WindowFrameBound::OffsetFollowing(offset) => write!(f, "{} following", offset),
4132            WindowFrameBound::UnboundedFollowing => write!(f, "unbounded following"),
4133        }
4134    }
4135}
4136
4137impl RustType<proto_window_frame::ProtoWindowFrameBound> for WindowFrameBound {
4138    fn into_proto(&self) -> proto_window_frame::ProtoWindowFrameBound {
4139        use proto_window_frame::proto_window_frame_bound::Kind::*;
4140        proto_window_frame::ProtoWindowFrameBound {
4141            kind: Some(match self {
4142                WindowFrameBound::UnboundedPreceding => UnboundedPreceding(()),
4143                WindowFrameBound::OffsetPreceding(offset) => OffsetPreceding(*offset),
4144                WindowFrameBound::CurrentRow => CurrentRow(()),
4145                WindowFrameBound::OffsetFollowing(offset) => OffsetFollowing(*offset),
4146                WindowFrameBound::UnboundedFollowing => UnboundedFollowing(()),
4147            }),
4148        }
4149    }
4150
4151    fn from_proto(x: proto_window_frame::ProtoWindowFrameBound) -> Result<Self, TryFromProtoError> {
4152        use proto_window_frame::proto_window_frame_bound::Kind::*;
4153        Ok(match x.kind {
4154            Some(UnboundedPreceding(())) => WindowFrameBound::UnboundedPreceding,
4155            Some(OffsetPreceding(offset)) => WindowFrameBound::OffsetPreceding(offset),
4156            Some(CurrentRow(())) => WindowFrameBound::CurrentRow,
4157            Some(OffsetFollowing(offset)) => WindowFrameBound::OffsetFollowing(offset),
4158            Some(UnboundedFollowing(())) => WindowFrameBound::UnboundedFollowing,
4159            None => {
4160                return Err(TryFromProtoError::missing_field(
4161                    "ProtoWindowFrameBound::kind",
4162                ));
4163            }
4164        })
4165    }
4166}
4167
4168/// Maximum iterations for a LetRec.
4169#[derive(
4170    Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Arbitrary,
4171)]
4172pub struct LetRecLimit {
4173    /// Maximum number of iterations to evaluate.
4174    pub max_iters: NonZeroU64,
4175    /// Whether to throw an error when reaching the above limit.
4176    /// If true, we simply use the current contents of each Id as the final result.
4177    pub return_at_limit: bool,
4178}
4179
4180impl LetRecLimit {
4181    /// Compute the smallest limit from a Vec of `LetRecLimit`s.
4182    pub fn min_max_iter(limits: &Vec<Option<LetRecLimit>>) -> Option<u64> {
4183        limits
4184            .iter()
4185            .filter_map(|l| l.as_ref().map(|l| l.max_iters.get()))
4186            .min()
4187    }
4188
4189    /// The default value of `LetRecLimit::return_at_limit` when using the RECURSION LIMIT option of
4190    /// WMR without ERROR AT or RETURN AT.
4191    pub const RETURN_AT_LIMIT_DEFAULT: bool = false;
4192}
4193
4194impl Display for LetRecLimit {
4195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4196        write!(f, "[recursion_limit={}", self.max_iters)?;
4197        if self.return_at_limit != LetRecLimit::RETURN_AT_LIMIT_DEFAULT {
4198            write!(f, ", return_at_limit")?;
4199        }
4200        write!(f, "]")
4201    }
4202}
4203
4204/// For a global Get, this indicates whether we are going to read from Persist or from an index.
4205/// (See comment in MirRelationExpr::Get.)
4206#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, Arbitrary)]
4207pub enum AccessStrategy {
4208    /// It's either a local Get (a CTE), or unknown at the time.
4209    /// `prune_and_annotate_dataflow_index_imports` decides it for global Gets, and thus switches to
4210    /// one of the other variants.
4211    UnknownOrLocal,
4212    /// The Get will read from Persist.
4213    Persist,
4214    /// The Get will read from an index or indexes: (index id, how the index will be used).
4215    Index(Vec<(GlobalId, IndexUsageType)>),
4216    /// The Get will read a collection that is computed by the same dataflow, but in a different
4217    /// `BuildDesc` in `objects_to_build`.
4218    SameDataflow,
4219}
4220
4221#[cfg(test)]
4222mod tests {
4223    use mz_ore::assert_ok;
4224    use mz_proto::protobuf_roundtrip;
4225    use mz_repr::explain::text::text_string_at;
4226    use proptest::prelude::*;
4227
4228    use crate::explain::HumanizedExplain;
4229
4230    use super::*;
4231
4232    proptest! {
4233        #[mz_ore::test]
4234        fn column_order_protobuf_roundtrip(expect in any::<ColumnOrder>()) {
4235            let actual = protobuf_roundtrip::<_, ProtoColumnOrder>(&expect);
4236            assert_ok!(actual);
4237            assert_eq!(actual.unwrap(), expect);
4238        }
4239    }
4240
4241    proptest! {
4242        #[mz_ore::test]
4243        #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `decContextDefault` on OS `linux`
4244        fn aggregate_expr_protobuf_roundtrip(expect in any::<AggregateExpr>()) {
4245            let actual = protobuf_roundtrip::<_, ProtoAggregateExpr>(&expect);
4246            assert_ok!(actual);
4247            assert_eq!(actual.unwrap(), expect);
4248        }
4249    }
4250
4251    proptest! {
4252        #[mz_ore::test]
4253        fn window_frame_units_protobuf_roundtrip(expect in any::<WindowFrameUnits>()) {
4254            let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameUnits>(&expect);
4255            assert_ok!(actual);
4256            assert_eq!(actual.unwrap(), expect);
4257        }
4258    }
4259
4260    proptest! {
4261        #[mz_ore::test]
4262        fn window_frame_bound_protobuf_roundtrip(expect in any::<WindowFrameBound>()) {
4263            let actual = protobuf_roundtrip::<_, proto_window_frame::ProtoWindowFrameBound>(&expect);
4264            assert_ok!(actual);
4265            assert_eq!(actual.unwrap(), expect);
4266        }
4267    }
4268
4269    proptest! {
4270        #[mz_ore::test]
4271        fn window_frame_protobuf_roundtrip(expect in any::<WindowFrame>()) {
4272            let actual = protobuf_roundtrip::<_, ProtoWindowFrame>(&expect);
4273            assert_ok!(actual);
4274            assert_eq!(actual.unwrap(), expect);
4275        }
4276    }
4277
4278    #[mz_ore::test]
4279    fn test_row_set_finishing_as_text() {
4280        let finishing = RowSetFinishing {
4281            order_by: vec![ColumnOrder {
4282                column: 4,
4283                desc: true,
4284                nulls_last: true,
4285            }],
4286            limit: Some(NonNeg::try_from(7).unwrap()),
4287            offset: Default::default(),
4288            project: vec![1, 3, 4, 5],
4289        };
4290
4291        let mode = HumanizedExplain::new(false);
4292        let expr = mode.expr(&finishing, None);
4293
4294        let act = text_string_at(&expr, mz_ore::str::Indent::default);
4295
4296        let exp = {
4297            use mz_ore::fmt::FormatBuffer;
4298            let mut s = String::new();
4299            write!(&mut s, "Finish");
4300            write!(&mut s, " order_by=[#4 desc nulls_last]");
4301            write!(&mut s, " limit=7");
4302            write!(&mut s, " output=[#1, #3..=#5]");
4303            writeln!(&mut s, "");
4304            s
4305        };
4306
4307        assert_eq!(act, exp);
4308    }
4309}
4310
4311/// An iterator over AST structures, which calls out nodes in difference.
4312///
4313/// The iterators visit two ASTs in tandem, continuing as long as the AST node data matches,
4314/// and yielding an output pair as soon as the AST nodes do not match. Their intent is to call
4315/// attention to the moments in the ASTs where they differ, and incidentally a stack-free way
4316/// to compare two ASTs.
4317mod structured_diff {
4318
4319    use super::MirRelationExpr;
4320
4321    ///  An iterator over structured differences between two `MirRelationExpr` instances.
4322    pub struct MreDiff<'a> {
4323        /// Pairs of expressions that must still be compared.
4324        todo: Vec<(&'a MirRelationExpr, &'a MirRelationExpr)>,
4325    }
4326
4327    impl<'a> MreDiff<'a> {
4328        /// Create a new `MirRelationExpr` structured difference.
4329        pub fn new(expr1: &'a MirRelationExpr, expr2: &'a MirRelationExpr) -> Self {
4330            MreDiff {
4331                todo: vec![(expr1, expr2)],
4332            }
4333        }
4334    }
4335
4336    impl<'a> Iterator for MreDiff<'a> {
4337        // Pairs of expressions that do not match.
4338        type Item = (&'a MirRelationExpr, &'a MirRelationExpr);
4339
4340        fn next(&mut self) -> Option<Self::Item> {
4341            while let Some((expr1, expr2)) = self.todo.pop() {
4342                match (expr1, expr2) {
4343                    (
4344                        MirRelationExpr::Constant {
4345                            rows: rows1,
4346                            typ: typ1,
4347                        },
4348                        MirRelationExpr::Constant {
4349                            rows: rows2,
4350                            typ: typ2,
4351                        },
4352                    ) => {
4353                        if rows1 != rows2 || typ1 != typ2 {
4354                            return Some((expr1, expr2));
4355                        }
4356                    }
4357                    (
4358                        MirRelationExpr::Get {
4359                            id: id1,
4360                            typ: typ1,
4361                            access_strategy: as1,
4362                        },
4363                        MirRelationExpr::Get {
4364                            id: id2,
4365                            typ: typ2,
4366                            access_strategy: as2,
4367                        },
4368                    ) => {
4369                        if id1 != id2 || typ1 != typ2 || as1 != as2 {
4370                            return Some((expr1, expr2));
4371                        }
4372                    }
4373                    (
4374                        MirRelationExpr::Let {
4375                            id: id1,
4376                            body: body1,
4377                            value: value1,
4378                        },
4379                        MirRelationExpr::Let {
4380                            id: id2,
4381                            body: body2,
4382                            value: value2,
4383                        },
4384                    ) => {
4385                        if id1 != id2 {
4386                            return Some((expr1, expr2));
4387                        } else {
4388                            self.todo.push((body1, body2));
4389                            self.todo.push((value1, value2));
4390                        }
4391                    }
4392                    (
4393                        MirRelationExpr::LetRec {
4394                            ids: ids1,
4395                            body: body1,
4396                            values: values1,
4397                            limits: limits1,
4398                        },
4399                        MirRelationExpr::LetRec {
4400                            ids: ids2,
4401                            body: body2,
4402                            values: values2,
4403                            limits: limits2,
4404                        },
4405                    ) => {
4406                        if ids1 != ids2 || values1.len() != values2.len() || limits1 != limits2 {
4407                            return Some((expr1, expr2));
4408                        } else {
4409                            self.todo.push((body1, body2));
4410                            self.todo.extend(values1.iter().zip(values2.iter()));
4411                        }
4412                    }
4413                    (
4414                        MirRelationExpr::Project {
4415                            outputs: outputs1,
4416                            input: input1,
4417                        },
4418                        MirRelationExpr::Project {
4419                            outputs: outputs2,
4420                            input: input2,
4421                        },
4422                    ) => {
4423                        if outputs1 != outputs2 {
4424                            return Some((expr1, expr2));
4425                        } else {
4426                            self.todo.push((input1, input2));
4427                        }
4428                    }
4429                    (
4430                        MirRelationExpr::Map {
4431                            scalars: scalars1,
4432                            input: input1,
4433                        },
4434                        MirRelationExpr::Map {
4435                            scalars: scalars2,
4436                            input: input2,
4437                        },
4438                    ) => {
4439                        if scalars1 != scalars2 {
4440                            return Some((expr1, expr2));
4441                        } else {
4442                            self.todo.push((input1, input2));
4443                        }
4444                    }
4445                    (
4446                        MirRelationExpr::Filter {
4447                            predicates: predicates1,
4448                            input: input1,
4449                        },
4450                        MirRelationExpr::Filter {
4451                            predicates: predicates2,
4452                            input: input2,
4453                        },
4454                    ) => {
4455                        if predicates1 != predicates2 {
4456                            return Some((expr1, expr2));
4457                        } else {
4458                            self.todo.push((input1, input2));
4459                        }
4460                    }
4461                    (
4462                        MirRelationExpr::FlatMap {
4463                            input: input1,
4464                            func: func1,
4465                            exprs: exprs1,
4466                        },
4467                        MirRelationExpr::FlatMap {
4468                            input: input2,
4469                            func: func2,
4470                            exprs: exprs2,
4471                        },
4472                    ) => {
4473                        if func1 != func2 || exprs1 != exprs2 {
4474                            return Some((expr1, expr2));
4475                        } else {
4476                            self.todo.push((input1, input2));
4477                        }
4478                    }
4479                    (
4480                        MirRelationExpr::Join {
4481                            inputs: inputs1,
4482                            equivalences: eq1,
4483                            implementation: impl1,
4484                        },
4485                        MirRelationExpr::Join {
4486                            inputs: inputs2,
4487                            equivalences: eq2,
4488                            implementation: impl2,
4489                        },
4490                    ) => {
4491                        if inputs1.len() != inputs2.len() || eq1 != eq2 || impl1 != impl2 {
4492                            return Some((expr1, expr2));
4493                        } else {
4494                            self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4495                        }
4496                    }
4497                    (
4498                        MirRelationExpr::Reduce {
4499                            aggregates: aggregates1,
4500                            input: inputs1,
4501                            group_key: gk1,
4502                            monotonic: m1,
4503                            expected_group_size: egs1,
4504                        },
4505                        MirRelationExpr::Reduce {
4506                            aggregates: aggregates2,
4507                            input: inputs2,
4508                            group_key: gk2,
4509                            monotonic: m2,
4510                            expected_group_size: egs2,
4511                        },
4512                    ) => {
4513                        if aggregates1 != aggregates2 || gk1 != gk2 || m1 != m2 || egs1 != egs2 {
4514                            return Some((expr1, expr2));
4515                        } else {
4516                            self.todo.push((inputs1, inputs2));
4517                        }
4518                    }
4519                    (
4520                        MirRelationExpr::TopK {
4521                            group_key: gk1,
4522                            order_key: order1,
4523                            input: input1,
4524                            limit: l1,
4525                            offset: o1,
4526                            monotonic: m1,
4527                            expected_group_size: egs1,
4528                        },
4529                        MirRelationExpr::TopK {
4530                            group_key: gk2,
4531                            order_key: order2,
4532                            input: input2,
4533                            limit: l2,
4534                            offset: o2,
4535                            monotonic: m2,
4536                            expected_group_size: egs2,
4537                        },
4538                    ) => {
4539                        if order1 != order2
4540                            || gk1 != gk2
4541                            || l1 != l2
4542                            || o1 != o2
4543                            || m1 != m2
4544                            || egs1 != egs2
4545                        {
4546                            return Some((expr1, expr2));
4547                        } else {
4548                            self.todo.push((input1, input2));
4549                        }
4550                    }
4551                    (
4552                        MirRelationExpr::Negate { input: input1 },
4553                        MirRelationExpr::Negate { input: input2 },
4554                    ) => {
4555                        self.todo.push((input1, input2));
4556                    }
4557                    (
4558                        MirRelationExpr::Threshold { input: input1 },
4559                        MirRelationExpr::Threshold { input: input2 },
4560                    ) => {
4561                        self.todo.push((input1, input2));
4562                    }
4563                    (
4564                        MirRelationExpr::Union {
4565                            base: base1,
4566                            inputs: inputs1,
4567                        },
4568                        MirRelationExpr::Union {
4569                            base: base2,
4570                            inputs: inputs2,
4571                        },
4572                    ) => {
4573                        if inputs1.len() != inputs2.len() {
4574                            return Some((expr1, expr2));
4575                        } else {
4576                            self.todo.push((base1, base2));
4577                            self.todo.extend(inputs1.iter().zip(inputs2.iter()));
4578                        }
4579                    }
4580                    (
4581                        MirRelationExpr::ArrangeBy {
4582                            keys: keys1,
4583                            input: input1,
4584                        },
4585                        MirRelationExpr::ArrangeBy {
4586                            keys: keys2,
4587                            input: input2,
4588                        },
4589                    ) => {
4590                        if keys1 != keys2 {
4591                            return Some((expr1, expr2));
4592                        } else {
4593                            self.todo.push((input1, input2));
4594                        }
4595                    }
4596                    _ => {
4597                        return Some((expr1, expr2));
4598                    }
4599                }
4600            }
4601            None
4602        }
4603    }
4604}